万象是开源的一系列视频生成模型,提出新的时空变分自编码器(VAE)、可扩展的预训练策略、大数据监管、自动化评测指标,提升模型性能和通用性。14B模型在数十亿图片和视频上训练,在数据量和模型大小上都展现出scaling law。覆盖多样下游任务,包括图生视频、指令引导的视频编辑等等,能接受中文。1.3B模型只需要8.19GB VRAM。这里只关注DiT 的模型架构和前向过程。
准备模型输入
提示词编码:输出形状默认 (词汇表大小256384)
python
text_emb = get_umt5_embedding(checkpoint_path=args.text_encoder_path,\
prompts=args.prompt).to(dtype=torch.bfloat16).cuda()
首帧图编码:
python
frames_to_encode = torch.cat( # 包装成一个长为 F 的视频张量
[
image_tensor.unsqueeze(2), # [B, 3, H_img, W_img] -> [B, C, 1, H_img, W_img]
torch.zeros(1, 3, F - 1, h, w, device=image_tensor.device)
], dim=2
) # -> [B, 3, F, H_img, W_img]
# vae编码,一些设置中,时间压缩率为 4(T=(F-1)/4+1),空间压缩了为 8(H=H_img/8),通道数C=16
encoded_latents = tokenizer.encode(frames_to_encode) # -> [B, C, T, H, W]
msk = torch.zeros(1, 4, lat_t, lat_h, lat_w, device=tensor_kwargs["device"], dtype=tensor_kwargs["dtype"]) # 0维硬编码成1,输入的潜向量的批量大小也只能为 1 了【像是一个小 bug】
msk[:, :, 0, :, :] = 1.0 # 高亮潜向量中的第一帧为条件帧
y = torch.cat([msk, encoded_latents.to(**tensor_kwargs)], dim=1)
# 将掩码和潜向量在通道维度进行拼接作为y -> [1, 4+C, T, H, W]
将上面两个组件处理后的结果作为条件,和形状为的latent一起输入给模型
python
condition = {"crossattn_emb": repeat(text_emb.to(**tensor_kwargs), "b l d -> (k b) l d", k=args.num_samples)} # 和t2v 的条件进行比对
condition = {"crossattn_emb": repeat(text_emb.to(**tensor_kwargs), "b l d -> (k b) l d", k=args.num_samples), "y_B_C_T_H_W": y}
# 模型调用都是输入 latent timestep condition
v_pred = net(x_B_C_T_H_W=x.to(**tensor_kwargs),
timesteps_B_T=(t_cur.float() * ones * 1000).to(**tensor_kwargs), **condition
).to(torch.float64)
在Wanvideo中
输入分片:输入形状为,其中潜空间通道数默认是
, 由VAE编码器决定,将它和形状为
的y在通道维度拼接,得到形状为
的x。默认分片大小patch_size=(1, 2, 2),调整x形状变成
,经由一个线性分片嵌入层,把
变成
。
时间步编码:对时间步t[B],通过1D正余弦编码变成一个维的向量,具体地,根据一半的维度构造频率
,和时间步序列做外积,得到形状
的二维张量fs,其中
,将频率的余弦正弦值拼接(
),再通过一个(两个线性层夹一个SiLU)时间嵌入层,把
变成
,再由一个(一个SiLU配一个线性层)时间投影层,变成6d,形状调整为
。
提示词嵌入:通过一个(两个线性层夹一个GELU)文本嵌入层,把变成
位置编码:如何为编码位置信息呢?不是直接加在输入上,而是通过3D旋转位置编码把位置相关的旋转作用在Q/K上。具体地,将每个头的隐藏维度
划分为3个部分,其中
、
分别为高、宽、帧数进行编码,都保证是2的倍数。分别构造频率
,
,让位置和频率做外积,得到形状分别为
、
、
的二维张量
,通过扩展成四维,并在最后一维拼回
,每个(t, h, w)对应一组 RoPE 相位,
python
freqs_T_H_W_D = torch.cat(
[
repeat(freqs_t, "t d -> t h w d", h=H, w=W),
repeat(freqs_h, "h d -> t h w d", t=T, w=W),
repeat(freqs_w, "w d -> t h w d", t=T, h=H),
],
dim=-1,
)
将这个四维的频率张量展平,变成形状为的二维相位张量。
末尾层:最后调制一下,线性层变化形状
python
class Head(nn.Module):
def __init__(self, dim, out_dim, patch_size, eps=1e-6):
...
out_dim = math.prod(patch_size) * out_dim
self.norm = WanLayerNorm(dim, eps)
self.head = nn.Linear(dim, out_dim)
# 在DiT中,调制层的输入是时间步和标签的嵌入向量加和,当前场景下没有 label,多一个可学习的调制参数吧
self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
def init_weights(self):
self.norm.reset_parameters()
std = 1.0 / math.sqrt(self.dim)
torch.nn.init.trunc_normal_(self.modulation, std=std)
torch.nn.init.trunc_normal_(self.head.weight, std=std)
self.head.bias.data.zero_()
def forward(self, x, e):
assert e.dtype == torch.float32
with amp.autocast("cuda", dtype=torch.float32):
e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1) # 这个输入的不是时间步嵌入诶?
x = self.head(self.norm(x) * (1 + e[1]) + e[0]
return x
DiT块的处理
每一个DiT块是如何处理(的潜向量、
的时间步、
的提示词、
的位置编码频率)这些输入的呢?(忽略批量维度)
先放一些归一化层的定义:
python
class WanLayerNorm(nn.LayerNorm):
# elementwise_affine 带可学习的缩放/偏移参数
def __init__(self, dim, eps=1e-6, elementwise_affine=False):
super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
def forward(self, x):
with amp.autocast("cuda", dtype=torch.float32):
return super().forward(x.float()).type_as(x)
class WanRMSNorm(nn.Module):
...
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
def forward(self, x):
return self._norm(x.float()).type_as(x) * self.weight
通用attention的统一入口:根据不同GPU和软件环境,尽量选择最快的实现,更关注调度和后端选择,而非 attention数学推导本身。Hopper(H100, H20, H200......):flash attention3 --》cuDNN attention;Ampere(A100, A40 SM80、Ampere RTX)优先flash attention2;Ada(如RTX 4090)和Blackwell RTX:flash attention --》cuDNN --》xformers;对于Blackwell数据中心卡(B200, GB200 SM100)则优先cuDNN。其中SM(计算能力*10)是一个硬件特征标签
python
# 格式 GPU 对应的计算能力:
#
# | GPU / category | Arch |
# |================|=======|
# | A100 | SM80 |
# | A40 | SM80 |
# | Ampere RTX | SM86 |
# |----------------|-------|
# | Ada Lovelace | SM89 |
# |----------------|-------|
# | H20 | SM90 |
# | H100 | SM90 |
# | H200 | SM90 |
# |----------------|-------|
# | B200 | SM100 |
# | Blackwell RTX | SM103 |
# |----------------|-------|
#
try: # 尝试导入 flashAttention3
from flash_attn_3.flash_attn_interface import flash_attn_func
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
def get_device_cc(device) -> int: # 获取SM
if torch.cuda.is_available() abd torch.version.cuda and device.type=="cuda":
major, minor = torch.cuda.get_device_capability(device)
return major * 10 + minor
return 0
def attention(q, k, v, dropout_p=0.0, softmax_scale=None, q_scale=None, causal=False, deterministic=False):
assert q.dtype == k.dtype == v.dtype
dtype = q.dtype
supported_dtypes = [torch.bfloat16, torch.float16, torch.float32]
is_half = dtype in [torch.bfloat16, torch.float16]
compute_cap = get_device_cc(q.device)
if dtype not in supported_dtypes:
raise NotImplementedError(f"{dtype=} is not supported.")
if q_scale is not None:
q = q * q_scale
if compute_cap == 90 and FLASH_ATTN_3_AVAILABLE and is_half: # 优先 flashAttention3
return flash_attn_func(q=q, k=k, v=v, softmax_scale=softmax_scale, causal=causal, deterministic=deterministic)[0]
else:
if compute_cap in [90, 100] and is_half:
SDPA_BACKENDS = [
SDPBackend.CUDNN_ATTENTION, # 优先 cuDNN
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION
]
BEST_SDPA_BACKEND = SDPBackend.CUDNN_ATTENTION
elif is_half:
SDPA_BACKENDS = [
SDPBackend.FLASH_ATTENTION, # 优先 flash attention
SDPBackend.CUDNN_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
]
BEST_SDPA_BACKEND = SDPBackend.FLASH_ATTENTION if compute_cap >= 80 else SDPBackend.EFFICIENT_ATTENTION
else:
assert dtype == torch.float32, f"Unrecognized {dtype=}."
SDPA_BACKENDS = [SDPBackend.EFFICIENT_ATTENTION]
BEST_SDPA_BACKEND = SDPBackend.EFFICIENT_ATTENTION
if deterministic:
raise NotImplementedError("Deterministic mode in attention is only supported when Flash Attention 3 is available.")
try:
sdpa_kernel(backends=SDPA_BACKENDS, set_priority_order=True)
sdpa_kernel_ = partial(sdpa_kernel, set_priority_order=True)
except TypeError:
sdpa_kernel_ = sdpa_kernel
SDPA_BACKENDS = [BEST_SDPA_BACKEND]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
with sdpa_kernel_(backends=SDPA_BACKENDS):
out = torch.nn.functional.sclaed_dot_product_attention(q, k, v, is_causal=causal, dropout_p=dropout_p, scale=softmax_scale)
out = out.transpose(1, 2).continuous()
return out
序列并行版注意力计算:每张卡从"局部序列&全部头" -》"全部序列&局部头"-》"局部序列&全部头"
python
class MinimalA2AAttnOp(DistributedAttention):
...
def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor:
results = super().forward(query, key, value, *args, **kwargs)
return rearrange(results, "b ... h l -> b ... (h l)") # 把头也聚合了
class DistributedAttention(torch.nn.Module):
...
def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor:
if self.pg is None:
return self.local_attn(query, key, value, *args, **kwargs)
pg_size = dist.get_world_size(self.pg)
if pg_size < 2:
return self.local_attn(query, key, value, *args, **kwargs)
# 启用上下文并行
query_layer, key_layer, value_layer = _SeqAllToAllQKV.apply(self.pg, query, key, value, pg_size, self.stream, True) # 将 qkv 从"局部序列完整头"变成"完整序列局部头"
context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs) # 局部计算注意力
output = _SeqAllToAll.apply(self.pg, context_layer, False) # 将输出从"完整序列局部头"变成"局部序列完整头"
return output
# 启用开关设置
def set_context_parallel_group(self, group, stream):
self.pg = group
self.stream = stream
########## 切分QKV的前反向实现 ##########
class _SeqAllToAllQKV(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup,
q: Tensor, k: Tensor, v: Tensor,
cp_size: int, cp_stream: torch.cuda.Stream,
local_seq_2_local_head: bool,
) -> Tuple[Tensor, Tensor, Tensor]:
ctx.group = group
ctx.cp_size = cp_size
ctx.cp_stream = cp_stream
ctx.local_seq_2_local_head = local_seq_2_local_head
q, k, v = async_a2a_communicate([q, k, v], cp_size, group, cp_stream, local_seq_2_local_head)
return q, k, v
@staticmethod
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, Tensor, Tensor, None, None, None]:
q_grad, k_grad, v_grad = _SeqAllToAllQKV.apply(ctx.group, *grad_output, ctx.cp_size, ctx.cp_stream, not ctx.local_seq_2_local_head)
return (None, q_grad, k_grad, v_grad, None, None, None)
# all-to-all 通信的异步流水线实现
def async_a2a_communicate(
a2a_inputs: Union[torch.Tensor, List[torch.Tensor]], # [q, k, v]
cp_size: int, # 序列并行数
cp_group: ProcessGroup, # 分布式通信组
cp_stream: torch.cuda.Stream, # 用于执行通信等待和后处理的 CUDA stream
local_seq_2_local_head: bool, # 序列局部 -》 头局部
) -> Union[torch.Tensor, List[torch.Tensor]]:
a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs
a2a_outputs, a2a_reqs, a2a_post_fns = [None] * len(a2a_inputs), [None] * len(a2a_inputs), [None] * len(a2a_inputs) # 通信输出张量、异步请求句柄、通信后 reshape 函数
# 3阶段流水线调度
if ocal_seq_2_local_head:
for i in range(len(a2a_inputs) + 2):
# 2. 进行 all2all 通信,拿到部分头的完整序列
if 0 < i < len(a2a_inputs) + 1:
a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
a2a_reqs[i - 1] = dist.all_to_all_single(a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True)
a2a_post_fns[i - 1] = post_all2all(local_seq_2_local_head, cp_size)
# 3. 聚合序列
if i>1:
with torch.cuda.stream(cp_stream):
a2a_reqs[i - 2].wait()
a2a_outputs[i - 2] = a2a_post_fns[i - 2](a2a_outputs[i - 2])
# 1. 分头
if i < len(a2a_inputs):
a2a_inputs[i] = rearrange(a2a_inputs[i], "bs seq_len (w h) d -> w bs seq_len h d", w=cp_size).contiguous()
else:
for i in range(len(a2a_inputs) + 2):
# 2. 进行 all2all 通信,拿到部分序列的所有头
if 0 < i < len(a2a_inputs) + 1:
...
# 1. 分序列
if i < len(a2a_inputs):
a2a_inputs[i] = rearrange(a2a_inputs[i], "bs (w s) h d -> w bs s h d", w=cp_size).contiguous()
# 3. 聚合头
if i > 1:
with torch.cuda.stream(cp_stream):
a2a_reqs[i - 2].wait()
a2a_outputs[i - 2] = a2a_post_fns[i - 2](a2a_outputs[i - 2])
torch.cuda.current_stream().wait_stream(cp_stream)
return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs
########## 切分输出的前反向实现 ##########
class _SeqAllToAll(torch.autograd.Function):
@staticmethod
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, local_seq_2_local_head: bool) -> Tensor:
ctx.group = group
ctx.local_seq_2_local_head = local_seq_2_local_head
res = single_all_to_all(input, local_seq_2_local_head, group, False)
return res
@staticmethod
def backward(ctx:Any, *grad_output: Tensor) -> Tuple[None, Tensor, None]:
return (None, _SeqAllToAll.apply(ctx.group, *grad_output, not ctx.local_seq_2_local_head), None)
# all-to-all 通信的简单实现
def single_all_to_all(input, local_seq_2_local_head, group, async_op=False):
seq_world_size = dist.get_world_size(group)
post_all2all_fun = post_all2all(local_seq_2_local_head, seq_world_size)
output = torch.empty_like(input_t)
if local_seq_2_local_head:
bs, local_seq_len, total_head_num, head_dim = input.shape
assert total_head_num % seq_world_size == 0, f"Number of heads ({total_head_num}) must be divisible by the sequence parallel size ({seq_world_size})!"
# 1. 分头
input = rearrange(input, "bs seq_len (w h) d -> w bs seq_len h d", w=seq_world_size, h=total_head_num // seq_world_size).contiguous()
else:
bs, global_seq_len, local_head_num, head_dim = input.shape
assert global_seq_len % seq_world_size == 0, f"Length of sequence ({global_seq_len}) must be divisible by the sequence parallel size ({seq_world_size})!"
# 1. 分序列
input = rearrange(input, "bs (w s) h d -> w bs s h d", w=seq_world_size, s=global_seq_len // seq_world_size).contiguous()
# 2. 进行 all2all 通信
dist.all_to_all_single(output, input, group=group, async_op=async_op)
# 3. 聚合
res = post_all2all_fun(output)
return res
自注意力层:
python
class WanSelfAttention(nn.Module):
def __init__(self, dim, num_heads, qk_norm=True, eps=1e-6):
assert dim % num_heads == 0
super.__init__()
self.dim = dim
self.num_heads = num_heads
self.head_dim = dim // self.num_heads
self.qk_norm = qk_norm
self.eps = eps
self.q = nn.Linear(dim, dim)
self.k = nn.Linear(dim, dim)
self.v = nn.Linear(dim, dim)
self.o = nn.Linear(dim, dim)
self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
self.attn_op = MinimalA2AAttnOp()
def init_weights(self):
# xavier 是一种"按层输入输出维度缩放方差"的初始化策略,而这里采用"从截断正态分布中采样"
std = 1.0 / math.sqrt(self.dim)
torch.nn.init.trunc_normal_(self.q.weight, std=std)
...
# zero out bias
self.q.bias.data.zero_()
...
# reset norm weights
if self.qk_norm:
self.norm_q.reset_parameters()
self.norm_k.reset_parameters()
def forward(self, x, freqs):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
def qkv_fn(x):
q = self.norm_q(self.q(x)).view(b, s, n, d)
k = self.norm_k(self.k(x)).view(b, s, n, d)
v = self.v(x).view(b, s, n, d)
return q, k, v
q, k, v = qkv_fn(x)
q = rope_apply(q, freqs)
k = rope_apply(k, freqs)
x = self.attn_op(q, k, v)
x = x.flatten(2) # 将第2维及以后所有维度,展成一个维度
x = self.o(x)
return x
def set_context_parallel_group(self, process_group, ranks, stream):
self.attn_op.set_context_parallel_group(process_group, ranks, stream)
其中位置编码频率是如何应用到形状为的q和k上面的?
python
def rope_apply(x, freqs):
b, s, n, d = x.shape # 序列长度 头数 每个头的维度
freqs = frqs.view(s, d//2)
cos = torch.cos(freqs).to(torch.float32)
sin = torch.sin(freqs).to(torch.float32)
rotated = flash_apply_rotary_emb(x.to(torch.float32), cos, sin, interleaved=True, inplace=False)
return rotated.to(x.dtype)
Triton核实现的旋转位置编码:Triton是来自OpenAI,广泛用于高性能GPU kernel开发的python包,相较于更底层复杂的CUDA,Triton的接口更偏python、张量风格,更适合深度学习算子开发。Triton的类型注解不完整、动态属性多,通过标注"# type: ignore "可以让静态类型检查器的报错被忽略掉。Triton kernel直接操作指针(ptr)
python
import triton # type: ignore
import triton.language as tl # type: ignore
def apply_rotary_embedding(
x: torch.Tensor, # [b, s, n, d]
cos: torch.Tensor, sin: torch.Tensor, # [HWT, d/2]
interleaved: bool = False
) -> torch.Tensor:
output = torch.empty_like(x)
if x.dim() > 3:
b, s, n, d = x.shape
else:
s, n, d = x.shape
b = 1
assert b % 2 == 0, "head dim must be divisible by 2"
x = x.view(-1, d)
output = output.view(-1, d)
grid = (b * s * n, )
if interleaved ans cos.shape[-1] == d:
cos = cos[..., ::2].contiguous()
sin = sin[..., ::2].continuous()
else: # 维度应该是只有一半的
cos = cos.contiguous()
sin = sin.contiguous()
# grid 决定这个 kernel 会启动多少个并行程序实例(programs)
_rotary_embedding_kernel[grid](
output, x, cos, sin, # 传入Triton核时会自动变成带元素类型的指针,比如*fp32 pointer
n, d, s,
x.stride(0), cos.stride(0), sin.stride(0),
# 在第0维移到下个邻位时,内存地址要跳过多少个元素,即隔"行"元素偏移数
)
return output
# 提供多种配置,在第一次运行时都尝试一下,记录"最优配置",后续在相同条件(由 key 指定)下直接用
@triton.autotune(
configs=[
triton.Config({"BLOCK_HS_HALF": 32}, num_warps=2),
triton.Config({"BLOCK_HS_HALF": 64}, num_warps=4),
triton.Config({"BLOCK_HS_HALF": 128}, num_warps=4),
triton.Config({"BLOCK_HS_HALF": 256}, num_warps=8),
], key=["head_size", "interleaved"],
)
@triton.jit # 表示这是一个Triton kernel,会被编译成GPU代码执行
def _rotary_embedding_kernel(
out_ptr, x_ptr, cos_ptr, sin_ptr, # [bsn, d] [HWT, d//2]
n, d, s, stride_x_row, stride_cos_row, stride_sin_row,
BLOCK_HS_HALF: tl.constexpr,
):
row_idx = tl.program_id(0) # 当前program(类似 CUDA block)的ID,每个 program 处理一行数据
token_idx = (row_idx // n) % s # token维度
x_row_ptr = x_ptr + row_idx * stride_x_row
cos_row_ptr = cos_ptr + token_idx * stride_cos_row
sin_row_ptr = sin_ptr + token_idx * stride_sin_row
out_row_ptr = out_ptr + row_idx * stride_x_row
d_half = d // 2
for block_start in range(0, d_half, BLOCK_HS_HALF): # 每次加载一块
offset_half = block_start + tl.arnge(0, BLOCK_HS_HALF)
mask = offset_half < d_half # 最后一块可能不满
offset_x1 = 2 * offset_half
offset_x2 = 2 * offset_half + 1
# 只加载有效(mask)位置,无效位置填0
cos_vals = tl.load(cos_row_ptr + offset_half, mask=mask, other=0.0).to(tl.float32)
sin_vals = tl.load(sin_row_ptr + offset_half, mask=mask, other=0.0).to(tl.float32)
x1_vals = tl.load(x_row_ptr + offset_x1, mask=mask, other=0.0).to(tl.float32)
x2_vals = tl.load(x_row_ptr + offset_x2, mask=mask, other=0.0).to(tl.float32)
# RoPE 本质是二维旋转:x1' = x1 * cos - x2 * sin; x2' = x1 * sin + x2 * cos
# tl.fma(a, b, c) = a * b + c 融合乘加
o1_vals = tl.fma(-x2_vals, sin_vals, x1_vals * cos_vals)
o2_vals = tl.fma(x1_vals, sin_vals, x2_vals * cos_vals)
# 写回
tl.store(out_ptr + offset_x1, o1_vals.to(x1_vals.dtype), mask=mask)
tl.store(out_ptr + offset_x2, o2_vals.to(x2_vals.dtype), mask=mask)
交叉注意力层:
python
class WanCrossAttention(WanSelfAttention):
def forward(self, x, context, ):
b, n, d = x.size(0), self.num_heads, self.head_dim
q = self.norm_q(self.q(x)).view(b, -1, n, d)
k = self.norm_k(self.k(context).view(b, -1, n, d)
v = self.v(context).view(b, -1, n, d)
x = self.attn_op(q, k, v)
x = x.flatten(2)
x = self.o(x)
return x
带时间步调制的DiT块:
python
class WanAttention(nn.Module):
def __init__(self, dim, ffn_dim, num_heads, qk_norm=True, cross_attn_norm=False, eps=1e-6):
super().__init__()
...
self.norm1 = WanLayerNorm(dim, eps)
self.self_attn = WanSelfAttention(dim, num_heads, qk_norm, eps)
self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
self.cross_sttn = WanCrossAttention(dim, num_heads, qk_norm, eps)
self.norm2 = WanLayerNorm(dim, eps)
self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim))
self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
...
def forward(self, x, e, freqs, context):
assert e.dtype == torch.float32
with amp.autocast("cuda", dtype=torch.float32)
e = (self.modulation + e).chunk(6, dim=1) # 按照第1维切成 6 份
assert e[0].dtype == torch.float32
# self-attention
y = self.self_attn(
(self.norm1(x).float() * (1 + e[1]) + e[0]).type_as(x), # y = x * (1+scale) + shift
freqs
)
with amp.autocast("cuda", dtype=torch.float32):
x = x + y * e[2].type_as(x) # x + gate * y
# cross-attention
x = x + self.cross_attn(self.norm3(x), context)
# ffn
y = self.ffn(
(self.norm2(x).float() * (1 + e[4]) + e[3]).type_as(x), # y = x * (1+scale) + shift
)
with amp.autocast("cuda", dtype=torch.float32):
x = x + y * e[2].type_as(x) # x + gate * y
return x
完整模型
类初始化
python
class WanModel(nn.Module):
def __init__(self,
model_type="i2v",
patch_size=(1, 2, 2),
in_dim=36, # 潜向量+图片+mask 通道数
text_len=512,
text_dim=4096, # 提示词
freq_dim=256, # 时间步编码
dim=5120,
ffn_dim=13824,
out_dim=16, # 潜向量通道数
num_heads=40,
num_layers=40,
qk_norm=True,
cross_attn_norm=True,
eps=1e-6,
sac_config: SACCofig = SACCofig(),# 训练过程中激活者选择保存还是重算
)
...
self.patch_embedding = nn.Linear(in_dim * patch_size[0] * patch_size[1] * patch_size[2], dim)
self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim))
self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)
self.blocks = nn.ModuleList(
[WanAttentionBlock(dim, ffn_dim, num_heads, qk_norm, cross_attn_norm, eps) for _ in range(num_layers)]
)
self.head = Head(dim, out_dim, patch_size, eps)
assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
d = dim // num_heads
self.rope_position_embedding = VideoRopePosition3DEmb(head_dim=d, len_h=128, len_w=128, len_t=32) # 设置的最大长度
self.init_weights()
self.enable_selective_checkpoint(sac_config)
序列并行开关使能
python
def enable_context_parallel(self, process_group: Optional[ProcessGroup] = None):
cp_ranks = get_process_group_ranks(process_group)
for block in self.blocks:
block.self_attn.set_context_parallel_group(process_group=process_group, ranks=cp_ranks, stream=torch.cuda.Stream())
self._is_context_parallel_enabled = True
self._cp_group = process_group
def disable_context_parallel(self):
for block in self.blocks:
block.self_attn.set_context_parallel_group(process_group=None, ranks=None, stream=torch.cuda.Stream())
self._is_context_parallel_enabled = False
self._cp_group = None
前向过程
python
def forward(self, x_B_C_T_H_W, timestep_B_T, crossattn_emb, y_B_C_T_H_W=None, **kwargs,):
cp_group = getattr(self, "_cp_group", None)
cp_enabled = (cp_group is not None) and (cp_goup.size() > 1)
if cp_enabled:
x_B_C_T_H_W = broadcast(x_B_C_T_H_W, cp_group)
... # 为了一致性做的一些广播操作
assert timesteps_B_T.shape[1] == 1
t_B = timesteps_B_T[:, 0]
del kwargs
if self.model_type == "i2v":
assert y_B_C_T_H_W is not None
x_B_C_T_H_W = torch.cat([x_B_T_H_W, y_B_C_T_H_W], dim=1)
# 分片 [B, C, T, H, W] -> [B, L, d_in] -> [B, L, d]
kt, kh, kw = self.patch_size
B, _, T_in, H_in, W_in = x_B_T_H_W.shape
assert (T_in % kt) == 0 and (H_in % kh) == 0 and (W_in % kw) == 0
T, H, W = T_in // kt, H_in // kh, W_in // kw
L = T * H * W
x_B_L_Din = rearrange(
x_B_C_T_H_W,
"b c (t kt) (h kh) (w kw) -> b (t h w) (c kt kh kw)",
kt=kt, kh=kh, kw=kw,
).contiguous()
if cp_enabled:
assert (L % cp_group.size()) == 0, f"L=T*H*W must be divisible by cp_size. Got L={L}, cp={cp_group.size()}."
x_B_L_Din = split_inputs_cp(x_B_L_Din, seq_dim=1, cp_group=cp_group)
x_B_L_D = self.patch_embedding(x_B_L_Din)
seq_lens = torch.tensor([u.size(0) for u in x_B_L_D], dtype=torch.long)
# 不同尺寸下采样、切片得到的序列长度是不一样的,也许代码中没用到,但还是写在这提示你
# 时间步 [B,] -> [B, d_freq] -> [B, d] -> [B, 6d] -> [B, 6, d]
with amp.autocast("cuda", dtype=torch.float32):
e_B_D = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t_B).float())
e0_B_6_D = self.time_projection(e_B_D).unflatten(1, (6, self.dim))
assert e_B_D.dtype == torch.float32 and e0_B_6_D.dtype == torch.float32
# 提示词 [B, L_text, d_text] -> [B, L_text, d]
context_B_L_D = self.text_embedding(crossattn_emb)
# 位置编码 [L, d_head//2]
freqs = self.repo_position_embedding.generate_embeddings(torch.Size([B, T, H, W, self.dim])).contiguous()
if cp_enabled:
freqs = split_inputs_cp(freqs, seq_dim=self.rope_position_embedding.seq_dim, cp_group=cp_group)
kwargs = dict(
e=e0_B_6_D,
freqs=freqs,
context=context_B_L_D,
)
for block_idx, block in enumerate(self.blocks):
x_B_L_D = block(x_B_L_D, **kwargs)
x_B_L_Dout = self.head(x_B_L_D, **kwargs)
if cp_enabled:
if torch.is_grad_enabled():
x_B_L_Dout = cat_outputs_cp_with_grad(x_B_L_Dout, seq_dim=1, cp_group=cp_group)
else:
x_B_L_Dout = cat_outputs_cp(x_B_L_Dout, seq_dim=1, cp_group=cp_group)
x_B_T_H_W = rearrange(
x_B_L_Dout,
"b (t h w) (kt kh kw d) -> b d (t kt) (h kh) (w kw)",
kt=kt, kh=kh, kw=kw, t=T, h=H, w=W, d=self.out_dim
)
return x_B_T_H_W
七七八八
看到的一些装饰器:
@dataclass用来定义存储数据的类,会自动生成__init__、__repr__等
@register_to_config会自动把__init__参数保存到self.config中便于后续保存和加载
@property把类的方法当做属性来访问
为了保证一致性的广播的实现:
python
# 广播一个张量
def robust_broadcast(
tensor: torch.Tensor,
src: int, pg: ProcessGroup,
is_check_shape: bool = False
) -> torch.Tensor:
...
# 先广播张量的形状
if distributed.get_rank() == src:
shape = torch.tensor(tensor.shape).cuda()
else:
shape = torch.empty(torch.dim(), dtype=torch.long).cuda()
if is_check_shape: ...
dst.broadcast(shape, src, group=pg)
# 非主进程调整张量形状
if distributed.get_rank() != src:
tensor = tensor.new_empty(shape.tolist()).type_as(tensor)
# 广播张量数据
dst.broadcast(tensor, src, group=pg)
return tensor
def broadcast(
item: torch.Tensor | str | None,
process: ProcessGroup | None = None
) -> torch.Tensor | str | None:
if process_group is None:
return item
min_rank = min(get_process_group_rank(process_group)) # 由编号最小 rank 广播
if isinstance(item, torch.Tensor):
if item.device.type != "cuda":
item = item.cuda()
item = robust_broadcast(item, min_rank, process_group)
elif item is not None:
broadcastable_list = [item]
broadcast_object_list(broadcastable_list, min_rank, group=process_group)
item = broadcastable_list[0]
return item
序列并行:
python
def split_inputs_cp(x: Tensor, sq_dim: int, cp_group: ProcessGroup) -> Tensor:
if x.device.type != "cuda":
x = x.cuda()
cp_ranks = get_process_group_ranks(cp_group)
cp_size = len(cp_ranks)
assert x.shape[seq_dim] % cp_size == 0, f"{x.shape[seq_dim]} cannot divide cp_size {cp_size}"
x = x.view(*x.shape[:seq_dim], cp_size, x.shape[seq_dim] // cp_size, *x.shape[(seq_dim+1) :])
seq_idx = torch.tensor([cp_group.rank()], device=x.device)
x = x.index_select(seq_dim, seq_idx)
x = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_im+2) :])
return x
# 拼接序列并行得到结果
def cat_outputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor:
world_size = get_world_size(cp_group)
gathered_tensors = [torch.zero_like(x) for _ in range(world_size)]
try:
all_gather(gathered_tensors, x, group=cp_group)
except RuntimeError as e:
raise RuntimeError(f"Failed to gather tensors: {e}")
return torch.cat(gathered_tensors, dim=seq_dim)
# 把梯度带上
def cat_output_cp_with_grad(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor:
...
rank = cp_group.rank()
gathered_tensors[rank] = x
# all_gather会创建一个新的 tensor 并赋值数据,计算图断开,需要替换回来以保留原始 tensor 的计算图
return torch.cat(gathered_tensors, dim=seq_dim)