Qwen Image Encode Prompt
87 lignes
def _get_qwen_prompt_embeds(
def _get_qwen_prompt_embeds(
self,
self,
prompt: Union[str, List[str]] = None,
prompt: Union[str, List[str]] = None,
image: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
dtype: Optional[torch.dtype] = None,
):
):
device = device or self._execution_device
device = device or self._execution_device
dtype = dtype or self.text_encoder.dtype
dtype = dtype or self.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt] if isinstance(prompt, str) else prompt
template = self.prompt_template_encode
template = self.prompt_template_encode
drop_idx = self.prompt_template_encode_start_idx
drop_idx = self.prompt_template_encode_start_idx
txt = [template.format(e) for e in prompt]
txt = [template.format(e) for e in prompt]
txt_tokens = self.tokenizer(
model_inputs = self.processor(
txt, max_length=self.tokenizer_max_length + drop_idx, padding=True, truncation=True, return_tensors="pt"
text=txt,
images=image,
padding=True,
return_tensors="pt",
).to(device)
).to(device)
encoder_hidden_states = self.text_encoder(
outputs = self.text_encoder(
input_ids=txt_tokens.input_ids,
input_ids=model_inputs.input_ids,
attention_mask=txt_tokens.attention_mask,
attention_mask=model_inputs.attention_mask,
pixel_values=model_inputs.pixel_values,
image_grid_thw=model_inputs.image_grid_thw,
output_hidden_states=True,
output_hidden_states=True,
)
)
hidden_states = encoder_hidden_states.hidden_states[-1]
hidden_states = outputs.hidden_states[-1]
split_hidden_states = self._extract_masked_hidden(hidden_states, txt_tokens.attention_mask)
split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
max_seq_len = max([e.size(0) for e in split_hidden_states])
max_seq_len = max([e.size(0) for e in split_hidden_states])
prompt_embeds = torch.stack(
prompt_embeds = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0), u.size(1))]) for u in split_hidden_states]
)
)
encoder_attention_mask = torch.stack(
encoder_attention_mask = torch.stack(
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
[torch.cat([u, u.new_zeros(max_seq_len - u.size(0))]) for u in attn_mask_list]
)
)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
return prompt_embeds, encoder_attention_mask
return prompt_embeds, encoder_attention_mask
def encode_prompt(
def encode_prompt(
self,
self,
prompt: Union[str, List[str]],
prompt: Union[str, List[str]],
image: Optional[torch.Tensor] = None,
device: Optional[torch.device] = None,
device: Optional[torch.device] = None,
num_images_per_prompt: int = 1,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
prompt_embeds_mask: Optional[torch.Tensor] = None,
max_sequence_length: int = 1024,
max_sequence_length: int = 1024,
):
):
r"""
r"""
Args:
Args:
prompt (`str` or `List[str]`, *optional*):
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
prompt to be encoded
image (`torch.Tensor`, *optional*):
image to be encoded
device: (`torch.device`):
device: (`torch.device`):
torch device
torch device
num_images_per_prompt (`int`):
num_images_per_prompt (`int`):
number of images that should be generated per prompt
number of images that should be generated per prompt
prompt_embeds (`torch.Tensor`, *optional*):
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
provided, text embeddings will be generated from `prompt` input argument.
"""
"""
device = device or self._execution_device
device = device or self._execution_device
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt] if isinstance(prompt, str) else prompt
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
batch_size = len(prompt) if prompt_embeds is None else prompt_embeds.shape[0]
if prompt_embeds is None:
if prompt_embeds is None:
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, image, device)
prompt_embeds, prompt_embeds_mask = self._get_qwen_prompt_embeds(prompt, device)
prompt_embeds = prompt_embeds[:, :max_sequence_length]
prompt_embeds_mask = prompt_embeds_mask[:, :max_sequence_length]
_, seq_len, _ = prompt_embeds.shape
_, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.repeat(1, num_images_per_prompt, 1)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
prompt_embeds_mask = prompt_embeds_mask.view(batch_size * num_images_per_prompt, seq_len)
return prompt_embeds, prompt_embeds_mask
return prompt_embeds, prompt_embeds_mask