Qwen Image Encode Prompt

Created Diff never expires
24 削除
合計
削除
単語
合計
削除
この機能を引き続き使用するには、アップグレードしてください
Diffchecker logo
Diffchecker Pro
87
12 追加
合計
追加
単語
合計
追加
この機能を引き続き使用するには、アップグレードしてください
Diffchecker logo
Diffchecker Pro
78
def _get_qwen_prompt_embeds(
def _get_qwen_prompt_embeds(
self,
self,
prompt: Union[str, List[str]] = None,
prompt: Union[str, List[str]] = None,
image: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
dtype: Optional[torch.dtype] = None,
):
):
device = device or self._execution_device
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
dtype = dtype or self.text_encoder.dtype


prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt] if isinstance(prompt, str) else prompt


template = self.prompt_template_encode
template = self.prompt_template_encode
drop_idx = self.prompt_template_encode_start_idx
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(e) for e in prompt]
txt = [template.format(e) for e in prompt]

txt_tokens = self.tokenizer(
model_inputs = self.processor(
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
text=txt,
images=image,
padding=True,
return_tensors="pt",
).to(device)
).to(device)

encoder_hidden_states = self.text_encoder(
outputs = self.text_encoder(
input_ids=txt_tokens.input_ids,
input_ids=model_inputs.input_ids,
attention_mask=txt_tokens.attention_mask,
attention_mask=model_inputs.attention_mask,
pixel_values=model_inputs.pixel_values,
image_grid_thw=model_inputs.image_grid_thw,
output_hidden_states=True,
output_hidden_states=True,
)
)

hidden_states = encoder_hidden_states.hidden_states[-1]
hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
)
encoder_attention_mask = torch.stack(
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)
)


prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)


return prompt_embeds, encoder_attention_mask
return prompt_embeds, encoder_attention_mask


def encode_prompt(
def encode_prompt(
self,
self,
prompt: Union[str, List[str]],
prompt: Union[str, List[str]],
image: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 1024,
max_sequence_length: int = 1024,
):
):
r"""
r"""


Args:
Args:
prompt (`str` or `List[str]`, *optional*):
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt to be encoded
image (`torch.Tensor`, *optional*):
image to be encoded
device: (`torch.device`):
device: (`torch.device`):
torch device
torch device
num_images_per_prompt (`int`):
num_images_per_prompt (`int`):
number of images that should be generated per prompt
number of images that should be generated per prompt
prompt_embeds (`torch.Tensor`, *optional*):
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
provided, text embeddings will be generated from `prompt` input argument.
"""
"""
device = device or self._execution_device
device = device or self._execution_device


prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]


if prompt_embeds is None:
if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)

prompt_embeds = prompt_embeds[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]


_, seq_len, _ = prompt_embeds.shape
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)


return prompt_embeds, prompt_embeds_mask
return prompt_embeds, prompt_embeds_mask