diff
484 lines
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
#     http://www.apache.org/licenses/LICENSE-2.0
#     http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import math
import math
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union
import numpy as np
import torch
import torch
import torch.nn as nn
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward
from ..attention import FeedForward
from ..attention_processor import Attention
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
from ..normalization import FP32LayerNorm
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
logger = logging.get_logger(__name__)  # pylint: disable=invalid-name
class WanAttnProcessor2_0:
class SkyReelsV2AttnProcessor2_0:
    def __init__(self):
    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
            raise ImportError(
                "SkyReelsV2AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
            )
        self._flag_ar_attention = False
    def __call__(
    def __call__(
        self,
        self,
        attn: Attention,
        attn: Attention,
        hidden_states: torch.Tensor,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        rotary_emb: Optional[torch.Tensor] = None,
        rotary_emb: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
    ) -> torch.Tensor:
        encoder_hidden_states_img = None
        encoder_hidden_states_img = None
        if attn.add_k_proj is not None:
        if attn.add_k_proj is not None:
            # 512 is the context length of the text encoder, hardcoded for now
            # 512 is the context length of the text encoder, hardcoded for now
            image_context_length = encoder_hidden_states.shape[1] - 512
            image_context_length = encoder_hidden_states.shape[1] - 512
            encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
            encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
            encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
            encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
        if encoder_hidden_states is None:
        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
            encoder_hidden_states = hidden_states
        query = attn.to_q(hidden_states)
        query = attn.to_q(hidden_states)
        key = attn.to_k(encoder_hidden_states)
        key = attn.to_k(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)
        value = attn.to_v(encoder_hidden_states)
        if attn.norm_q is not None:
        if attn.norm_q is not None:
            query = attn.norm_q(query)
            query = attn.norm_q(query)
        if attn.norm_k is not None:
        if attn.norm_k is not None:
            key = attn.norm_k(key)
            key = attn.norm_k(key)
        query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
        query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
        key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
        key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
        value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
        value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
        if rotary_emb is not None:
        if rotary_emb is not None:
            def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
            def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
                x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
                x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
                x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
                x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
                return x_out.type_as(hidden_states)
                return x_out.type_as(hidden_states)
            query = apply_rotary_emb(query, rotary_emb)
            query = apply_rotary_emb(query, rotary_emb)
            key = apply_rotary_emb(key, rotary_emb)
            key = apply_rotary_emb(key, rotary_emb)
        # I2V task
        # I2V task
        hidden_states_img = None
        hidden_states_img = None
        if encoder_hidden_states_img is not None:
        if encoder_hidden_states_img is not None:
            key_img = attn.add_k_proj(encoder_hidden_states_img)
            key_img = attn.add_k_proj(encoder_hidden_states_img)
            key_img = attn.norm_added_k(key_img)
            key_img = attn.norm_added_k(key_img)
            value_img = attn.add_v_proj(encoder_hidden_states_img)
            value_img = attn.add_v_proj(encoder_hidden_states_img)
            key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
            key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
            value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
            value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
            hidden_states_img = F.scaled_dot_product_attention(
            hidden_states_img = F.scaled_dot_product_attention(
                query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
                query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
            )
            )
            hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
            hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
            hidden_states_img = hidden_states_img.type_as(query)
            hidden_states_img = hidden_states_img.type_as(query)
        hidden_states = F.scaled_dot_product_attention(
        if self._flag_ar_attention:
            query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
            is_self_attention = encoder_hidden_states == hidden_states
        )
            hidden_states = F.scaled_dot_product_attention(
                query.to(torch.bfloat16) if is_self_attention else query,
                key.to(torch.bfloat16) if is_self_attention else key,
                value.to(torch.bfloat16) if is_self_attention else value,
                attn_mask=attention_mask,
                dropout_p=0.0,
                is_causal=False,
            )
        else:
            hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
        hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
        hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
        hidden_states = hidden_states.type_as(query)
        hidden_states = hidden_states.type_as(query)
        if hidden_states_img is not None:
        if hidden_states_img is not None:
            hidden_states = hidden_states + hidden_states_img
            hidden_states = hidden_states + hidden_states_img
        hidden_states = attn.to_out[0](hidden_states)
        hidden_states = attn.to_out[0](hidden_states)
        hidden_states = attn.to_out[1](hidden_states)
        hidden_states = attn.to_out[1](hidden_states)
        return hidden_states
        return hidden_states
    def set_ar_attention(self):
        self._flag_ar_attention = True
class WanImageEmbedding(torch.nn.Module):
class SkyReelsV2ImageEmbedding(torch.nn.Module):
    def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
    def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
        super().__init__()
        super().__init__()
        self.norm1 = FP32LayerNorm(in_features)
        self.norm1 = FP32LayerNorm(in_features)
        self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
        self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
        self.norm2 = FP32LayerNorm(out_features)
        self.norm2 = FP32LayerNorm(out_features)
        if pos_embed_seq_len is not None:
        if pos_embed_seq_len is not None:
            self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
            self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
        else:
        else:
            self.pos_embed = None
            self.pos_embed = None
    def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
    def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
        if self.pos_embed is not None:
        if self.pos_embed is not None:
            batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
            batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
            encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
            encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
            encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
            encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
        hidden_states = self.norm1(encoder_hidden_states_image)
        hidden_states = self.norm1(encoder_hidden_states_image)
        hidden_states = self.ff(hidden_states)
        hidden_states = self.ff(hidden_states)
        hidden_states = self.norm2(hidden_states)
        hidden_states = self.norm2(hidden_states)
        return hidden_states
        return hidden_states
class WanTimeTextImageEmbedding(nn.Module):
class SkyReelsV2TimeTextImageEmbedding(nn.Module):
    def __init__(
    def __init__(
        self,
        self,
        dim: int,
        dim: int,
        time_freq_dim: int,
        time_freq_dim: int,
        time_proj_dim: int,
        time_proj_dim: int,
        text_embed_dim: int,
        text_embed_dim: int,
        image_embed_dim: Optional[int] = None,
        image_embed_dim: Optional[int] = None,
        pos_embed_seq_len: Optional[int] = None,
        pos_embed_seq_len: Optional[int] = None,
    ):
    ):
        super().__init__()
        super().__init__()
        self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
        self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
        self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
        self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
        self.act_fn = nn.SiLU()
        self.act_fn = nn.SiLU()
        self.time_proj = nn.Linear(dim, time_proj_dim)
        self.time_proj = nn.Linear(dim, time_proj_dim)
        self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
        self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
        self.image_embedder = None
        self.image_embedder = None
        if image_embed_dim is not None:
        if image_embed_dim is not None:
            self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
            self.image_embedder = SkyReelsV2ImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
    def forward(
    def forward(
        self,
        self,
        timestep: torch.Tensor,
        timestep: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        encoder_hidden_states_image: Optional[torch.Tensor] = None,
        encoder_hidden_states_image: Optional[torch.Tensor] = None,
    ):
    ):
        timestep = self.timesteps_proj(timestep)
        timestep = self.timesteps_proj(timestep)
        time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
        time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
        if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
        if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
            timestep = timestep.to(time_embedder_dtype)
            timestep = timestep.to(time_embedder_dtype)
        temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
        temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
        timestep_proj = self.time_proj(self.act_fn(temb))
        timestep_proj = self.time_proj(self.act_fn(temb))
        encoder_hidden_states = self.text_embedder(encoder_hidden_states)
        encoder_hidden_states = self.text_embedder(encoder_hidden_states)
        if encoder_hidden_states_image is not None:
        if encoder_hidden_states_image is not None:
            encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
            encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
        return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
        return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
class WanRotaryPosEmbed(nn.Module):
class SkyReelsV2RotaryPosEmbed(nn.Module):
    def __init__(
    def __init__(
        self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
        self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
    ):
    ):
        super().__init__()
        super().__init__()
        self.attention_head_dim = attention_head_dim
        self.attention_head_dim = attention_head_dim
        self.patch_size = patch_size
        self.patch_size = patch_size
        self.max_seq_len = max_seq_len
        self.max_seq_len = max_seq_len
        h_dim = w_dim = 2 * (attention_head_dim // 6)
        h_dim = w_dim = 2 * (attention_head_dim // 6)
        t_dim = attention_head_dim - h_dim - w_dim
        t_dim = attention_head_dim - h_dim - w_dim
        freqs = []
        freqs = []
        for dim in [t_dim, h_dim, w_dim]:
        for dim in [t_dim, h_dim, w_dim]:
            freq = get_1d_rotary_pos_embed(
            freq = get_1d_rotary_pos_embed(
                dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
                dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
            )
            )
            freqs.append(freq)
            freqs.append(freq)
        self.freqs = torch.cat(freqs, dim=1)
        self.freqs = torch.cat(freqs, dim=1)
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        batch_size, num_channels, num_frames, height, width = hidden_states.shape
        batch_size, num_channels, num_frames, height, width = hidden_states.shape
        p_t, p_h, p_w = self.patch_size
        p_t, p_h, p_w = self.patch_size
        ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
        ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
        freqs = self.freqs.to(hidden_states.device)
        freqs = self.freqs.to(hidden_states.device)
        freqs = freqs.split_with_sizes(
        freqs = freqs.split_with_sizes(
            [
            [
                self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
                self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
                self.attention_head_dim // 6,
                self.attention_head_dim // 6,
                self.attention_head_dim // 6,
                self.attention_head_dim // 6,
            ],
            ],
            dim=1,
            dim=1,
        )
        )
        freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
        freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
        freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
        freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
        freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
        freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
        freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
        freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
        return freqs
        return freqs
class WanTransformerBlock(nn.Module):
class SkyReelsV2TransformerBlock(nn.Module):
    def __init__(
    def __init__(
        self,
        self,
        dim: int,
        dim: int,
        ffn_dim: int,
        ffn_dim: int,
        num_heads: int,
        num_heads: int,
        qk_norm: str = "rms_norm_across_heads",
        qk_norm: str = "rms_norm",
        cross_attn_norm: bool = False,
        cross_attn_norm: bool = False,
        eps: float = 1e-6,
        eps: float = 1e-6,
        added_kv_proj_dim: Optional[int] = None,
        added_kv_proj_dim: Optional[int] = None,
    ):
    ):
        super().__init__()
        super().__init__()
        # 1. Self-attention
        # 1. Self-attention
        self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
        self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
        self.attn1 = Attention(
        self.attn1 = Attention(
            query_dim=dim,
            query_dim=dim,
            heads=num_heads,
            heads=num_heads,
            kv_heads=num_heads,
            kv_heads=num_heads,
            dim_head=dim // num_heads,
            dim_head=dim // num_heads,
            qk_norm=qk_norm,
            qk_norm=qk_norm,
            eps=eps,
            eps=eps,
            bias=True,
            bias=True,
            cross_attention_dim=None,
            cross_attention_dim=None,
            out_bias=True,
            out_bias=True,
            processor=WanAttnProcessor2_0(),
            processor=SkyReelsV2AttnProcessor2_0(),
        )
        )
        # 2. Cross-attention
        # 2. Cross-attention
        self.attn2 = Attention(
        self.attn2 = Attention(
            query_dim=dim,
            query_dim=dim,
            heads=num_heads,
            heads=num_heads,
            kv_heads=num_heads,
            kv_heads=num_heads,
            dim_head=dim // num_heads,
            dim_head=dim // num_heads,
            qk_norm=qk_norm,
            qk_norm=qk_norm,
            eps=eps,
            eps=eps,
            bias=True,
            bias=True,
            cross_attention_dim=None,
            cross_attention_dim=None,
            out_bias=True,
            out_bias=True,
            added_kv_proj_dim=added_kv_proj_dim,
            added_kv_proj_dim=added_kv_proj_dim,
            added_proj_bias=True,
            added_proj_bias=True,
            processor=WanAttnProcessor2_0(),
            processor=SkyReelsV2AttnProcessor2_0(),
        )
        )
        self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
        self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
        # 3. Feed-forward
        # 3. Feed-forward
        self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
        self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
        self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
        self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
        self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
        self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
    def forward(
    def forward(
        self,
        self,
        hidden_states: torch.Tensor,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        temb: torch.Tensor,
        temb: torch.Tensor,
        rotary_emb: torch.Tensor,
        rotary_emb: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
    ) -> torch.Tensor:
        shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
        if temb.dim() == 3:
            self.scale_shift_table + temb.float()
            shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
        ).chunk(6, dim=1)
                self.scale_shift_table + temb.float()
            ).chunk(6, dim=1)
        elif temb.dim() == 4:
            e = (self.scale_shift_table.unsqueeze(2) + temb.float()).chunk(6, dim=1)
            shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e]
        # 1. Self-attention
        # 1. Self-attention
        norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
        norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
        attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
        attn_output = self.attn1(
            hidden_states=norm_hidden_states, rotary_emb=rotary_emb, attention_mask=attention_mask
        )
        hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
        hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
        # 2. Cross-attention
        # 2. Cross-attention
        norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
        norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
        attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
        attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
        hidden_states = hidden_states + attn_output
        hidden_states = hidden_states + attn_output
        # 3. Feed-forward
        # 3. Feed-forward
        norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
        norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
            hidden_states
            hidden_states
        )
        )
        ff_output = self.ffn(norm_hidden_states)
        ff_output = self.ffn(norm_hidden_states)
        hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
        hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
        return hidden_states
        return hidden_states  # TODO: check .to(torch.bfloat16)
    def set_ar_attention(self):
        self.attn1.processor.set_ar_attention()
class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
    r"""
    r"""
    A Transformer model for video-like data used in the Wan model.
    A Transformer model for video-like data used in the Wan-based SkyReels-V2 model.
    Args:
    Args:
        patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
        patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
            3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
            3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
        num_attention_heads (`int`, defaults to `40`):
        num_attention_heads (`int`, defaults to `16`):
            Fixed length for text embeddings.
            Fixed length for text embeddings.
        attention_head_dim (`int`, defaults to `128`):
        attention_head_dim (`int`, defaults to `128`):
            The number of channels in each head.
            The number of channels in each head.
        in_channels (`int`, defaults to `16`):
        in_channels (`int`, defaults to `16`):
            The number of channels in the input.
            The number of channels in the input.
        out_channels (`int`, defaults to `16`):
        out_channels (`int`, defaults to `16`):
            The number of channels in the output.
            The number of channels in the output.
        text_dim (`int`, defaults to `512`):
        text_dim (`int`, defaults to `4096`):
            Input dimension for text embeddings.
            Input dimension for text embeddings.
        freq_dim (`int`, defaults to `256`):
        freq_dim (`int`, defaults to `256`):
            Dimension for sinusoidal time embeddings.
            Dimension for sinusoidal time embeddings.
        ffn_dim (`int`, defaults to `13824`):
        ffn_dim (`int`, defaults to `8192`):
            Intermediate dimension in feed-forward network.
            Intermediate dimension in feed-forward network.
        num_layers (`int`, defaults to `40`):
        num_layers (`int`, defaults to `32`):
            The number of layers of transformer blocks to use.
            The number of layers of transformer blocks to use.
        window_size (`Tuple[int]`, defaults to `(-1, -1)`):
        window_size (`Tuple[int]`, defaults to `(-1, -1)`):
            Window size for local attention (-1 indicates global attention).
            Window size for local attention (-1 indicates global attention).
        cross_attn_norm (`bool`, defaults to `True`):
        cross_attn_norm (`bool`, defaults to `True`):
            Enable cross-attention normalization.
            Enable cross-attention normalization.
        qk_norm (`bool`, defaults to `True`):
        qk_norm (`str`, *optional*, defaults to `"rms_norm"`):
            Enable query/key normalization.
            Enable query/key normalization.
        eps (`float`, defaults to `1e-6`):
        eps (`float`, defaults to `1e-6`):
            Epsilon value for normalization layers.
            Epsilon value for normalization layers.
        add_img_emb (`bool`, defaults to `False`):
        inject_sample_info (`bool`, defaults to `False`):
            Whether to use img_emb.
            Whether to inject sample information into the model.
        added_kv_proj_dim (`int`, *optional*, defaults to `None`):
        image_dim (`int`, *optional*):
            The number of channels to use for the added key and value projections. If `None`, no projection is used.
            The dimension of the image embeddings.
        added_kv_proj_dim (`int`, *optional*):
            The dimension of the added key/value projection.
        rope_max_seq_len (`int`, defaults to `1024`):
            The maximum sequence length for the rotary embeddings.
        pos_embed_seq_len (`int`, *optional*):
            The sequence length for the positional embeddings.
    """
    """
    _supports_gradient_checkpointing = True
    _supports_gradient_checkpointing = True
    _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
    _skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
    _no_split_modules = ["WanTransformerBlock"]
    _no_split_modules = ["WanTransformerBlock"]
    _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
    _keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
    _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
    _keys_to_ignore_on_load_unexpected = ["norm_added_q"]
    @register_to_config
    @register_to_config
    def __init__(
    def __init__(
        self,
        self,
        patch_size: Tuple[int] = (1, 2, 2),
        patch_size: Tuple[int] = (1, 2, 2),
        num_attention_heads: int = 40,
        num_attention_heads: int = 16,
        attention_head_dim: int = 128,
        attention_head_dim: int = 128,
        in_channels: int = 16,
        in_channels: int = 16,
        out_channels: int = 16,
        out_channels: int = 16,
        text_dim: int = 4096,
        text_dim: int = 4096,
        freq_dim: int = 256,
        freq_dim: int = 256,
        ffn_dim: int = 13824,
        ffn_dim: int = 8192,
        num_layers: int = 40,
        num_layers: int = 32,
        cross_attn_norm: bool = True,
        cross_attn_norm: bool = True,
        qk_norm: Optional[str] = "rms_norm_across_heads",
        qk_norm: Optional[str] = "rms_norm",
        eps: float = 1e-6,
        eps: float = 1e-6,
        image_dim: Optional[int] = None,
        image_dim: Optional[int] = None,
        added_kv_proj_dim: Optional[int] = None,
        added_kv_proj_dim: Optional[int] = None,
        rope_max_seq_len: int = 1024,
        rope_max_seq_len: int = 1024,
        pos_embed_seq_len: Optional[int] = None,
        pos_embed_seq_len: Optional[int] = None,
        inject_sample_info: bool = False,
    ) -> None:
    ) -> None:
        super().__init__()
        super().__init__()
        inner_dim = num_attention_heads * attention_head_dim
        inner_dim = num_attention_heads * attention_head_dim
        out_channels = out_channels or in_channels
        out_channels = out_channels or in_channels
        self.num_frame_per_block = 1
        self.flag_causal_attention = False
        self.enable_teacache = False
        # 1. Patch & position embedding
        # 1. Patch & position embedding
        self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
        self.rope = SkyReelsV2RotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
        self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
        self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
        # 2. Condition embeddings
        # 2. Condition embeddings
        # image_embedding_dim=1280 for I2V model
        # image_embedding_dim=1280 for I2V model
        self.condition_embedder = WanTimeTextImageEmbedding(
        self.condition_embedder = SkyReelsV2TimeTextImageEmbedding(
            dim=inner_dim,
            dim=inner_dim,
            time_freq_dim=freq_dim,
            time_freq_dim=freq_dim,
            time_proj_dim=inner_dim * 6,
            time_proj_dim=inner_dim * 6,
            text_embed_dim=text_dim,
            text_embed_dim=text_dim,
            image_embed_dim=image_dim,
            image_embed_dim=image_dim,
            pos_embed_seq_len=pos_embed_seq_len,
            pos_embed_seq_len=pos_embed_seq_len,
        )
        )
        # 3. Transformer blocks
        # 3. Transformer blocks
        self.blocks = nn.ModuleList(
        self.blocks = nn.ModuleList(
            [
            [
                WanTransformerBlock(
                SkyReelsV2TransformerBlock(
                    inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
                    inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim=inner_dim
                )
                )
                for _ in range(num_layers)
                for _ in range(num_layers)
            ]
            ]
        )
        )
        # 4. Output norm & projection
        # 4. Output norm & projection
        self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
        self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
        self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
        self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
        self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
        self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
        self.gradient_checkpointing = False
        self.gradient_checkpointing = False
        if inject_sample_info:
            self.fps_embedding = nn.Embedding(2, inner_dim)
            self.fps_projection = nn.Sequential(
                nn.Linear(inner_dim, inner_dim), nn.SiLU(), nn.Linear(inner_dim, inner_dim * 6)
            )
    def forward(
    def forward(
        self,
        self,
        hidden_states: torch.Tensor,
        hidden_states: torch.Tensor,
        timestep: torch.LongTensor,
        timestep: torch.LongTensor,
        encoder_hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        encoder_hidden_states_image: Optional[torch.Tensor] = None,
        encoder_hidden_states_image: Optional[torch.Tensor] = None,
        fps: Optional[torch.Tensor] = None,
        return_dict: bool = True,
        return_dict: bool = True,
        attention_kwargs: Optional[Dict[str, Any]] = None,
        attention_kwargs: Optional[Dict[str, Any]] = None,
    ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
    ) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
        if attention_kwargs is not None:
        if attention_kwargs is not None:
            attention_kwargs = attention_kwargs.copy()
            attention_kwargs = attention_kwargs.copy()
            lora_scale = attention_kwargs.pop("scale", 1.0)
            lora_scale = attention_kwargs.pop("scale", 1.0)
        else:
        else:
            lora_scale = 1.0
            lora_scale = 1.0
        if USE_PEFT_BACKEND:
        if USE_PEFT_BACKEND:
            # weight the lora layers by setting `lora_scale` for each PEFT layer
            # weight the lora layers by setting `lora_scale` for each PEFT layer
            scale_lora_layers(self, lora_scale)
            scale_lora_layers(self, lora_scale)
        else:
        else:
            if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
            if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
                logger.warning(
                logger.warning(
                    "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
                    "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
                )
                )
        batch_size, num_channels, num_frames, height, width = hidden_states.shape
        batch_size, num_channels, num_frames, height, width = hidden_states.shape
        p_t, p_h, p_w = self.config.patch_size
        p_t, p_h, p_w = self.config.patch_size
        post_patch_num_frames = num_frames // p_t
        post_patch_num_frames = num_frames // p_t
        post_patch_height = height // p_h
        post_patch_height = height // p_h
        post_patch_width = width // p_w
        post_patch_width = width // p_w
        rotary_emb = self.rope(hidden_states)
        rotary_emb = self.rope(hidden_states)
        hidden_states = self.patch_embedding(hidden_states)
        hidden_states = self.patch_embedding(hidden_states)
        grid_sizes = torch.tensor(hidden_states.shape[2:], dtype=torch.long)
        if self.flag_causal_attention:
            frame_num, height, width = grid_sizes
            block_num = frame_num // self.num_frame_per_block
            range_tensor = torch.arange(block_num, device=hidden_states.device).view(-1, 1)
            range_tensor = range_tensor.repeat(1, self.num_frame_per_block).flatten()
            causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1)  # f, f
            causal_mask = causal_mask.view(frame_num, 1, 1, frame_num, 1, 1)
            causal_mask = causal_mask.repeat(1, height, width, 1, height, width)
            causal_mask = causal_mask.reshape(frame_num * height * width, frame_num * height * width)
            causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)
        hidden_states = hidden_states.flatten(2).transpose(1, 2)
        hidden_states = hidden_states.flatten(2).transpose(1, 2)
        # TODO: check here
        if timestep.dim() == 2:
            b, f = timestep.shape
            _flag_df = True
        else:
            _flag_df = False
        temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
        temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
            timestep, encoder_hidden_states, encoder_hidden_states_image
            timestep, encoder_hidden_states, encoder_hidden_states_image
        )
        )
        timestep_proj = timestep_proj.unflatten(1, (6, -1))
        timestep_proj = timestep_proj.unflatten(1, (6, -1))
        if encoder_hidden_states_image is not None:
        if encoder_hidden_states_image is not None:
            encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
            encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
        # 4. Transformer blocks
        # 4. Transformer blocks
        if torch.is_grad_enabled() and self.gradient_checkpointing:
        if torch.is_grad_enabled() and self.gradient_checkpointing:
            for block in self.blocks:
            for block in self.blocks:
                hidden_states = self._gradient_checkpointing_func(
                hidden_states = self._gradient_checkpointing_func(
                    block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
                    block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, causal_mask
                )
                )
        else:
        if self.inject_sample_info:
            for block in self.blocks:
            fps = torch.tensor(fps, dtype=torch.long, device=hidden_states.device)
                hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
        # 5. Output norm, projection & unpatchify
        shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
        # Move the shift and scale tensors to the same device as hidden_states.
        # When using multi-GPU inference via accelerate these will be on the
        # first device rather than the last device, which hidden_states ends up
        # on.
        shift = shift.to(hidden_states.device)
        scale = scale.to(hidden_states.device)
        hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
        hidden_states = self.proj_out(hidden_states)
        hidden_states = hidden_states.reshape(
            fps_emb = self.fps_embedding(fps).float()
            batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
            if _flag_df:
        )
                timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(
        hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
                    timestep.shape[1], 1, 1
        output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
                )
            else:
                timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))
        if USE_PEFT_BACKEND:
        if _flag_df:
            # remove `lora_scale` from each PEFT layer
            temb = temb.view(b, f, 1, 1, self.dim)
            unscale_lora_layers(self, lora_scale)
            timestep_proj = timestep_proj.view(b, f, 1, 1, 6, self.dim)
            temb = temb.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3)
            timestep_proj = timestep_proj.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3)
            timestep_proj = timestep_proj.transpose(1, 2).contiguous()
        if not return_dict:
        if self.enable_teacache:
            return (output,)
            modulated_inp = timestep_proj if self.use_ref_steps else temb
            # teacache
            if self.cnt % 2 == 0:  # even -> condition
                self.is_even = True
                if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
                    should_calc_even = True
                    self.accumulated_rel_l1_distance_even = 0
                else:
                    rescale_func = np.poly1d(self.coefficients)
                    self.accumulated_rel_l1_distance_even += rescale_func(
                        ((modulated_inp - self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean())
                        .cpu()
                        .item()
                    )
                    if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
                        should_calc_even = False
                    else:
                        should_calc_even = True
                        self.accumulated_rel_l1_distance_even = 0
                self.previous_e0_even = modulated_inp.clone()
        return Transformer2DModelOutput(sample=output)
            else:  # odd -> unconditon
                self.is_even = False
                if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
                    should_calc_odd = True
                    self.accumulated_rel_l1_distance_odd = 0
                else:
                    rescale_func = np.poly1d(self.coefficients)
                    self.accumulated_rel_l1_distance_odd += rescale_func(
                        ((modulated_inp - self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean())
                        .cpu()
                        .item()
                    )
                    if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
                        should_calc_odd = False
                    else:
                        should_calc_odd = True
                        self.accumulated_rel_l1_distance_odd = 0
                self.previous_e0_odd = modulated_inp.clone()
        if self.enable_teacache:
            if self.is_even:
                if not should_calc_even:
                    hidden_states += self.previous_residual_even
                else:
                    ori_hidden_states = hidden_states.clone()
                    for block in self.blocks:
                        hidden_states = block(
                            hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, causal_ma