diff

Created Diff never expires
66 removals
Lines
Total
Removed
Words
Total
Removed
To continue using this feature, upgrade to
Diffchecker logo
Diffchecker Pro
484 lines
176 additions
Lines
Total
Added
Words
Total
Added
To continue using this feature, upgrade to
Diffchecker logo
Diffchecker Pro
587 lines
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
# Copyright 2025 The SkyReels-V2 Team, The Wan Team and The HuggingFace Team. All rights reserved.
#
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# You may obtain a copy of the License at
#
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
#
# Unless required by applicable law or agreed to in writing, software
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.


import math
import math
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple, Union


import numpy as np
import torch
import torch
import torch.nn as nn
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.functional as F


from ...configuration_utils import ConfigMixin, register_to_config
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ..attention import FeedForward
from ..attention import FeedForward
from ..attention_processor import Attention
from ..attention_processor import Attention
from ..cache_utils import CacheMixin
from ..cache_utils import CacheMixin
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
from ..embeddings import PixArtAlphaTextProjection, TimestepEmbedding, Timesteps, get_1d_rotary_pos_embed
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..modeling_utils import ModelMixin
from ..normalization import FP32LayerNorm
from ..normalization import FP32LayerNorm




logger = logging.get_logger(__name__) # pylint: disable=invalid-name
logger = logging.get_logger(__name__) # pylint: disable=invalid-name




class WanAttnProcessor2_0:

class SkyReelsV2AttnProcessor2_0:
def __init__(self):
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
raise ImportError(
"SkyReelsV2AttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0."
)

self._flag_ar_attention = False


def __call__(
def __call__(
self,
self,
attn: Attention,
attn: Attention,
hidden_states: torch.Tensor,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
) -> torch.Tensor:
encoder_hidden_states_img = None
encoder_hidden_states_img = None
if attn.add_k_proj is not None:
if attn.add_k_proj is not None:
# 512 is the context length of the text encoder, hardcoded for now
# 512 is the context length of the text encoder, hardcoded for now
image_context_length = encoder_hidden_states.shape[1] - 512
image_context_length = encoder_hidden_states.shape[1] - 512
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
if encoder_hidden_states is None:
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
encoder_hidden_states = hidden_states


query = attn.to_q(hidden_states)
query = attn.to_q(hidden_states)
key = attn.to_k(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)


if attn.norm_q is not None:
if attn.norm_q is not None:
query = attn.norm_q(query)
query = attn.norm_q(query)
if attn.norm_k is not None:
if attn.norm_k is not None:
key = attn.norm_k(key)
key = attn.norm_k(key)


query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)


if rotary_emb is not None:
if rotary_emb is not None:


def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
return x_out.type_as(hidden_states)
return x_out.type_as(hidden_states)


query = apply_rotary_emb(query, rotary_emb)
query = apply_rotary_emb(query, rotary_emb)
key = apply_rotary_emb(key, rotary_emb)
key = apply_rotary_emb(key, rotary_emb)


# I2V task
# I2V task
hidden_states_img = None
hidden_states_img = None
if encoder_hidden_states_img is not None:
if encoder_hidden_states_img is not None:
key_img = attn.add_k_proj(encoder_hidden_states_img)
key_img = attn.add_k_proj(encoder_hidden_states_img)
key_img = attn.norm_added_k(key_img)
key_img = attn.norm_added_k(key_img)
value_img = attn.add_v_proj(encoder_hidden_states_img)
value_img = attn.add_v_proj(encoder_hidden_states_img)


key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)


hidden_states_img = F.scaled_dot_product_attention(
hidden_states_img = F.scaled_dot_product_attention(
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
)
)
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)
hidden_states_img = hidden_states_img.type_as(query)


hidden_states = F.scaled_dot_product_attention(
if self._flag_ar_attention:
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
is_self_attention = encoder_hidden_states == hidden_states
)
hidden_states = F.scaled_dot_product_attention(
query.to(torch.bfloat16) if is_self_attention else query,
key.to(torch.bfloat16) if is_self_attention else key,
value.to(torch.bfloat16) if is_self_attention else value,
attn_mask=attention_mask,
dropout_p=0.0,
is_causal=False,
)
else:
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)

hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.type_as(query)
hidden_states = hidden_states.type_as(query)


if hidden_states_img is not None:
if hidden_states_img is not None:
hidden_states = hidden_states + hidden_states_img
hidden_states = hidden_states + hidden_states_img


hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
return hidden_states


def set_ar_attention(self):
self._flag_ar_attention = True


class WanImageEmbedding(torch.nn.Module):

class SkyReelsV2ImageEmbedding(torch.nn.Module):
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
super().__init__()
super().__init__()


self.norm1 = FP32LayerNorm(in_features)
self.norm1 = FP32LayerNorm(in_features)
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
self.norm2 = FP32LayerNorm(out_features)
self.norm2 = FP32LayerNorm(out_features)
if pos_embed_seq_len is not None:
if pos_embed_seq_len is not None:
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
else:
else:
self.pos_embed = None
self.pos_embed = None


def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
if self.pos_embed is not None:
if self.pos_embed is not None:
batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed


hidden_states = self.norm1(encoder_hidden_states_image)
hidden_states = self.norm1(encoder_hidden_states_image)
hidden_states = self.ff(hidden_states)
hidden_states = self.ff(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = self.norm2(hidden_states)
return hidden_states
return hidden_states




class WanTimeTextImageEmbedding(nn.Module):
class SkyReelsV2TimeTextImageEmbedding(nn.Module):
def __init__(
def __init__(
self,
self,
dim: int,
dim: int,
time_freq_dim: int,
time_freq_dim: int,
time_proj_dim: int,
time_proj_dim: int,
text_embed_dim: int,
text_embed_dim: int,
image_embed_dim: Optional[int] = None,
image_embed_dim: Optional[int] = None,
pos_embed_seq_len: Optional[int] = None,
pos_embed_seq_len: Optional[int] = None,
):
):
super().__init__()
super().__init__()


self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.timesteps_proj = Timesteps(num_channels=time_freq_dim, flip_sin_to_cos=True, downscale_freq_shift=0)
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
self.time_embedder = TimestepEmbedding(in_channels=time_freq_dim, time_embed_dim=dim)
self.act_fn = nn.SiLU()
self.act_fn = nn.SiLU()
self.time_proj = nn.Linear(dim, time_proj_dim)
self.time_proj = nn.Linear(dim, time_proj_dim)
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")
self.text_embedder = PixArtAlphaTextProjection(text_embed_dim, dim, act_fn="gelu_tanh")


self.image_embedder = None
self.image_embedder = None
if image_embed_dim is not None:
if image_embed_dim is not None:
self.image_embedder = WanImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)
self.image_embedder = SkyReelsV2ImageEmbedding(image_embed_dim, dim, pos_embed_seq_len=pos_embed_seq_len)


def forward(
def forward(
self,
self,
timestep: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
):
):
timestep = self.timesteps_proj(timestep)
timestep = self.timesteps_proj(timestep)


time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
timestep_proj = self.time_proj(self.act_fn(temb))
timestep_proj = self.time_proj(self.act_fn(temb))


encoder_hidden_states = self.text_embedder(encoder_hidden_states)
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
if encoder_hidden_states_image is not None:
if encoder_hidden_states_image is not None:
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)


return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image




class WanRotaryPosEmbed(nn.Module):
class SkyReelsV2RotaryPosEmbed(nn.Module):
def __init__(
def __init__(
self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0
):
):
super().__init__()
super().__init__()


self.attention_head_dim = attention_head_dim
self.attention_head_dim = attention_head_dim
self.patch_size = patch_size
self.patch_size = patch_size
self.max_seq_len = max_seq_len
self.max_seq_len = max_seq_len


h_dim = w_dim = 2 * (attention_head_dim // 6)
h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim
t_dim = attention_head_dim - h_dim - w_dim


freqs = []
freqs = []
for dim in [t_dim, h_dim, w_dim]:
for dim in [t_dim, h_dim, w_dim]:
freq = get_1d_rotary_pos_embed(
freq = get_1d_rotary_pos_embed(
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
dim, max_seq_len, theta, use_real=False, repeat_interleave_real=False, freqs_dtype=torch.float64
)
)
freqs.append(freq)
freqs.append(freq)
self.freqs = torch.cat(freqs, dim=1)
self.freqs = torch.cat(freqs, dim=1)


def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.patch_size
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w


freqs = self.freqs.to(hidden_states.device)
freqs = self.freqs.to(hidden_states.device)
freqs = freqs.split_with_sizes(
freqs = freqs.split_with_sizes(
[
[
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
self.attention_head_dim // 6,
self.attention_head_dim // 6,
self.attention_head_dim // 6,
self.attention_head_dim // 6,
],
],
dim=1,
dim=1,
)
)


freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_f = freqs[0][:ppf].view(ppf, 1, 1, -1).expand(ppf, pph, ppw, -1)
freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_h = freqs[1][:pph].view(1, pph, 1, -1).expand(ppf, pph, ppw, -1)
freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs_w = freqs[2][:ppw].view(1, 1, ppw, -1).expand(ppf, pph, ppw, -1)
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
freqs = torch.cat([freqs_f, freqs_h, freqs_w], dim=-1).reshape(1, 1, ppf * pph * ppw, -1)
return freqs
return freqs




class WanTransformerBlock(nn.Module):
class SkyReelsV2TransformerBlock(nn.Module):
def __init__(
def __init__(
self,
self,
dim: int,
dim: int,
ffn_dim: int,
ffn_dim: int,
num_heads: int,
num_heads: int,
qk_norm: str = "rms_norm_across_heads",
qk_norm: str = "rms_norm",
cross_attn_norm: bool = False,
cross_attn_norm: bool = False,
eps: float = 1e-6,
eps: float = 1e-6,
added_kv_proj_dim: Optional[int] = None,
added_kv_proj_dim: Optional[int] = None,
):
):
super().__init__()
super().__init__()


# 1. Self-attention
# 1. Self-attention
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.attn1 = Attention(
self.attn1 = Attention(
query_dim=dim,
query_dim=dim,
heads=num_heads,
heads=num_heads,
kv_heads=num_heads,
kv_heads=num_heads,
dim_head=dim // num_heads,
dim_head=dim // num_heads,
qk_norm=qk_norm,
qk_norm=qk_norm,
eps=eps,
eps=eps,
bias=True,
bias=True,
cross_attention_dim=None,
cross_attention_dim=None,
out_bias=True,
out_bias=True,
processor=WanAttnProcessor2_0(),
processor=SkyReelsV2AttnProcessor2_0(),
)
)


# 2. Cross-attention
# 2. Cross-attention
self.attn2 = Attention(
self.attn2 = Attention(
query_dim=dim,
query_dim=dim,
heads=num_heads,
heads=num_heads,
kv_heads=num_heads,
kv_heads=num_heads,
dim_head=dim // num_heads,
dim_head=dim // num_heads,
qk_norm=qk_norm,
qk_norm=qk_norm,
eps=eps,
eps=eps,
bias=True,
bias=True,
cross_attention_dim=None,
cross_attention_dim=None,
out_bias=True,
out_bias=True,
added_kv_proj_dim=added_kv_proj_dim,
added_kv_proj_dim=added_kv_proj_dim,
added_proj_bias=True,
added_proj_bias=True,
processor=WanAttnProcessor2_0(),
processor=SkyReelsV2AttnProcessor2_0(),
)
)
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()


# 3. Feed-forward
# 3. Feed-forward
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)


self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)


def forward(
def forward(
self,
self,
hidden_states: torch.Tensor,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
temb: torch.Tensor,
rotary_emb: torch.Tensor,
rotary_emb: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
) -> torch.Tensor:
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
if temb.dim() == 3:
self.scale_shift_table + temb.float()
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
).chunk(6, dim=1)
self.scale_shift_table + temb.float()
).chunk(6, dim=1)
elif temb.dim() == 4:
e = (self.scale_shift_table.unsqueeze(2) + temb.float()).chunk(6, dim=1)
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = [ei.squeeze(1) for ei in e]


# 1. Self-attention
# 1. Self-attention
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
norm_hidden_states = (self.norm1(hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(hidden_states)
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
attn_output = self.attn1(
hidden_states=norm_hidden_states, rotary_emb=rotary_emb, attention_mask=attention_mask
)
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)
hidden_states = (hidden_states.float() + attn_output * gate_msa).type_as(hidden_states)


# 2. Cross-attention
# 2. Cross-attention
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
norm_hidden_states = self.norm2(hidden_states.float()).type_as(hidden_states)
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = hidden_states + attn_output
hidden_states = hidden_states + attn_output


# 3. Feed-forward
# 3. Feed-forward
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
norm_hidden_states = (self.norm3(hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
hidden_states
hidden_states
)
)
ff_output = self.ffn(norm_hidden_states)
ff_output = self.ffn(norm_hidden_states)
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)
hidden_states = (hidden_states.float() + ff_output.float() * c_gate_msa).type_as(hidden_states)


return hidden_states
return hidden_states # TODO: check .to(torch.bfloat16)

def set_ar_attention(self):
self.attn1.processor.set_ar_attention()




class WanTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
class SkyReelsV2Transformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
r"""
r"""
A Transformer model for video-like data used in the Wan model.
A Transformer model for video-like data used in the Wan-based SkyReels-V2 model.


Args:
Args:
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
num_attention_heads (`int`, defaults to `40`):
num_attention_heads (`int`, defaults to `16`):
Fixed length for text embeddings.
Fixed length for text embeddings.
attention_head_dim (`int`, defaults to `128`):
attention_head_dim (`int`, defaults to `128`):
The number of channels in each head.
The number of channels in each head.
in_channels (`int`, defaults to `16`):
in_channels (`int`, defaults to `16`):
The number of channels in the input.
The number of channels in the input.
out_channels (`int`, defaults to `16`):
out_channels (`int`, defaults to `16`):
The number of channels in the output.
The number of channels in the output.
text_dim (`int`, defaults to `512`):
text_dim (`int`, defaults to `4096`):
Input dimension for text embeddings.
Input dimension for text embeddings.
freq_dim (`int`, defaults to `256`):
freq_dim (`int`, defaults to `256`):
Dimension for sinusoidal time embeddings.
Dimension for sinusoidal time embeddings.
ffn_dim (`int`, defaults to `13824`):
ffn_dim (`int`, defaults to `8192`):
Intermediate dimension in feed-forward network.
Intermediate dimension in feed-forward network.
num_layers (`int`, defaults to `40`):
num_layers (`int`, defaults to `32`):
The number of layers of transformer blocks to use.
The number of layers of transformer blocks to use.
window_size (`Tuple[int]`, defaults to `(-1, -1)`):
window_size (`Tuple[int]`, defaults to `(-1, -1)`):
Window size for local attention (-1 indicates global attention).
Window size for local attention (-1 indicates global attention).
cross_attn_norm (`bool`, defaults to `True`):
cross_attn_norm (`bool`, defaults to `True`):
Enable cross-attention normalization.
Enable cross-attention normalization.
qk_norm (`bool`, defaults to `True`):
qk_norm (`str`, *optional*, defaults to `"rms_norm"`):
Enable query/key normalization.
Enable query/key normalization.
eps (`float`, defaults to `1e-6`):
eps (`float`, defaults to `1e-6`):
Epsilon value for normalization layers.
Epsilon value for normalization layers.
add_img_emb (`bool`, defaults to `False`):
inject_sample_info (`bool`, defaults to `False`):
Whether to use img_emb.
Whether to inject sample information into the model.
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
image_dim (`int`, *optional*):
The number of channels to use for the added key and value projections. If `None`, no projection is used.
The dimension of the image embeddings.
added_kv_proj_dim (`int`, *optional*):
The dimension of the added key/value projection.
rope_max_seq_len (`int`, defaults to `1024`):
The maximum sequence length for the rotary embeddings.
pos_embed_seq_len (`int`, *optional*):
The sequence length for the positional embeddings.
"""
"""


_supports_gradient_checkpointing = True
_supports_gradient_checkpointing = True
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
_skip_layerwise_casting_patterns = ["patch_embedding", "condition_embedder", "norm"]
_no_split_modules = ["WanTransformerBlock"]
_no_split_modules = ["WanTransformerBlock"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]


@register_to_config
@register_to_config
def __init__(
def __init__(
self,
self,
patch_size: Tuple[int] = (1, 2, 2),
patch_size: Tuple[int] = (1, 2, 2),
num_attention_heads: int = 40,
num_attention_heads: int = 16,
attention_head_dim: int = 128,
attention_head_dim: int = 128,
in_channels: int = 16,
in_channels: int = 16,
out_channels: int = 16,
out_channels: int = 16,
text_dim: int = 4096,
text_dim: int = 4096,
freq_dim: int = 256,
freq_dim: int = 256,
ffn_dim: int = 13824,
ffn_dim: int = 8192,
num_layers: int = 40,
num_layers: int = 32,
cross_attn_norm: bool = True,
cross_attn_norm: bool = True,
qk_norm: Optional[str] = "rms_norm_across_heads",
qk_norm: Optional[str] = "rms_norm",
eps: float = 1e-6,
eps: float = 1e-6,
image_dim: Optional[int] = None,
image_dim: Optional[int] = None,
added_kv_proj_dim: Optional[int] = None,
added_kv_proj_dim: Optional[int] = None,
rope_max_seq_len: int = 1024,
rope_max_seq_len: int = 1024,
pos_embed_seq_len: Optional[int] = None,
pos_embed_seq_len: Optional[int] = None,
inject_sample_info: bool = False,
) -> None:
) -> None:
super().__init__()
super().__init__()


inner_dim = num_attention_heads * attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
out_channels = out_channels or in_channels
out_channels = out_channels or in_channels


self.num_frame_per_block = 1
self.flag_causal_attention = False
self.enable_teacache = False

# 1. Patch & position embedding
# 1. Patch & position embedding
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
self.rope = SkyReelsV2RotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)


# 2. Condition embeddings
# 2. Condition embeddings
# image_embedding_dim=1280 for I2V model
# image_embedding_dim=1280 for I2V model
self.condition_embedder = WanTimeTextImageEmbedding(
self.condition_embedder = SkyReelsV2TimeTextImageEmbedding(
dim=inner_dim,
dim=inner_dim,
time_freq_dim=freq_dim,
time_freq_dim=freq_dim,
time_proj_dim=inner_dim * 6,
time_proj_dim=inner_dim * 6,
text_embed_dim=text_dim,
text_embed_dim=text_dim,
image_embed_dim=image_dim,
image_embed_dim=image_dim,
pos_embed_seq_len=pos_embed_seq_len,
pos_embed_seq_len=pos_embed_seq_len,
)
)


# 3. Transformer blocks
# 3. Transformer blocks
self.blocks = nn.ModuleList(
self.blocks = nn.ModuleList(
[
[
WanTransformerBlock(
SkyReelsV2TransformerBlock(
inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim=inner_dim
)
)
for _ in range(num_layers)
for _ in range(num_layers)
]
]
)
)


# 4. Output norm & projection
# 4. Output norm & projection
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)


self.gradient_checkpointing = False
self.gradient_checkpointing = False


if inject_sample_info:
self.fps_embedding = nn.Embedding(2, inner_dim)
self.fps_projection = nn.Sequential(
nn.Linear(inner_dim, inner_dim), nn.SiLU(), nn.Linear(inner_dim, inner_dim * 6)
)

def forward(
def forward(
self,
self,
hidden_states: torch.Tensor,
hidden_states: torch.Tensor,
timestep: torch.LongTensor,
timestep: torch.LongTensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
fps: Optional[torch.Tensor] = None,
return_dict: bool = True,
return_dict: bool = True,
attention_kwargs: Optional[Dict[str, Any]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
else:
lora_scale = 1.0
lora_scale = 1.0


if USE_PEFT_BACKEND:
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
scale_lora_layers(self, lora_scale)
else:
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logger.warning(
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
)


batch_size, num_channels, num_frames, height, width = hidden_states.shape
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t, p_h, p_w = self.config.patch_size
p_t, p_h, p_w = self.config.patch_size
post_patch_num_frames = num_frames // p_t
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p_h
post_patch_height = height // p_h
post_patch_width = width // p_w
post_patch_width = width // p_w


rotary_emb = self.rope(hidden_states)
rotary_emb = self.rope(hidden_states)


hidden_states = self.patch_embedding(hidden_states)
hidden_states = self.patch_embedding(hidden_states)
grid_sizes = torch.tensor(hidden_states.shape[2:], dtype=torch.long)

if self.flag_causal_attention:
frame_num, height, width = grid_sizes
block_num = frame_num // self.num_frame_per_block
range_tensor = torch.arange(block_num, device=hidden_states.device).view(-1, 1)
range_tensor = range_tensor.repeat(1, self.num_frame_per_block).flatten()
causal_mask = range_tensor.unsqueeze(0) <= range_tensor.unsqueeze(1) # f, f
causal_mask = causal_mask.view(frame_num, 1, 1, frame_num, 1, 1)
causal_mask = causal_mask.repeat(1, height, width, 1, height, width)
causal_mask = causal_mask.reshape(frame_num * height * width, frame_num * height * width)
causal_mask = causal_mask.unsqueeze(0).unsqueeze(0)

hidden_states = hidden_states.flatten(2).transpose(1, 2)
hidden_states = hidden_states.flatten(2).transpose(1, 2)


# TODO: check here
if timestep.dim() == 2:
b, f = timestep.shape
_flag_df = True
else:
_flag_df = False

temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
timestep, encoder_hidden_states, encoder_hidden_states_image
timestep, encoder_hidden_states, encoder_hidden_states_image
)
)
timestep_proj = timestep_proj.unflatten(1, (6, -1))
timestep_proj = timestep_proj.unflatten(1, (6, -1))


if encoder_hidden_states_image is not None:
if encoder_hidden_states_image is not None:
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)


# 4. Transformer blocks
# 4. Transformer blocks
if torch.is_grad_enabled() and self.gradient_checkpointing:
if torch.is_grad_enabled() and self.gradient_checkpointing:
for block in self.blocks:
for block in self.blocks:
hidden_states = self._gradient_checkpointing_func(
hidden_states = self._gradient_checkpointing_func(
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, causal_mask
)
)
else:
if self.inject_sample_info:
for block in self.blocks:
fps = torch.tensor(fps, dtype=torch.long, device=hidden_states.device)
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)

# 5. Output norm, projection & unpatchify
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)

# Move the shift and scale tensors to the same device as hidden_states.
# When using multi-GPU inference via accelerate these will be on the
# first device rather than the last device, which hidden_states ends up
# on.
shift = shift.to(hidden_states.device)
scale = scale.to(hidden_states.device)

hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
hidden_states = self.proj_out(hidden_states)


hidden_states = hidden_states.reshape(
fps_emb = self.fps_embedding(fps).float()
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
if _flag_df:
)
timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, self.dim)).repeat(
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
timestep.shape[1], 1, 1
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
)
else:
timestep_proj = timestep_proj + self.fps_projection(fps_emb).unflatten(1, (6, self.dim))


if USE_PEFT_BACKEND:
if _flag_df:
# remove `lora_scale` from each PEFT layer
temb = temb.view(b, f, 1, 1, self.dim)
unscale_lora_layers(self, lora_scale)
timestep_proj = timestep_proj.view(b, f, 1, 1, 6, self.dim)
temb = temb.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1).flatten(1, 3)
timestep_proj = timestep_proj.repeat(1, 1, grid_sizes[1], grid_sizes[2], 1, 1).flatten(1, 3)
timestep_proj = timestep_proj.transpose(1, 2).contiguous()


if not return_dict:
if self.enable_teacache:
return (output,)
modulated_inp = timestep_proj if self.use_ref_steps else temb
# teacache
if self.cnt % 2 == 0: # even -> condition
self.is_even = True
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
should_calc_even = True
self.accumulated_rel_l1_distance_even = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_even += rescale_func(
((modulated_inp - self.previous_e0_even).abs().mean() / self.previous_e0_even.abs().mean())
.cpu()
.item()
)
if self.accumulated_rel_l1_distance_even < self.teacache_thresh:
should_calc_even = False
else:
should_calc_even = True
self.accumulated_rel_l1_distance_even = 0
self.previous_e0_even = modulated_inp.clone()


return Transformer2DModelOutput(sample=output)
else: # odd -> unconditon
self.is_even = False
if self.cnt < self.ret_steps or self.cnt >= self.cutoff_steps:
should_calc_odd = True
self.accumulated_rel_l1_distance_odd = 0
else:
rescale_func = np.poly1d(self.coefficients)
self.accumulated_rel_l1_distance_odd += rescale_func(
((modulated_inp - self.previous_e0_odd).abs().mean() / self.previous_e0_odd.abs().mean())
.cpu()
.item()
)
if self.accumulated_rel_l1_distance_odd < self.teacache_thresh:
should_calc_odd = False
else:
should_calc_odd = True
self.accumulated_rel_l1_distance_odd = 0
self.previous_e0_odd = modulated_inp.clone()


if self.enable_teacache:
if self.is_even:
if not should_calc_even:
hidden_states += self.previous_residual_even
else:
ori_hidden_states = hidden_states.clone()
for block in self.blocks:
hidden_states = block(
hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, causal_ma