acestep_diff

생성일 비교 결과 만료 없음
1 삭제
574
1 추가
575
# Copyright 2025 The ACESTEO Team. All rights reserved.
# Copyright 2025 The ACESTEO Team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import copy
import math
import math
import time
import time
from typing import Callable, List, Optional, Union
from typing import Callable, List, Optional, Union


import torch
import torch
import torch.nn.functional as F
import torch.nn.functional as F
from torch import nn
from torch import nn


from einops import rearrange
from einops import rearrange


# Transformers imports (sorted by submodule, then alphabetically)
# Transformers imports (sorted by submodule, then alphabetically)
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
from transformers.cache_utils import Cache, DynamicCache, EncoderDecoderCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_layers import GradientCheckpointingLayer
from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_outputs import BaseModelOutput
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from transformers.processing_utils import Unpack
from transformers.processing_utils import Unpack
from transformers.utils import auto_docstring, can_return_tuple, logging
from transformers.utils import auto_docstring, can_return_tuple, logging
from transformers.models.qwen3.modeling_qwen3 import (
from transformers.models.qwen3.modeling_qwen3 import (
Qwen3MLP,
Qwen3MLP,
Qwen3RMSNorm,
Qwen3RMSNorm,
Qwen3RotaryEmbedding,
Qwen3RotaryEmbedding,
apply_rotary_pos_emb,
apply_rotary_pos_emb,
eager_attention_forward,
eager_attention_forward,
)
)


from vector_quantize_pytorch import ResidualFSQ
from vector_quantize_pytorch import ResidualFSQ


# Local config import with fallback
# Local config import with fallback
try:
try:
from .configuration_acestep_v15 import AceStepConfig
from .configuration_acestep_v15 import AceStepConfig
except ImportError:
except ImportError:
from configuration_acestep_v15 import AceStepConfig
from configuration_acestep_v15 import AceStepConfig




logger = logging.get_logger(__name__)
logger = logging.get_logger(__name__)




def create_4d_mask(
def create_4d_mask(
seq_len: int,
seq_len: int,
dtype: torch.dtype,
dtype: torch.dtype,
device: torch.device,
device: torch.device,
attention_mask: Optional[torch.Tensor] = None, # [Batch, Seq_Len]
attention_mask: Optional[torch.Tensor] = None, # [Batch, Seq_Len]
sliding_window: Optional[int] = None,
sliding_window: Optional[int] = None,
is_sliding_window: bool = False,
is_sliding_window: bool = False,
is_causal: bool = True,
is_causal: bool = True,
) -> torch.Tensor:
) -> torch.Tensor:
"""
"""
General 4D Attention Mask generator compatible with CPU/Mac/SDPA and Eager mode.
General 4D Attention Mask generator compatible with CPU/Mac/SDPA and Eager mode.
Supports use cases:
Supports use cases:
1. Causal Full: is_causal=True, is_sliding_window=False (standard GPT)
1. Causal Full: is_causal=True, is_sliding_window=False (standard GPT)
2. Causal Sliding: is_causal=True, is_sliding_window=True (Mistral/Qwen local window)
2. Causal Sliding: is_causal=True, is_sliding_window=True (Mistral/Qwen local window)
3. Bidirectional Full: is_causal=False, is_sliding_window=False (BERT/Encoder)
3. Bidirectional Full: is_causal=False, is_sliding_window=False (BERT/Encoder)
4. Bidirectional Sliding: is_causal=False, is_sliding_window=True (Longformer local)
4. Bidirectional Sliding: is_causal=False, is_sliding_window=True (Longformer local)


Returns:
Returns:
[Batch, 1, Seq_Len, Seq_Len] additive mask (0.0 for keep, -inf for mask)
[Batch, 1, Seq_Len, Seq_Len] additive mask (0.0 for keep, -inf for mask)
"""
"""
# ------------------------------------------------------
# ------------------------------------------------------
# 1. Construct basic geometry mask [Seq_Len, Seq_Len]
# 1. Construct basic geometry mask [Seq_Len, Seq_Len]
# ------------------------------------------------------
# ------------------------------------------------------


# Build index matrices
# Build index matrices
# i (Query): [0, 1, ..., L-1]
# i (Query): [0, 1, ..., L-1]
# j (Key): [0, 1, ..., L-1]
# j (Key): [0, 1, ..., L-1]
indices = torch.arange(seq_len, device=device)
indices = torch.arange(seq_len, device=device)
# diff = i - j
# diff = i - j
diff = indices.unsqueeze(1) - indices.unsqueeze(0)
diff = indices.unsqueeze(1) - indices.unsqueeze(0)


# Initialize all True (all positions visible)
# Initialize all True (all positions visible)
valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool)
valid_mask = torch.ones((seq_len, seq_len), device=device, dtype=torch.bool)


# (A) Handle causality (Causal)
# (A) Handle causality (Causal)
if is_causal:
if is_causal:
# i >= j => diff >= 0
# i >= j => diff >= 0
valid_mask = valid_mask & (diff >= 0)
valid_mask = valid_mask & (diff >= 0)


# (B) Handle sliding window
# (B) Handle sliding window
if is_sliding_window and sliding_window is not None:
if is_sliding_window and sliding_window is not None:
if is_causal:
if is_causal:
# Causal sliding: only attend to past window steps
# Causal sliding: only attend to past window steps
# i - j <= window => diff <= window
# i - j <= window => diff <= window
# (diff >= 0 already handled above)
# (diff >= 0 already handled above)
valid_mask = valid_mask & (diff <= sliding_window)
valid_mask = valid_mask & (diff <= sliding_window)
else:
else:
# Bidirectional sliding: attend past and future window steps
# Bidirectional sliding: attend past and future window steps
# |i - j| <= window => abs(diff) <= sliding_window
# |i - j| <= window => abs(diff) <= sliding_window
valid_mask = valid_mask & (torch.abs(diff) <= sliding_window)
valid_mask = valid_mask & (torch.abs(diff) <= sliding_window)


# Expand dimensions to [1, 1, Seq_Len, Seq_Len] for broadcasting
# Expand dimensions to [1, 1, Seq_Len, Seq_Len] for broadcasting
valid_mask = valid_mask.unsqueeze(0).unsqueeze(0)
valid_mask = valid_mask.unsqueeze(0).unsqueeze(0)


# ------------------------------------------------------
# ------------------------------------------------------
# 2. Apply padding mask (Key Masking)
# 2. Apply padding mask (Key Masking)
# ------------------------------------------------------
# ------------------------------------------------------
if attention_mask is not None:
if attention_mask is not None:
# attention_mask shape: [Batch, Seq_Len] (1=valid, 0=padding)
# attention_mask shape: [Batch, Seq_Len] (1=valid, 0=padding)
# We want to mask out invalid keys (columns)
# We want to mask out invalid keys (columns)
# Expand shape: [Batch, 1, 1, Seq_Len]
# Expand shape: [Batch, 1, 1, Seq_Len]
padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
padding_mask_4d = attention_mask.view(attention_mask.shape[0], 1, 1, seq_len).to(torch.bool)
# Broadcasting: Geometry Mask [1, 1, L, L] & Padding Mask [B, 1, 1, L]
# Broadcasting: Geometry Mask [1, 1, L, L] & Padding Mask [B, 1, 1, L]
# Result shape: [B, 1, L, L]
# Result shape: [B, 1, L, L]
valid_mask = valid_mask & padding_mask_4d
valid_mask = valid_mask & padding_mask_4d


# ------------------------------------------------------
# ------------------------------------------------------
# 3. Convert to additive mask
# 3. Convert to additive mask
# ------------------------------------------------------
# ------------------------------------------------------
# Get the minimal value for current dtype
# Get the minimal value for current dtype
min_dtype = torch.finfo(dtype).min
min_dtype = torch.finfo(dtype).min
# Create result tensor filled with -inf by default
# Create result tensor filled with -inf by default
mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
mask_tensor = torch.full(valid_mask.shape, min_dtype, dtype=dtype, device=device)
# Set valid positions to 0.0
# Set valid positions to 0.0
mask_tensor.masked_fill_(valid_mask, 0.0)
mask_tensor.masked_fill_(valid_mask, 0.0)
return mask_tensor
return mask_tensor




def pack_sequences(hidden1: torch.Tensor, hidden2: torch.Tensor, mask1: torch.Tensor, mask2: torch.Tensor):
def pack_sequences(hidden1: torch.Tensor, hidden2: torch.Tensor, mask1: torch.Tensor, mask2: torch.Tensor):
"""
"""
Pack two sequences by concatenating and sorting them based on mask values.
Pack two sequences by concatenating and sorting them based on mask values.


Args:
Args:
hidden1: First hidden states tensor of shape [B, L1, D]
hidden1: First hidden states tensor of shape [B, L1, D]
hidden2: Second hidden states tensor of shape [B, L2, D]
hidden2: Second hidden states tensor of shape [B, L2, D]
mask1: First mask tensor of shape [B, L1]
mask1: First mask tensor of shape [B, L1]
mask2: Second mask tensor of shape [B, L2]
mask2: Second mask tensor of shape [B, L2]


Returns:
Returns:
Tuple of (packed_hidden_states, new_mask) where:
Tuple of (packed_hidden_states, new_mask) where:
- packed_hidden_states: Packed hidden states with valid tokens (mask=1) first, shape [B, L1+L2, D]
- packed_hidden_states: Packed hidden states with valid tokens (mask=1) first, shape [B, L1+L2, D]
- new_mask: New mask tensor indicating valid positions, shape [B, L1+L2]
- new_mask: New mask tensor indicating valid positions, shape [B, L1+L2]
"""
"""
# Step 1: Concatenate hidden states and masks along sequence dimension
# Step 1: Concatenate hidden states and masks along sequence dimension
hidden_cat = torch.cat([hidden1, hidden2], dim=1) # [B, L, D]
hidden_cat = torch.cat([hidden1, hidden2], dim=1) # [B, L, D]
mask_cat = torch.cat([mask1, mask2], dim=1) # [B, L]
mask_cat = torch.cat([mask1, mask2], dim=1) # [B, L]


B, L, D = hidden_cat.shape
B, L, D = hidden_cat.shape


# Step 2: Sort indices so that mask values of 1 come before 0
# Step 2: Sort indices so that mask values of 1 come before 0
sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True) # [B, L]
sort_idx = mask_cat.argsort(dim=1, descending=True, stable=True) # [B, L]


# Step 3: Reorder hidden states using sorted indices
# Step 3: Reorder hidden states using sorted indices
hidden_left = torch.gather(hidden_cat, 1, sort_idx.unsqueeze(-1).expand(B, L, D))
hidden_left = torch.gather(hidden_cat, 1, sort_idx.unsqueeze(-1).expand(B, L, D))


# Step 4: Create new mask based on valid sequence lengths
# Step 4: Create new mask based on valid sequence lengths
lengths = mask_cat.sum(dim=1) # [B]
lengths = mask_cat.sum(dim=1) # [B]
new_mask = (torch.arange(L, dtype=torch.long, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1))
new_mask = (torch.arange(L, dtype=torch.long, device=hidden_cat.device).unsqueeze(0) < lengths.unsqueeze(1))


return hidden_left, new_mask
return hidden_left, new_mask




def sample_t_r(batch_size, device, dtype, data_proportion=0.0, timestep_mu=-0.4, timestep_sigma=1.0, use_meanflow=True):
def sample_t_r(batch_size, device, dtype, data_proportion=0.0, timestep_mu=-0.4, timestep_sigma=1.0, use_meanflow=True):
"""
"""
Sample timestep t and r for flow matching training.
Sample timestep t and r for flow matching training.


Args:
Args:
batch_size: Batch size
batch_size: Batch size
device: Device to create tensors on
device: Device to create tensors on
dtype: Data type for tensors
dtype: Data type for tensors
data_proportion: Proportion of data samples (0.0 to 1.0)
data_proportion: Proportion of data samples (0.0 to 1.0)
timestep_mu: Mean for timestep sampling
timestep_mu: Mean for timestep sampling
timestep_sigma: Standard deviation for timestep sampling
timestep_sigma: Standard deviation for timestep sampling
use_meanflow: Whether to use meanflow (if False, data_proportion is set to 1.0)
use_meanflow: Whether to use meanflow (if False, data_proportion is set to 1.0)


Returns:
Returns:
Tuple of (t, r) tensors, each of shape [batch_size]
Tuple of (t, r) tensors, each of shape [batch_size]
"""
"""
t = torch.sigmoid(torch.randn((batch_size,), device=device, dtype=dtype) * timestep_sigma + timestep_mu)
t = torch.sigmoid(torch.randn((batch_size,), device=device, dtype=dtype) * timestep_sigma + timestep_mu)
r = torch.sigmoid(torch.randn((batch_size,), device=device, dtype=dtype) * timestep_sigma + timestep_mu)
r = torch.sigmoid(torch.randn((batch_size,), device=device, dtype=dtype) * timestep_sigma + timestep_mu)
# Assign t = max, r = min, for each pair
# Assign t = max, r = min, for each pair
t, r = torch.maximum(t, r), torch.minimum(t, r)
t, r = torch.maximum(t, r), torch.minimum(t, r)
if not use_meanflow:
if not use_meanflow:
data_proportion = 1.0
data_proportion = 1.0
data_size = int(batch_size * data_proportion)
data_size = int(batch_size * data_proportion)
zero_mask = torch.arange(batch_size, device=device) < data_size
zero_mask = torch.arange(batch_size, device=device) < data_size
r = torch.where(zero_mask, t, r)
r = torch.where(zero_mask, t, r)
return t, r
return t, r




class TimestepEmbedding(nn.Module):
class TimestepEmbedding(nn.Module):
"""
"""
Timestep embedding module for diffusion models.
Timestep embedding module for diffusion models.
Converts timestep values into high-dimensional embeddings using sinusoidal
Converts timestep values into high-dimensional embeddings using sinusoidal
positional encoding, followed by MLP layers. Used for conditioning diffusion
positional encoding, followed by MLP layers. Used for conditioning diffusion
models on timestep information.
models on timestep information.
"""
"""
def __init__(
def __init__(
self,
self,
in_channels: int,
in_channels: int,
time_embed_dim: int,
time_embed_dim: int,
scale: float = 1000,
scale: float = 1000,
):
):
super().__init__()
super().__init__()


self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=True)
self.act1 = nn.SiLU()
self.act1 = nn.SiLU()
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, bias=True)
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, bias=True)
self.in_channels = in_channels
self.in_channels = in_channels
self.act2 = nn.SiLU()
self.act2 = nn.SiLU()
self.time_proj = nn.Linear(time_embed_dim, time_embed_dim * 6)
self.time_proj = nn.Linear(time_embed_dim, time_embed_dim * 6)
self.scale = scale
self.scale = scale


def timestep_embedding(self, t, dim, max_period=10000):
def timestep_embedding(self, t, dim, max_period=10000):
"""
"""
Create sinusoidal timestep embeddings.
Create sinusoidal timestep embeddings.


Args:
Args:
t: A 1-D tensor of N indices, one per batch element. These may be fractional.
t: A 1-D tensor of N indices, one per batch element. These may be fractional.
dim: The dimension of the output embeddings.
dim: The dimension of the output embeddings.
max_period: Controls the minimum frequency of the embeddings.
max_period: Controls the minimum frequency of the embeddings.


Returns:
Returns:
An (N, D) tensor of positional embeddings.
An (N, D) tensor of positional embeddings.
"""
"""
t = t * self.scale
t = t * self.scale
half = dim // 2
half = dim // 2
freqs = torch.exp(
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=t.device)
).to(device=t.device)
args = t[:, None].float() * freqs[None]
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
return embedding


def forward(self, t):
def forward(self, t):
t_freq = self.timestep_embedding(t, self.in_channels)
t_freq = self.timestep_embedding(t, self.in_channels)
temb = self.linear_1(t_freq.to(t.dtype))
temb = self.linear_1(t_freq.to(t.dtype))
temb = self.act1(temb)
temb = self.act1(temb)
temb = self.linear_2(temb)
temb = self.linear_2(temb)
timestep_proj = self.time_proj(self.act2(temb)).unflatten(1, (6, -1))
timestep_proj = self.time_proj(self.act2(temb)).unflatten(1, (6, -1))
return temb, timestep_proj
return temb, timestep_proj


class AceStepAttention(nn.Module):
class AceStepAttention(nn.Module):
"""
"""
Multi-headed attention module for AceStep model.
Multi-headed attention module for AceStep model.


Implements the attention mechanism from 'Attention Is All You Need' paper,
Implements the attention mechanism from 'Attention Is All You Need' paper,
with support for both self-attention and cross-attention modes. Uses RMSNorm
with support for both self-attention and cross-attention modes. Uses RMSNorm
for query and key normalization, and supports sliding window attention for
for query and key normalization, and supports sliding window attention for
efficient long-sequence processing.
efficient long-sequence processing.
"""
"""


def __init__(self, config: AceStepConfig, layer_idx: int, is_cross_attention: bool = False, is_causal: bool = False):
def __init__(self, config: AceStepConfig, layer_idx: int, is_cross_attention: bool = False, is_causal: bool = False):
super().__init__()
super().__init__()
self.config = config
self.config = config
self.layer_idx = layer_idx
self.layer_idx = layer_idx
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
self.scaling = self.head_dim**-0.5
self.scaling = self.head_dim**-0.5
self.attention_dropout = config.attention_dropout
self.attention_dropout = config.attention_dropout
if is_cross_attention:
if is_cross_attention:
is_causal = False
is_causal = False
self.is_causal = is_causal
self.is_causal = is_causal
self.is_cross_attention = is_cross_attention
self.is_cross_attention = is_cross_attention


self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias)
self.q_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias)
# Apply RMS normalization only on the head dimension (unlike OLMo)
# Apply RMS normalization only on the head dimension (unlike OLMo)
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.q_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.k_norm = Qwen3RMSNorm(self.head_dim, eps=config.rms_norm_eps)
self.attention_type = config.layer_types[layer_idx]
self.attention_type = config.layer_types[layer_idx]
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None
self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None


def forward(
def forward(
self,
self,
hidden_states: torch.Tensor,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor],
past_key_value: Optional[Cache] = None,
past_key_value: Optional[Cache] = None,
cache_position: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] = None,
output_attentions: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs: Unpack[FlashAttentionKwargs],
**kwargs: Unpack[FlashAttentionKwargs],
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
input_shape = hidden_states.shape[:-1]
input_shape = hidden_states.shape[:-1]
hidden_shape = (*input_shape, -1, self.head_dim)
hidden_shape = (*input_shape, -1, self.head_dim)


# Project and normalize query states
# Project and normalize query states
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)


# Determine if this is cross-attention (requires encoder_hidden_states)
# Determine if this is cross-attention (requires encoder_hidden_states)
is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None
is_cross_attention = self.is_cross_attention and encoder_hidden_states is not None
# Cross-attention path: attend to encoder hidden states
# Cross-attention path: attend to encoder hidden states
if is_cross_attention:
if is_cross_attention:
encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim)
encoder_hidden_shape = (*encoder_hidden_states.shape[:-1], -1, self.head_dim)
if past_key_value is not None:
if past_key_value is not None:
is_updated = past_key_value.is_updated.get(self.layer_idx)
is_updated = past_key_value.is_updated.get(self.layer_idx)
# After the first generated token, we can reuse all key/value states from cache
# After the first generated token, we can reuse all key/value states from cache
curr_past_key_value = past_key_value.cross_attention_cache
curr_past_key_value = past_key_value.cross_attention_cache
# Conditions for calculating key and value states
# Conditions for calculating key and value states
if not is_updated:
if not is_updated:
# Compute and cache K/V for the first time
# Compute and cache K/V for the first time
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
# Update cache: save all key/value states to cache for fast auto-regressive generation
# Update cache: save all key/value states to cache for fast auto-regressive generation
key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx)
key_states, value_states = curr_past_key_value.update(key_states, value_states, self.layer_idx)
# Set flag that this layer's cross-attention cache is updated
# Set flag that this layer's cross-attention cache is updated
past_key_value.is_updated[self.layer_idx] = True
past_key_value.is_updated[self.layer_idx] = True
else:
else:
# Reuse cached key/value states for subsequent tokens
# Reuse cached key/value states for subsequent tokens
key_states = curr_past_key_value.layers[self.layer_idx].keys
key_states = curr_past_key_value.layers[self.layer_idx].keys
value_states = curr_past_key_value.layers[self.layer_idx].values
value_states = curr_past_key_value.layers[self.layer_idx].values
else:
else:
# No cache used, compute K/V directly
# No cache used, compute K/V directly
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(encoder_hidden_states).view(encoder_hidden_shape)).transpose(1, 2)
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
value_states = self.v_proj(encoder_hidden_states).view(encoder_hidden_shape).transpose(1, 2)
# Self-attention path: attend to the same sequence
# Self-attention path: attend to the same sequence
else:
else:
# Project and normalize key/value states for self-attention
# Project and normalize key/value states for self-attention
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
# Apply rotary position embeddings (RoPE) if provided
# Apply rotary position embeddings (RoPE) if provided
if position_embeddings is not None:
if position_embeddings is not None:
cos, sin = position_embeddings
cos, sin = position_embeddings
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)


# Update cache for auto-regressive generation
# Update cache for auto-regressive generation
if past_key_value is not None:
if past_key_value is not None:
# Sin and cos are specific to RoPE models; cache_position needed for the static cache
# Sin and cos are specific to RoPE models; cache_position needed for the static cache
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)


attention_interface: Callable = eager_attention_forward
attention_interface: Callable = eager_attention_forward
if is_cross_attention and output_attentions:
if is_cross_attention and output_attentions:
attention_interface: Callable = eager_attention_forward
attention_interface: Callable = eager_attention_forward
elif self.config._attn_implementation != "eager":
elif self.config._attn_implementation != "eager":
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
attn_output, attn_weights = attention_interface(
self,
self,
query_states,
query_states,
key_states,
key_states,
value_states,
value_states,
attention_mask,
attention_mask,
dropout=self.attention_dropout if self.training else 0.0,
dropout=self.attention_dropout if self.training else 0.0,
scaling=self.scaling,
scaling=self.scaling,
sliding_window=self.sliding_window if not self.is_cross_attention else None,
sliding_window=self.sliding_window if not self.is_cross_attention else None,
**kwargs,
**kwargs,
)
)


attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = attn_output.reshape(*input_shape, -1).contiguous()
attn_output = self.o_proj(attn_output)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights
return attn_output, attn_weights




class AceStepEncoderLayer(GradientCheckpointingLayer):
class AceStepEncoderLayer(GradientCheckpointingLayer):
"""
"""
Encoder layer for AceStep model.
Encoder layer for AceStep model.


Consists of self-attention and MLP (feed-forward) sub-layers with residual connections.
Consists of self-attention and MLP (feed-forward) sub-layers with residual connections.
"""
"""


def __init__(self, config, layer_idx: int):
def __init__(self, config, layer_idx: int):
super().__init__()
super().__init__()
self.hidden_size = config.hidden_size
self.hidden_size = config.hidden_size
self.config = config
self.config = config
self.layer_idx = layer_idx
self.layer_idx = layer_idx


# Self-attention sub-layer
# Self-attention sub-layer
self.self_attn = AceStepAttention(
self.self_attn = AceStepAttention(
config=config,
config=config,
layer_idx=layer_idx,
layer_idx=layer_idx,
is_cross_attention=False,
is_cross_attention=False,
is_causal=False,
is_causal=False,
)
)
self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.input_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)


# MLP (feed-forward) sub-layer
# MLP (feed-forward) sub-layer
self.mlp = Qwen3MLP(config)
self.mlp = Qwen3MLP(config)
self.attention_type = config.layer_types[layer_idx]
self.attention_type = config.layer_types[layer_idx]


def forward(
def forward(
self,
self,
hidden_states: torch.Tensor,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
position_embeddings: tuple[torch.Tensor, torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = False,
output_attentions: Optional[bool] = False,
**kwargs,
**kwargs,
) -> tuple[
) -> tuple[
torch.FloatTensor,
torch.FloatTensor,
Optional[tuple[torch.FloatTensor, torch.FloatTensor]],
Optional[tuple[torch.FloatTensor, torch.FloatTensor]],
]:
]:
# Self-attention with residual connection
# Self-attention with residual connection
residual = hidden_states
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.input_layernorm(hidden_states)
hidden_states, self_attn_weights = self.self_attn(
hidden_states, self_attn_weights = self.self_attn(
hidden_states=hidden_states,
hidden_states=hidden_states,
position_embeddings=position_embeddings,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
attention_mask=attention_mask,
position_ids=position_ids,
position_ids=position_ids,
output_attentions=output_attentions,
output_attentions=output_attentions,
# Encoders don't use cache
# Encoders don't use cache
use_cache=False,
use_cache=False,
past_key_value=None,
past_key_value=None,
**kwargs,
**kwargs,
)
)
hidden_states = residual + hidden_states
hidden_states = residual + hidden_states


# MLP with residual connection
# MLP with residual connection
residual = hidden_states
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
hidden_states = residual + hidden_states


outputs = (hidden_states,)
outputs = (hidden_states,)


if output_attentions:
if output_attentions:
outputs += (self_attn_weights,)
outputs += (self_attn_weights,)


return outputs
return outputs




class AceStepDiTLayer(GradientCheckpointingLayer):
class AceStepDiTLayer(GradientCheckpointingLayer):
"""
"""
DiT (Diffusion Transformer) layer for AceStep model.
DiT (Diffusion Transformer) layer for AceStep model.
Implements a transformer layer with three main components:
Implements a transformer layer with three main components:
1. Self-attention with adaptive layer norm (AdaLN)
1. Self-attention with adaptive layer norm (AdaLN)
2. Cross-attention (optional) for conditioning on encoder outputs
2. Cross-attention (optional) for conditioning on encoder outputs
3. Feed-forward MLP with adaptive layer norm
3. Feed-forward MLP with adaptive layer norm
Uses scale-shift modulation from timestep embeddings for adaptive normalization.
Uses scale-shift modulation from timestep embeddings for adaptive normalization.
"""
"""
def __init__(self, config: AceStepConfig, layer_idx: int, use_cross_attention: bool = True):
def __init__(self, config: AceStepConfig, layer_idx: int, use_cross_attention: bool = True):
super().__init__()
super().__init__()


# 1. Self-attention sub-layer with adaptive normalization
# 1. Self-attention sub-layer with adaptive normalization
self.self_attn_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.self_attn_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.self_attn = AceStepAttention(config=config, layer_idx=layer_idx)
self.self_attn = AceStepAttention(config=config, layer_idx=layer_idx)


# 2. Cross-attention sub-layer (optional, for encoder conditioning)
# 2. Cross-attention sub-layer (optional, for encoder conditioning)
self.use_cross_attention = use_cross_attention
self.use_cross_attention = use_cross_attention
if self.use_cross_attention:
if self.use_cross_attention:
self.cross_attn_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.cross_attn_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.cross_attn = AceStepAttention(config=config, layer_idx=layer_idx, is_cross_attention=True)
self.cross_attn = AceStepAttention(config=config, layer_idx=layer_idx, is_cross_attention=True)


# 3. Feed-forward MLP sub-layer with adaptive normalization
# 3. Feed-forward MLP sub-layer with adaptive normalization
self.mlp_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp_norm = Qwen3RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp = Qwen3MLP(config)
self.mlp = Qwen3MLP(config)


# Scale-shift table for adaptive layer norm modulation (6 values: 3 for self-attn, 3 for MLP)
# Scale-shift table for adaptive layer norm modulation (6 values: 3 for self-attn, 3 for MLP)
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, config.hidden_size) / config.hidden_size**0.5)
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, config.hidden_size) / config.hidden_size**0.5)
self.attention_type = config.layer_types[layer_idx]
self.attention_type = config.layer_types[layer_idx]
def forward(
def forward(
self,
self,
hidden_states: torch.Tensor,
hidden_states: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor],
position_embeddings: tuple[torch.Tensor, torch.Tensor],
temb: torch.Tensor,
temb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
past_key_value: Optional[EncoderDecoderCache] = None,
output_attentions: Optional[bool] = False,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
**kwargs,
**kwargs,
) -> torch.Tensor:
) -> torch.Tensor:


# Extract scale-shift parameters for adaptive layer norm from timestep embeddings
# Extract scale-shift parameters for adaptive layer norm from timestep embeddings
# 6 values: (shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa)
# 6 values: (shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
self.scale_shift_table + temb
self.scale_shift_table + temb
).chunk(6, dim=1)
).chunk(6, dim=1)


# Step 1: Self-attention with adaptive layer norm (AdaLN)
# Step 1: Self-attention with adaptive layer norm (AdaLN)
# Apply adaptive normalization: norm(x) * (1 + scale) + shift
# Apply adaptive normalization: norm(x) * (1 + scale) + shift
norm_hidden_states = (self.self_attn_norm(hidden_states) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
norm_hidden_states = (self.self_attn_norm(hidden_states) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
attn_output, self_attn_weights = self.self_attn(
attn_output, self_attn_weights = self.self_attn(
hidden_states=norm_hidden_states,
hidden_states=norm_hidden_states,
position_embeddings=position_embeddings,
position_embeddings=position_embeddings,
attention_mask=attention_mask,
attention_mask=attention_mask,
position_ids=position_ids,
position_ids=position_ids,
output_attentions=output_attentions,
output_attentions=output_attentions,
use_cache=False,
use_cache=False,
past_key_value=None,
past_key_value=None,
**kwargs,
**kwargs,
)
)
# Apply gated residual connection: x = x + attn_output * gate
# Apply gated residual connection: x = x + attn_output * gate
hidden_states = (hidden_states + attn_output * gate_msa).type_as(hidden_states)
hidden_states = (hidden_states + attn_output * gate_msa).type_as(hidden_states)


# Step 2: Cross-attention (if enabled) for conditioning on encoder outputs
# Step 2: Cross-attention (if enabled) for conditioning on encoder outputs
if self.use_cross_attention:
if self.use_cross_attention:
norm_hidden_states = self.cross_attn_norm(hidden_states).type_as(hidden_states)
norm_hidden_states = self.cross_attn_norm(hidden_states).type_as(hidden_states)
attn_output, cross_attn_weights = self.cross_attn(
attn_output, cross_attn_weights = self.cross_attn(
hidden_states=norm_hidden_states,
hidden_states=norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
attention_mask=encoder_attention_mask,
past_key_value=past_key_value,
past_key_value=past_key_value,
output_attentions=output_attentions,
output_attentions=output_attentions,
use_cache=use_cache,
use_cache=use_cache,
**kwargs,
**kwargs,
)
)
# Standard residual connection for cross-attention
# Standard residual connection for cross-attention
hidden_states = hidden_states + attn_output
hidden_states = hidden_states + attn_output


# Step 3: Feed-forward (MLP) with adaptive layer norm
# Step 3: Feed-forward (MLP) with adaptive layer norm
# Apply adaptive normalization for MLP: norm(x) * (1 + scale) + shift
# Apply adaptive normalization for MLP: norm(x) * (1 + scale) + shift
norm_hidden_states = (self.mlp_norm(hidden_states) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states)
norm_hidden_states = (self.mlp_norm(hidden_states) * (1 + c_scale_msa) + c_shift_msa).type_as(hidden_states)
ff_output = self.mlp(norm_hidden_states)
ff_output = self.mlp(norm_hidden_states)
# Apply gated residual connection: x = x + mlp_output * gate
# Apply gated residual connection: x = x + mlp_output * gate
hidden_states = (hidden_states + ff_output * c_gate_msa).type_as(hidden_states)
hidden_states = (hidden_states + ff_output * c_gate_msa).type_as(hidden_states)


outputs = (hidden_states,)
outputs = (hidden_states,)
if output_attentions:
if output_attentions:
outputs += (self_attn_weights, cross_attn_weights)
outputs += (self_attn_weights, cross_attn_weights)


return outputs
return outputs




@auto_docstring
@auto_docstring
class AceStepPreTrainedModel(PreTrainedModel):
class AceStepPreTrainedModel(PreTrainedModel):
config_class = AceStepConfig
config_class = AceStepConfig
base_model_prefix = "model"
base_model_prefix = "model"
supports_gradient_checkpointing = True
supports_gradient_checkpointing = True
_no_split_modules = ["AceStepEncoderLayer", "AceStepDiTLayer"]
_no_split_modules = ["AceStepEncoderLayer", "AceStepDiTLayer"]
_skip_keys_device_placement = ["past_key_values"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn_3 = True
_supports_flash_attn_3 = True
_supports_flash_attn_2 = True
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_sdpa = True
_supports_flex_attn = True
_supports_flex_attn = True
_supports_cache_class = True
_supports_cache_class = True
_supports_quantized_cache = True
_supports_quantized_cache = True
_supports_static_cache = True
_supports_static_cache = True
_supports_attention_backend = True
_supports_attention_backend = True


def _init_weights(self, module):
def _init_weights(self, module):
"""
"""
Initialize weights for different module types.
Initialize weights for different module types.


TODO: Support separate initialization for encoders and decoders.
TODO: Support separate initialization for encoders and decoders.
"""
"""
std = self.config.initializer_range
std = self.config.initializer_range
if isinstance(module, nn.Linear):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
if module.bias is not None:
module.bias.data.zero_()
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, Qwen3RMSNorm):
elif isinstance(module, Qwen3RMSNorm):
module.weight.data.fill_(1.0)
module.weight.data.fill_(1.0)




class AceStepLyricEncoder(Ac
class AceStepLyr