Hunyuan Left A13B Right Large
558 lines
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
# Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
#
#
# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
# Licensed under the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
# https://github.com/Tencent/Tencent-Hunyuan-Large/blob/main/License.docx
# https://github.com/Tencent/Tencent-Hunyuan-Large/blob/main/License.docx
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
#
#
""" PyTorch HunYuan model."""
""" PyTorch HunYuan model."""
import math
import math
import warnings
import warnings
from typing import List, Optional, Tuple, Union
from typing import List, Optional, Tuple, Union
import torch
import torch
from torch import Tensor
from torch import Tensor
import torch.nn.functional as F
import torch.nn.functional as F
import torch.utils.checkpoint
import torch.utils.checkpoint
from torch import nn
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import (
from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter,
AttentionMaskConverter,
_prepare_4d_attention_mask,
_prepare_4d_attention_mask,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask_for_sdpa,
)
)
from transformers.modeling_outputs import (
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
BaseModelOutputWithPast,
CausalLMOutputWithPast,
CausalLMOutputWithPast,
SequenceClassifierOutputWithPast
SequenceClassifierOutputWithPast
)
)
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_utils import PreTrainedModel
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
from transformers.utils import (
from transformers.utils import (
add_start_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
add_start_docstrings_to_model_forward,
is_flash_attn_2_available,
is_flash_attn_2_available,
is_flash_attn_greater_or_equal_2_10,
is_flash_attn_greater_or_equal_2_10,
logging,
logging,
replace_return_docstrings,
replace_return_docstrings,
)
)
from transformers.utils.import_utils import is_torch_fx_available
from transformers.utils.import_utils import is_torch_fx_available
from .configuration_hunyuan import HunYuanConfig
from .configuration_hunyuan import HunYuanConfig
if is_flash_attn_2_available():
if is_flash_attn_2_available():
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn import flash_attn_func, flash_attn_varlen_func
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# It means that the function will not be traced through and simply appear as a node in the graph.
# It means that the function will not be traced through and simply appear as a node in the graph.
if is_torch_fx_available():
if is_torch_fx_available():
if not is_torch_greater_or_equal_than_1_13:
if not is_torch_greater_or_equal_than_1_13:
import torch.fx
import torch.fx
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
logger = logging.get_logger(__name__)
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "HunYuanConfig"
_CONFIG_FOR_DOC = "HunYuanConfig"
def topkgating(logits: Tensor, topk: int):
def topkgating(logits: Tensor, topk: int):
logits = logits.float()
logits = logits.float()
gates = F.softmax(logits, dim=1)
gates = F.softmax(logits, dim=1)
# expert_capacity = topk * gates.shape[0]
expert_capacity = topk * gates.shape[0]
expert_capacity = max(topk, topk * gates.shape[0] // gates.shape[1])
num_experts = int(gates.shape[1])
num_experts = int(gates.shape[1])
# Top-k router probability and corresponding expert indices for each token.
# Top-k router probability and corresponding expert indices for each token.
# Shape: [tokens_per_group, num_selected_experts].
# Shape: [tokens_per_group, num_selected_experts].
expert_gate, expert_index = torch.topk(gates, topk)
expert_gate, expert_index = torch.topk(gates, topk)
expert_mask = F.one_hot(expert_index, num_experts)
expert_mask = F.one_hot(expert_index, num_experts)
# For a given token, determine if it was routed to a given expert.
# For a given token, determine if it was routed to a given expert.
# Shape: [tokens_per_group, num_experts]
# Shape: [tokens_per_group, num_experts]
expert_mask_aux = expert_mask.max(dim=-2)[0]
expert_mask_aux = expert_mask.max(dim=-2)[0]
tokens_per_group_and_expert = torch.mean(expert_mask_aux.float(), dim=-2)
tokens_per_group_and_expert = torch.mean(expert_mask_aux.float(), dim=-2)
router_prob_per_group_and_expert = torch.mean(gates.float(), dim=-2)
router_prob_per_group_and_expert = torch.mean(gates.float(), dim=-2)
l_aux = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert)
l_aux = num_experts**2 * torch.mean(tokens_per_group_and_expert * router_prob_per_group_and_expert)
gates_s = torch.clamp(
gates_s = torch.clamp(
torch.matmul(expert_mask.float(), gates.unsqueeze(-1)).sum(dim=1), min=torch.finfo(gates.dtype).eps
torch.matmul(expert_mask.float(), gates.unsqueeze(-1)).sum(dim=1), min=torch.finfo(gates.dtype).eps
)
)
router_probs = gates / gates_s
router_probs = gates / gates_s
# Make num_selected_experts the leading axis to ensure that top-1 choices
# Make num_selected_experts the leading axis to ensure that top-1 choices
# have priority over top-2 choices, which have priority over top-3 choices,
# have priority over top-2 choices, which have priority over top-3 choices,
# etc.
# etc.
expert_index = torch.transpose(expert_index, 0, 1)
expert_index = torch.transpose(expert_index, 0, 1)
# Shape: [num_selected_experts * tokens_per_group]
# Shape: [num_selected_experts * tokens_per_group]
expert_index = expert_index.reshape(-1)
expert_index = expert_index.reshape(-1)
# Create mask out of indices.
# Create mask out of indices.
# Shape: [tokens_per_group * num_selected_experts, num_experts].
# Shape: [tokens_per_group * num_selected_experts, num_experts].
expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32)
expert_mask = F.one_hot(expert_index, num_experts).to(torch.int32)
exp_counts = torch.sum(expert_mask, dim=0).detach()
exp_counts = torch.sum(expert_mask, dim=0).detach()
# Experts have a fixed capacity that we cannot exceed. A token's priority
# Experts have a fixed capacity that we cannot exceed. A token's priority
# within the expert's buffer is given by the masked, cumulative capacity of
# within the expert's buffer is given by the masked, cumulative capacity of
# its target expert.
# its target expert.
# Shape: [tokens_per_group * num_selected_experts, num_experts].
# Shape: [tokens_per_group * num_selected_experts, num_experts].
token_priority = torch.cumsum(expert_mask, dim=0) * expert_mask - 1
token_priority = torch.cumsum(expert_mask, dim=0) * expert_mask - 1
# Shape: [num_selected_experts, tokens_per_group, num_experts].
# Shape: [num_selected_experts, tokens_per_group, num_experts].
token_priority = token_priority.reshape((topk, -1, num_experts))
token_priority = token_priority.reshape((topk, -1, num_experts))
# Shape: [tokens_per_group, num_selected_experts, num_experts].
# Shape: [tokens_per_group, num_selected_experts, num_experts].
token_priority = torch.transpose(token_priority, 0, 1)
token_priority = torch.transpose(token_priority, 0, 1)
# For each token, across all selected experts, select the only non-negative
# For each token, across all selected experts, select the only non-negative
# (unmasked) priority. Now, for group G routing to expert E, token T has
# (unmasked) priority. Now, for group G routing to expert E, token T has
# non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E
# non-negative priority (i.e. token_priority[G,T,E] >= 0) if and only if E
# is its targeted expert.
# is its targeted expert.
# Shape: [tokens_per_group, num_experts].
# Shape: [tokens_per_group, num_experts].
token_priority = torch.max(token_priority, dim=1)[0]
token_priority = torch.max(token_priority, dim=1)[0]
# Token T can only be routed to expert E if its priority is positive and
# Token T can only be routed to expert E if its priority is positive and
# less than the expert capacity. One-hot matrix will ignore indices outside
# less than the expert capacity. One-hot matrix will ignore indices outside
# the range [0, expert_capacity).
# the range [0, expert_capacity).
# Shape: [tokens_per_group, num_experts, expert_capacity].
# Shape: [tokens_per_group, num_experts, expert_capacity].
valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity)
valid_mask = torch.logical_and(token_priority >= 0, token_priority < expert_capacity)
token_priority = torch.masked_fill(token_priority, ~valid_mask, 0)
token_priority = torch.masked_fill(token_priority, ~valid_mask, 0)
dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool)
dispatch_mask = F.one_hot(token_priority, expert_capacity).to(torch.bool)
valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, expert_capacity)
valid_mask = valid_mask.unsqueeze(-1).expand(-1, -1, expert_capacity)
dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0)
dispatch_mask = torch.masked_fill(dispatch_mask, ~valid_mask, 0)
# The combine array will be used for combining expert outputs, scaled by the
# The combine array will be used for combining expert outputs, scaled by the
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
# router probabilities. Shape: [num_groups, tokens_per_group, num_experts,
# expert_capacity].
# expert_capacity].
combine_weights = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
combine_weights = torch.einsum("...te,...tec->...tec", router_probs, dispatch_mask)
exp_counts_capacity = torch.sum(dispatch_mask)
exp_counts_capacity = torch.sum(dispatch_mask)
exp_capacity_rate = exp_counts_capacity / (logits.shape[0]*topk)
exp_capacity_rate = exp_counts_capacity / (logits.shape[0]*topk)
return [l_aux, exp_capacity_rate], combine_weights, dispatch_mask, exp_counts
return [l_aux, exp_capacity_rate], combine_weights, dispatch_mask, exp_counts
def top1gating(logits: Tensor, random_routing_dropped_token: bool = False):
def top1gating(logits: Tensor, random_routing_dropped_token: bool = False):
"""Implements Top1Gating on logits."""
"""Implements Top1Gating on logits."""
# everything is in fp32 in this function
# everything is in fp32 in this function
logits = logits.float()
logits = logits.float()
gates = F.softmax(logits, dim=1)
gates = F.softmax(logits, dim=1)
capacity = gates.shape[0]
capacity = gates.shape[0]
# Create a mask for 1st's expert per token
# Create a mask for 1st's expert per token
# noisy gating
# noisy gating
indices1_s = torch.argmax(gates, dim=1)
indices1_s = torch.argmax(gates, dim=1)
num_experts = int(gates.shape[1])
num_experts = int(gates.shape[1])
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
mask1 = F.one_hot(indices1_s, num_classes=num_experts)
# gating decisions
# gating decisions
# exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
# exp_counts = torch.sum(mask1, dim=0).detach().to('cpu')
exp_counts = torch.sum(mask1, dim=0).detach()
exp_counts = torch.sum(mask1, dim=0).detach()
# Compute l_aux
# Compute l_aux
me = torch.mean(gates, dim=0)
me = torch.mean(gates, dim=0)
ce = torch.mean(mask1.float(), dim=0)
ce = torch.mean(mask1.float(), dim=0)
l_aux = torch.sum(me * ce) * num_experts
l_aux = torch.sum(me * ce) * num_experts
mask1_rand = mask1
mask1_rand = mask1
top_idx = torch.topk(mask1_rand, k=capacity, dim=0)[1]
top_idx = torch.topk(mask1_rand, k=capacity, dim=0)[1]
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1)
mask1 = new_mask1
mask1 = new_mask1
mask1_bk = mask1
mask1_bk = mask1
if random_routing_dropped_token:
if random_routing_dropped_token:
not_full = capacity - new_mask1.sum(dim=0)
not_full = capacity - new_mask1.sum(dim=0)
sorted_notfull, indices_notfull = torch.sort(not_full, descending=True)
sorted_notfull, indices_notfull = torch.sort(not_full, descending=True)
sorted_notfull = sorted_notfull.to(torch.int64)
sorted_notfull = sorted_notfull.to(torch.int64)
not_full_experts_ids = torch.repeat_interleave(indices_notfull, sorted_notfull)
not_full_experts_ids = torch.repeat_interleave(indices_notfull, sorted_notfull)
shuffle_not_full_ids = torch.randperm(not_full_experts_ids.shape[0])
shuffle_not_full_ids = torch.randperm(not_full_experts_ids.shape[0])
not_full_experts_ids = not_full_experts_ids[shuffle_not_full_ids]
not_full_experts_ids = not_full_experts_ids[shuffle_not_full_ids]
indices1_s_after_drop = torch.argmax(new_mask1, dim=1)
indices1_s_after_drop = torch.argmax(new_mask1, dim=1)
# get drop idx
# get drop idx
drop_mask = 1 - new_mask1.sum(dim=1)
drop_mask = 1 - new_mask1.sum(dim=1)
drop_mask = drop_mask.bool()
drop_mask = drop_mask.bool()
drop_idx = drop_mask.nonzero().view(-1)
drop_idx = drop_mask.nonzero().view(-1)
drop_num = drop_mask.sum().to(torch.int64)
drop_num = drop_mask.sum().to(torch.int64)
indices1_s_after_drop.scatter_(0, drop_idx, not_full_experts_ids[:drop_num])
indices1_s_after_drop.scatter_(0, drop_idx, not_full_experts_ids[:drop_num])
nodrop_mask1 = F.one_hot(indices1_s_after_drop, num_classes=num_experts)
nodrop_mask1 = F.one_hot(indices1_s_after_drop, num_classes=num_experts)
mask1 = nodrop_mask1
mask1 = nodrop_mask1
# Compute locations in capacity buffer
# Compute locations in capacity buffer
locations1 = torch.cumsum(mask1, dim=0) - 1
locations1 = torch.cumsum(mask1, dim=0) - 1
# Store the capacity location for each token
# Store the capacity location for each token
locations1_s = torch.sum(locations1 * mask1, dim=1)
locations1_s = torch.sum(locations1 * mask1, dim=1)
# Normalize gate probabilities
# Normalize gate probabilities
mask1_float = mask1.float()
mask1_float = mask1.float()
gates = gates * mask1_float
gates = gates * mask1_float
locations1_sc = F.one_hot(locations1_s, num_classes=capacity).float() # one hot to float
locations1_sc = F.one_hot(locations1_s, num_classes=capacity).float() # one hot to float
combine_weights = torch.einsum("se,sc->sec", gates, locations1_sc)
combine_weights = torch.einsum("se,sc->sec", gates, locations1_sc)
dispatch_mask = combine_weights.bool()
dispatch_mask = combine_weights.bool()
exp_counts_capacity = torch.sum(mask1_bk)
exp_counts_capacity = torch.sum(mask1_bk)
exp_capacity_rate = exp_counts_capacity / (logits.shape[0])
exp_capacity_rate = exp_counts_capacity / (logits.shape[0])
return [l_aux, exp_capacity_rate], combine_weights, dispatch_mask, exp_counts
return [l_aux, exp_capacity_rate], combine_weights, dispatch_mask, exp_counts
def _get_unpad_data(attention_mask):
def _get_unpad_data(attention_mask):
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = seqlens_in_batch.max().item()
max_seqlen_in_batch = seqlens_in_batch.max().item()
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0))
return (
return (
indices,
indices,
cu_seqlens,
cu_seqlens,
max_seqlen_in_batch,
max_seqlen_in_batch,
)
)
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
warnings.warn(
warnings.warn(
"Calling `transformers.models.llama.modeling_llama._prepare_4d_attention_mask` is deprecated and will be "
"Calling `transformers.models.llama.modeling_llama._prepare_4d_attention_mask` is deprecated and will be "
"removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask"
"removed in v4.37. Use `transformers.modeling_attn_mask_utils._prepare_4d_attention_mask"
)
)
return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
return _prepare_4d_attention_mask(mask=mask, dtype=dtype, tgt_len=tgt_len)
def _make_causal_mask(
def _make_causal_mask(
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
):
):
warnings.warn(
warnings.warn(
"Calling `transformers.models.llama.modeling_llama._make_causal_mask` is deprecated and will be removed in "
"Calling `transformers.models.llama.modeling_llama._make_causal_mask` is deprecated and will be removed in "
"v4.37. Use `transformers.models.llama.modeling_llama.AttentionMaskConverter._make_causal_mask"
"v4.37. Use `transformers.models.llama.modeling_llama.AttentionMaskConverter._make_causal_mask"
)
)
return AttentionMaskConverter._make_causal_mask(
return AttentionMaskConverter._make_causal_mask(
input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
input_ids_shape=input_ids_shape, dtype=dtype, device=device, past_key_values_length=past_key_values_length
)
)
class HunYuanRMSNorm(nn.Module):
class HunYuanRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
def __init__(self, hidden_size, eps=1e-6):
"""
"""
HunYuanRMSNorm is equivalent to T5LayerNorm
HunYuanRMSNorm is equivalent to T5LayerNorm
"""
"""
super().__init__()
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
self.variance_epsilon = eps
def forward(self, hidden_states):
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
return self.weight * hidden_states.to(input_dtype)
ALL_LAYERNORM_LAYERS.append(HunYuanRMSNorm)
ALL_LAYERNORM_LAYERS.append(HunYuanRMSNorm)
class HunYuanRotaryEmbedding(nn.Module):
class HunYuanRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
super().__init__()
self.dim = dim
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.max_position_embeddings = max_position_embeddings
self.base = base
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
# inv_freq = inv_freq.bfloat16()
inv_freq = inv_freq.bfloat16()
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.float32)
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
freqs = torch.outer(t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).float()
emb = torch.cat((freqs, freqs), dim=-1).float()
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached or self.inv_freq.dtype != torch.float32:
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
)
class HunYuanLinearScalingRotaryEmbedding(HunYuanRotaryEmbedding):
class HunYuanLinearScalingRotaryEmbedding(HunYuanRotaryEmbedding):
"""HunYuanRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
"""HunYuanRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
t = t / self.scaling_factor
freqs = torch.outer(t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
class HunYuanDynamicNTKScalingRotaryEmbedding(HunYuanRotaryEmbedding):
class HunYuanDynamicNTKScalingRotaryEmbedding(HunYuanRotaryEmbedding):
"""
"""
HunYuanRotaryEmbedding extended with Dynamic NTK scaling.
HunYuanRotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""
"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
if seq_len > self.max_position_embeddings:
base = self.base * (
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
class HunYuanDynamicNTKAlphaRotaryEmbedding(HunYuanRotaryEmbedding):
class HunYuanDynamicNTKAlphaRotaryEmbedding(HunYuanRotaryEmbedding):
"""
"""
HunYuanRotaryEmbedding extended with Dynamic NTK scaling.
HunYuanRotaryEmbedding extended with Dynamic NTK scaling.
Credits to the Reddit users /u/bloc97 and /u/emozilla
Credits to the Reddit users /u/bloc97 and /u/emozilla
"""
"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_alpha=1.0):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_alpha=1.0):
self.scaling_alpha = scaling_alpha
self.scaling_alpha = scaling_alpha
super().__init__(dim, max_position_embeddings, base, device)
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
self.max_seq_len_cached = seq_len
base = self.base * self.scaling_alpha ** (self.dim / (self.dim-2))
base = self.base * self.scaling_alpha ** (self.dim / (self.dim-2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def rotate_half(x):
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
Args:
q (`torch.Tensor`): The query tensor.
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
position_ids (`torch.Tensor`):
position_ids (`torch.Tensor`):
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
The position indices of the tokens corresponding to the query and key tensors. For example, this can be
used to pass offsetted position ids when working with a KV-cache.
used to pass offsetted position ids when working with a KV-cache.
unsqueeze_dim (`int`, *optional*, defaults to 1):
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
"""
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
cos = cos[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
sin = sin[position_ids].unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
return q_embed, k_embed
class HunYuanMLP(nn.Module):
class HunYuanMLP(nn.Module):
def __init__(self, config: HunYuanConfig, layer_idx=None, is_shared_mlp=False):
def __init__(self, config: HunYuanConfig, layer_idx=None, is_shared_mlp=False):
super().__init__()
super().__init__()
self.config = config
self.config = config
self.layer_idx = layer_idx
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.hidden_size = config.hidden_size
if is_shared_mlp:
if is_shared_mlp:
self.intermediate_size = config.intermediate_size * config.num_shared_expert[0]
self.intermediate_size = config.intermediate_size * config.num_shared_expert
else:
else:
self.intermediate_size = config.intermediate_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
def forward(self, x):
if self.config.pretraining_tp > 1:
if self.config.pretraining_tp > 1:
slice = self.intermediate_size // self.config.pretraining_tp
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
gate_proj = torch.cat(
gate_proj = torch.cat(
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
)
)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
]
]
down_proj = sum(down_proj)
down_proj = sum(down_proj)
else:
else:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
return down_proj
class HunYuanTopKGate(nn.Module):
class HunYuanTopKGate(nn.Module):
def __init__(self, config: HunYuanConfig, layer_idx: Optional[int] = None):
def __init__(self, config: HunYuanConfig, layer_idx: Optional[int] = None):
super().__init__()
super().__init__()
self.config = config
self.config = config
self.layer_idx = layer_idx
self.layer_idx = layer_idx
self.moe_topk = config.moe_topk
self.moe_topk = config.moe_topk
self.drop_tokens = config.moe_drop_tokens
self.drop_tokens = config.moe_drop_tokens
self.min_capacity = 8
self.min_capacity = 8
self.random_routing_dropped_token = config.moe_random_routing_dropped_token
self.random_routing_dropped_token = config.moe_random_routing_dropped_token
self.wg = nn.Linear(config.hidden_size, config.num_experts, bias=False, dtype=torch.float32)
self.wg = nn.Linear(config.hidden_size, config.num_experts, bias=False, dtype=torch.float32)
def forward(self, hidden_states):
def forward(self, hidden_states):
bsz, seq_len, hidden_size = hidden_states.shape
bsz, seq_len, hidden_size = hidden_states.shape
hidden_states = hidden_states.reshape(-1, hidden_size)
hidden_states = hidden_states.reshape(-1, hidden_size)
if self.wg.weight.dtype == torch.float32:
if self.wg.weight.dtype == torch.float32:
hidden_states = hidden_states.float()
hidden_states = hidden_states.float()
logits = self.wg(hidden_states)
logits = self.wg(hidden_states)
if self.moe_topk == 1:
if self.moe_topk == 1:
gate_output = top1gating(logits, random_routing_dropped_token=self.random_routing_dropped_token)
gate_output = top1gating(logits, random_routing_dropped_token=self.random_routing_dropped_token)
else:
else:
gate_output = topkgating(logits, self.moe_topk[0])
gate_output = topkgating(logits, self.moe_topk)
return gate_output
return gate_output
class HunYuanMoE(nn.Module):
class HunYuanMoE(nn.Module):
def __init__(self, config: HunYuanConfig, layer_idx: Optional[int] = None):
def __init__(self, config: HunYuanConfig, layer_idx: Optional[int] = None):
super().__init__()
super().__init__()
self.config = config
self.config = config
self.layer_idx = layer_idx
self.layer_idx = layer_idx
self.moe_topk = config.moe_topk
self.moe_topk = config.moe_topk
self.num_experts = config.num_experts
self.num_experts = config.num_experts
if config.use_mixed_mlp_moe:
if config.use_mixed_mlp_moe:
self.shared_mlp = HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True)
self.shared_mlp = HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=True)
self.gate = HunYuanTopKGate(config, layer_idx=layer_idx)
self.gate = HunYuanTopKGate(config, layer_idx=layer_idx)
self.experts = nn.ModuleList(
self.experts = nn.ModuleList(
[HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False) for _ in range(config.num_experts)]
[HunYuanMLP(config, layer_idx=layer_idx, is_shared_mlp=False) for _ in range(config.num_experts)]
)
)
def forward(self, hidden_states):
def forward(self, hidden_states):
bsz, seq_len, hidden_size = hidden_states.shape
bsz, seq_len, hidden_size = hidden_states.shape
if self.config.use_mixed_mlp_moe:
if self.config.use_mixed_mlp_moe:
hidden_states_mlp = self.shared_mlp(hidden_states)
hidden_states_mlp = self.shared_mlp(hidden_states)
l_moe, combine_weights, dispatch_mask, exp_counts = self.gate(hidden_states)
l_moe, combine_weights, dispatch_mask, exp_counts = self.gate(hidden_states)
reshaped_input = hidden_states.reshape(-1, hidden_size)
reshaped_input = hidden_states.reshape(-1, hidden_size)
dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input)
dispatched_input = torch.einsum("sec,sm->ecm", dispatch_mask.type_as(hidden_states), reshaped_input)
chunks = dispatched_input.chunk(self.num_experts, dim=0)
chunks = dispatched_input.chunk(self.num_experts, dim=0)
expert_outputs = []
expert_outputs = []
for chunk, expert in zip(chunks, self.experts):
for chunk, expert in zip(chunks, self.experts):
expert_outputs.append(expert(chunk))
expert_outputs.append(expert(chunk))
expert_output = torch.cat(expert_outputs, dim=0)
expert_output = torch.cat(expert_outputs, dim=0)
combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(hidden_states), expert_output)
combined_output = torch.einsum("sec,ecm->sm", combine_weights.type_as(hidden_states), expert_output)
combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
combined_output = combined_output.reshape(bsz, seq_len, hidden_size)
if self.config.use_mixed_mlp_moe:
if self.config.use_mixed_mlp_moe:
output = hidden_states_mlp + combined_output
output = hidden_states_mlp + combined_output
else:
else:
output = combined_output
output = combined_output
return output
return output
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
if n_rep == 1:
return hidden_states
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class HunYuanAttention(nn.Module):
class HunYuanAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: HunYuanConfig, layer_idx: Optional[int] = None):
def __init__(self, config: HunYuanConfig, layer_idx: Optional[int] = None):
super().__init__()
super().__init__()
self.config = config
self.config = config
self.layer_idx = layer_idx
self.layer_idx = layer_idx
# layer_idx 从 0 开始
# layer_idx 从 0 开始
self.attention_type = 'cross' if config.use_cla and layer_idx % config.cla_share_factor != 0 else 'self'
self.attention_type = 'cross' if config.use_cla and layer_idx % config.cla_share_factor != 0 else 'self'
if layer_idx is None:
if layer_idx is None:
logger.warning_once(
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
"when creating this class."
)
)
self.attention_dropout = config.attention_dropout
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.rope_theta = config.rope_theta
self.is_causal = True
self.is_causal = True
self.use_qk_norm = config.use_qk_norm
self.use_qk_norm = config.use_qk_norm
if (self.head_dim * self.num_heads) != self.hidden_size:
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
f" and `num_heads`: {self.num_heads})."
)
)
self.q_proj = nn.Linear(self.hidden_size, self.num_h
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
if self.attention_type == 'self':
self.k_proj = nn.Linear(
self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bi