-
Notifications
You must be signed in to change notification settings - Fork 108
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Add support for Tensor Parallelism to the Step-Video-T2V model (#…
…454) * feat: Add support for Tensor Parallelism to the Step-Video-T2V model * feat: adjust directory structure * feat: adjust import directory --------- Co-authored-by: fengliaoyuan <fengliaoyuan@bytedance.com>
- Loading branch information
Showing
7 changed files
with
988 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
62 changes: 62 additions & 0 deletions
62
xfuser/model_executor/models/customized/step_video_t2v/attentions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
import torch | ||
import torch.nn as nn | ||
from einops import rearrange | ||
|
||
try: | ||
from xfuser.core.long_ctx_attention import xFuserLongContextAttention | ||
except ImportError: | ||
xFuserLongContextAttention = None | ||
|
||
|
||
class Attention(nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def attn_processor(self, attn_type): | ||
if attn_type == 'torch': | ||
return self.torch_attn_func | ||
elif attn_type == 'parallel': | ||
return self.parallel_attn_func | ||
else: | ||
raise Exception('Not supported attention type...') | ||
|
||
def torch_attn_func( | ||
self, | ||
q, | ||
k, | ||
v, | ||
attn_mask=None, | ||
causal=False, | ||
drop_rate=0.0, | ||
**kwargs | ||
): | ||
|
||
if attn_mask is not None and attn_mask.dtype != torch.bool: | ||
attn_mask = attn_mask.to(q.dtype) | ||
|
||
if attn_mask is not None and attn_mask.ndim == 3: | ||
n_heads = q.shape[2] | ||
attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) | ||
|
||
q, k, v = map(lambda x: rearrange(x, 'b s h d -> b h s d'), (q, k, v)) | ||
x = torch.nn.functional.scaled_dot_product_attention( | ||
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal | ||
) | ||
x = rearrange(x, 'b h s d -> b s h d') | ||
return x | ||
|
||
def parallel_attn_func( | ||
self, | ||
q, | ||
k, | ||
v, | ||
causal=False, | ||
**kwargs | ||
): | ||
assert xFuserLongContextAttention is not None; | ||
'to use sequence parallel attention, xFuserLongContextAttention should be imported...' | ||
hybrid_seq_parallel_attn = xFuserLongContextAttention() | ||
x = hybrid_seq_parallel_attn( | ||
None, q, k, v, causal=causal | ||
) | ||
return x |
321 changes: 321 additions & 0 deletions
321
xfuser/model_executor/models/customized/step_video_t2v/blocks.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,321 @@ | ||
# Copyright 2025 StepFun Inc. All Rights Reserved. | ||
# | ||
# Permission is hereby granted, free of charge, to any person obtaining a copy | ||
# of this software and associated documentation files (the "Software"), to deal | ||
# in the Software without restriction, including without limitation the rights | ||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
# copies of the Software, and to permit persons to whom the Software is | ||
# furnished to do so, subject to the following conditions: | ||
# | ||
# The above copyright notice and this permission notice shall be included in all | ||
# copies or substantial portions of the Software. | ||
# ============================================================================== | ||
import torch | ||
import torch.nn as nn | ||
from typing import Optional | ||
from einops import rearrange | ||
|
||
from xfuser.model_executor.models.customized.step_video_t2v.attentions import Attention | ||
from xfuser.model_executor.models.customized.step_video_t2v.normalization import RMSNorm | ||
from xfuser.model_executor.models.customized.step_video_t2v.rope import RoPE3D | ||
|
||
|
||
class SelfAttention(Attention): | ||
def __init__(self, hidden_dim, head_dim, bias=False, with_rope=True, with_qk_norm=True, attn_type='torch'): | ||
super().__init__() | ||
self.head_dim = head_dim | ||
self.n_heads = hidden_dim // head_dim | ||
self.n_heads_per_tp = self.n_heads | ||
|
||
self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=bias) | ||
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=bias) | ||
|
||
self.with_rope = with_rope | ||
self.with_qk_norm = with_qk_norm | ||
if self.with_qk_norm: | ||
self.q_norm = RMSNorm(head_dim, elementwise_affine=True) | ||
self.k_norm = RMSNorm(head_dim, elementwise_affine=True) | ||
|
||
if self.with_rope: | ||
self.rope_3d = RoPE3D(freq=1e4, F0=1.0, scaling_factor=1.0) | ||
self.rope_ch_split = [64, 32, 32] | ||
|
||
self.core_attention = self.attn_processor(attn_type=attn_type) | ||
self.parallel = attn_type == 'parallel' | ||
|
||
def apply_rope3d(self, x, fhw_positions, rope_ch_split, parallel=True): | ||
x = self.rope_3d(x, fhw_positions, rope_ch_split, parallel) | ||
return x | ||
|
||
def forward( | ||
self, | ||
x, | ||
cu_seqlens=None, | ||
max_seqlen=None, | ||
rope_positions=None, | ||
attn_mask=None | ||
): | ||
xqkv = self.wqkv(x) | ||
xqkv = xqkv.view( | ||
*x.shape[:-1], | ||
self.n_heads_per_tp, | ||
3 * self.head_dim | ||
) | ||
xq, xk, xv = torch.split(xqkv, [self.head_dim] * 3, dim=-1) ## seq_len, n, dim | ||
|
||
if self.with_qk_norm: | ||
xq = self.q_norm(xq) | ||
xk = self.k_norm(xk) | ||
|
||
if self.with_rope: | ||
xq = self.apply_rope3d(xq, rope_positions, self.rope_ch_split, parallel=self.parallel) | ||
xk = self.apply_rope3d(xk, rope_positions, self.rope_ch_split, parallel=self.parallel) | ||
|
||
output = self.core_attention( | ||
xq, | ||
xk, | ||
xv, | ||
cu_seqlens=cu_seqlens, | ||
max_seqlen=max_seqlen, | ||
attn_mask=attn_mask | ||
) | ||
output = rearrange(output, 'b s h d -> b s (h d)') | ||
output = self.wo(output) | ||
|
||
return output | ||
|
||
|
||
class CrossAttention(Attention): | ||
def __init__(self, hidden_dim, head_dim, bias=False, with_qk_norm=True, attn_type='torch'): | ||
super().__init__() | ||
self.head_dim = head_dim | ||
self.n_heads = hidden_dim // head_dim | ||
self.n_heads_per_tp = self.n_heads | ||
|
||
self.wq = nn.Linear(hidden_dim, hidden_dim, bias=bias) | ||
self.wkv = nn.Linear(hidden_dim, hidden_dim * 2, bias=bias) | ||
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=bias) | ||
|
||
self.with_qk_norm = with_qk_norm | ||
if self.with_qk_norm: | ||
self.q_norm = RMSNorm(head_dim, elementwise_affine=True) | ||
self.k_norm = RMSNorm(head_dim, elementwise_affine=True) | ||
|
||
self.core_attention = self.attn_processor(attn_type=attn_type) | ||
|
||
def forward( | ||
self, | ||
x: torch.Tensor, | ||
encoder_hidden_states: torch.Tensor, | ||
attn_mask=None | ||
): | ||
xq = self.wq(x) | ||
xq = xq.view(*xq.shape[:-1], self.n_heads_per_tp, self.head_dim) | ||
|
||
xkv = self.wkv(encoder_hidden_states) | ||
xkv = xkv.view(*xkv.shape[:-1], self.n_heads_per_tp, 2 * self.head_dim) | ||
|
||
xk, xv = torch.split(xkv, [self.head_dim] * 2, dim=-1) | ||
|
||
if self.with_qk_norm: | ||
xq = self.q_norm(xq) | ||
xk = self.k_norm(xk) | ||
|
||
output = self.core_attention( | ||
xq, | ||
xk, | ||
xv, | ||
attn_mask=attn_mask | ||
) | ||
|
||
output = rearrange(output, 'b s h d -> b s (h d)') | ||
output = self.wo(output) | ||
|
||
return output | ||
|
||
|
||
class GELU(nn.Module): | ||
r""" | ||
GELU activation function with tanh approximation support with `approximate="tanh"`. | ||
Parameters: | ||
dim_in (`int`): The number of channels in the input. | ||
dim_out (`int`): The number of channels in the output. | ||
approximate (`str`, *optional*, defaults to `"none"`): If `"tanh"`, use tanh approximation. | ||
bias (`bool`, defaults to True): Whether to use a bias in the linear layer. | ||
""" | ||
|
||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): | ||
super().__init__() | ||
self.proj = nn.Linear(dim_in, dim_out, bias=bias) | ||
self.approximate = approximate | ||
|
||
def gelu(self, gate: torch.Tensor) -> torch.Tensor: | ||
return torch.nn.functional.gelu(gate, approximate=self.approximate) | ||
|
||
def forward(self, hidden_states): | ||
hidden_states = self.proj(hidden_states) | ||
hidden_states = self.gelu(hidden_states) | ||
return hidden_states | ||
|
||
|
||
class FeedForward(nn.Module): | ||
def __init__( | ||
self, | ||
dim: int, | ||
inner_dim: Optional[int] = None, | ||
dim_out: Optional[int] = None, | ||
mult: int = 4, | ||
bias: bool = False, | ||
): | ||
super().__init__() | ||
inner_dim = dim * mult if inner_dim is None else inner_dim | ||
dim_out = dim if dim_out is None else dim_out | ||
self.net = nn.ModuleList([ | ||
GELU(dim, inner_dim, approximate="tanh", bias=bias), | ||
nn.Identity(), | ||
nn.Linear(inner_dim, dim_out, bias=bias) | ||
]) | ||
|
||
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: | ||
for module in self.net: | ||
hidden_states = module(hidden_states) | ||
return hidden_states | ||
|
||
|
||
def modulate(x, scale, shift): | ||
x = x * (1 + scale) + shift | ||
return x | ||
|
||
|
||
def gate(x, gate): | ||
x = gate * x | ||
return x | ||
|
||
|
||
class StepVideoTransformerBlock(nn.Module): | ||
r""" | ||
A basic Transformer block. | ||
Parameters: | ||
dim (`int`): The number of channels in the input and output. | ||
num_attention_heads (`int`): The number of heads to use for multi-head attention. | ||
attention_head_dim (`int`): The number of channels in each head. | ||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. | ||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. | ||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. | ||
num_embeds_ada_norm (: | ||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. | ||
attention_bias (: | ||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. | ||
only_cross_attention (`bool`, *optional*): | ||
Whether to use only cross-attention layers. In this case two cross attention layers are used. | ||
double_self_attention (`bool`, *optional*): | ||
Whether to use two self-attention layers. In this case no cross attention layers are used. | ||
upcast_attention (`bool`, *optional*): | ||
Whether to upcast the attention computation to float32. This is useful for mixed precision training. | ||
norm_elementwise_affine (`bool`, *optional*, defaults to `True`): | ||
Whether to use learnable elementwise affine parameters for normalization. | ||
norm_type (`str`, *optional*, defaults to `"layer_norm"`): | ||
The normalization layer to use. Can be `"layer_norm"`, `"ada_norm"` or `"ada_norm_zero"`. | ||
final_dropout (`bool` *optional*, defaults to False): | ||
Whether to apply a final dropout after the last feed-forward layer. | ||
attention_type (`str`, *optional*, defaults to `"default"`): | ||
The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. | ||
positional_embeddings (`str`, *optional*, defaults to `None`): | ||
The type of positional embeddings to apply to. | ||
num_positional_embeddings (`int`, *optional*, defaults to `None`): | ||
The maximum number of positional embeddings to apply. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
dim: int, | ||
attention_head_dim: int, | ||
norm_eps: float = 1e-5, | ||
ff_inner_dim: Optional[int] = None, | ||
ff_bias: bool = False, | ||
attention_type: str = 'parallel' | ||
): | ||
super().__init__() | ||
self.dim = dim | ||
self.norm1 = nn.LayerNorm(dim, eps=norm_eps) | ||
self.attn1 = SelfAttention(dim, attention_head_dim, bias=False, with_rope=True, with_qk_norm=True, | ||
attn_type=attention_type) | ||
|
||
self.norm2 = nn.LayerNorm(dim, eps=norm_eps) | ||
self.attn2 = CrossAttention(dim, attention_head_dim, bias=False, with_qk_norm=True, attn_type='torch') | ||
|
||
self.ff = FeedForward(dim=dim, inner_dim=ff_inner_dim, dim_out=dim, bias=ff_bias) | ||
|
||
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim ** 0.5) | ||
|
||
@torch.no_grad() | ||
def forward( | ||
self, | ||
q: torch.Tensor, | ||
kv: Optional[torch.Tensor] = None, | ||
timestep: Optional[torch.LongTensor] = None, | ||
attn_mask=None, | ||
rope_positions: list = None, | ||
) -> torch.Tensor: | ||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( | ||
torch.clone(chunk) for chunk in | ||
(self.scale_shift_table[None] + timestep.reshape(-1, 6, self.dim)).chunk(6, dim=1) | ||
) | ||
|
||
scale_shift_q = modulate(self.norm1(q), scale_msa, shift_msa) | ||
|
||
attn_q = self.attn1( | ||
scale_shift_q, | ||
rope_positions=rope_positions | ||
) | ||
|
||
q = gate(attn_q, gate_msa) + q | ||
|
||
attn_q = self.attn2( | ||
q, | ||
kv, | ||
attn_mask | ||
) | ||
|
||
q = attn_q + q | ||
|
||
scale_shift_q = modulate(self.norm2(q), scale_mlp, shift_mlp) | ||
|
||
ff_output = self.ff(scale_shift_q) | ||
|
||
q = gate(ff_output, gate_mlp) + q | ||
|
||
return q | ||
|
||
|
||
class PatchEmbed(nn.Module): | ||
"""2D Image to Patch Embedding""" | ||
|
||
def __init__( | ||
self, | ||
patch_size=64, | ||
in_channels=3, | ||
embed_dim=768, | ||
layer_norm=False, | ||
flatten=True, | ||
bias=True, | ||
): | ||
super().__init__() | ||
|
||
self.flatten = flatten | ||
self.layer_norm = layer_norm | ||
|
||
self.proj = nn.Conv2d( | ||
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias | ||
) | ||
|
||
def forward(self, latent): | ||
latent = self.proj(latent).to(latent.dtype) | ||
if self.flatten: | ||
latent = latent.flatten(2).transpose(1, 2) | ||
if self.layer_norm: | ||
latent = self.norm(latent) | ||
|
||
return latent |
Oops, something went wrong.