Qwen Image Encode Prompt

Created Diff never expires
24 remoções
Linhas
Total
Removido
Palavras
Total
Removido
Para continuar usando este recurso, atualize para
Diffchecker logo
Diffchecker Pro
87 linhas
12 adições
Linhas
Total
Adicionado
Palavras
Total
Adicionado
Para continuar usando este recurso, atualize para
Diffchecker logo
Diffchecker Pro
78 linhas
def _get_qwen_prompt_embeds(
def _get_qwen_prompt_embeds(
self,
self,
prompt: Union[str, List[str]] = None,
prompt: Union[str, List[str]] = None,
image: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
dtype: Optional[torch.dtype] = None,
):
):
device = device or self._execution_device
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
dtype = dtype or self.text_encoder.dtype


prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt] if isinstance(prompt, str) else prompt


template = self.prompt_template_encode
template = self.prompt_template_encode
drop_idx = self.prompt_template_encode_start_idx
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(e) for e in prompt]
txt = [template.format(e) for e in prompt]

txt_tokens = self.tokenizer(
model_inputs = self.processor(
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
text=txt,
images=image,
padding=True,
return_tensors="pt",
).to(device)
).to(device)

encoder_hidden_states = self.text_encoder(
outputs = self.text_encoder(
input_ids=txt_tokens.input_ids,
input_ids=model_inputs.input_ids,
attention_mask=txt_tokens.attention_mask,
attention_mask=model_inputs.attention_mask,
pixel_values=model_inputs.pixel_values,
image_grid_thw=model_inputs.image_grid_thw,
output_hidden_states=True,
output_hidden_states=True,
)
)

hidden_states = encoder_hidden_states.hidden_states[-1]
hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
)
encoder_attention_mask = torch.stack(
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)
)


prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)


return prompt_embeds, encoder_attention_mask
return prompt_embeds, encoder_attention_mask


def encode_prompt(
def encode_prompt(
self,
self,
prompt: Union[str, List[str]],
prompt: Union[str, List[str]],
image: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 1024,
max_sequence_length: int = 1024,
):
):
r"""
r"""


Args:
Args:
prompt (`str` or `List[str]`, *optional*):
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt to be encoded
image (`torch.Tensor`, *optional*):
image to be encoded
device: (`torch.device`):
device: (`torch.device`):
torch device
torch device
num_images_per_prompt (`int`):
num_images_per_prompt (`int`):
number of images that should be generated per prompt
number of images that should be generated per prompt
prompt_embeds (`torch.Tensor`, *optional*):
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
provided, text embeddings will be generated from `prompt` input argument.
"""
"""
device = device or self._execution_device
device = device or self._execution_device


prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]


if prompt_embeds is None:
if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)

prompt_embeds = prompt_embeds[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]


_, seq_len, _ = prompt_embeds.shape
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)


return prompt_embeds, prompt_embeds_mask
return prompt_embeds, prompt_embeds_mask