Qwen Image Encode Prompt

Created Diff never expires
24 删除
总计
删除
单词
总计
删除
要继续使用此功能,请升级到
Diffchecker logo
Diffchecker Pro
87
12 添加
总计
添加
单词
总计
添加
要继续使用此功能,请升级到
Diffchecker logo
Diffchecker Pro
78
def _get_qwen_prompt_embeds(
def _get_qwen_prompt_embeds(
self,
self,
prompt: Union[str, List[str]] = None,
prompt: Union[str, List[str]] = None,
image: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
dtype: Optional[torch.dtype] = None,
):
):
device = device or self._execution_device
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
dtype = dtype or self.text_encoder.dtype


prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt] if isinstance(prompt, str) else prompt


template = self.prompt_template_encode
template = self.prompt_template_encode
drop_idx = self.prompt_template_encode_start_idx
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(e) for e in prompt]
txt = [template.format(e) for e in prompt]

txt_tokens = self.tokenizer(
model_inputs = self.processor(
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
text=txt,
images=image,
padding=True,
return_tensors="pt",
).to(device)
).to(device)

encoder_hidden_states = self.text_encoder(
outputs = self.text_encoder(
input_ids=txt_tokens.input_ids,
input_ids=model_inputs.input_ids,
attention_mask=txt_tokens.attention_mask,
attention_mask=model_inputs.attention_mask,
pixel_values=model_inputs.pixel_values,
image_grid_thw=model_inputs.image_grid_thw,
output_hidden_states=True,
output_hidden_states=True,
)
)

hidden_states = encoder_hidden_states.hidden_states[-1]
hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
)
encoder_attention_mask = torch.stack(
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)
)


prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)


return prompt_embeds, encoder_attention_mask
return prompt_embeds, encoder_attention_mask


def encode_prompt(
def encode_prompt(
self,
self,
prompt: Union[str, List[str]],
prompt: Union[str, List[str]],
image: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 1024,
max_sequence_length: int = 1024,
):
):
r"""
r"""


Args:
Args:
prompt (`str` or `List[str]`, *optional*):
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt to be encoded
image (`torch.Tensor`, *optional*):
image to be encoded
device: (`torch.device`):
device: (`torch.device`):
torch device
torch device
num_images_per_prompt (`int`):
num_images_per_prompt (`int`):
number of images that should be generated per prompt
number of images that should be generated per prompt
prompt_embeds (`torch.Tensor`, *optional*):
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
provided, text embeddings will be generated from `prompt` input argument.
"""
"""
device = device or self._execution_device
device = device or self._execution_device


prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]


if prompt_embeds is None:
if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)

prompt_embeds = prompt_embeds[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]


_, seq_len, _ = prompt_embeds.shape
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)


return prompt_embeds, prompt_embeds_mask
return prompt_embeds, prompt_embeds_mask