wan2.2-i2v-a14b 模型架构

万象是开源的一系列视频生成模型,提出新的时空变分自编码器(VAE)、可扩展的预训练策略、大数据监管、自动化评测指标,提升模型性能和通用性。14B模型在数十亿图片和视频上训练,在数据量和模型大小上都展现出scaling law。覆盖多样下游任务,包括图生视频、指令引导的视频编辑等等,能接受中文。1.3B模型只需要8.19GB VRAM。这里只关注DiT 的模型架构和前向过程。

准备模型输入

提示词编码:输出形状默认 (词汇表大小256384)

python 复制代码
text_emb = get_umt5_embedding(checkpoint_path=args.text_encoder_path,\
prompts=args.prompt).to(dtype=torch.bfloat16).cuda()

首帧图编码:

python 复制代码
frames_to_encode = torch.cat(   # 包装成一个长为 F 的视频张量
  [
    image_tensor.unsqueeze(2), # [B, 3, H_img, W_img] -> [B, C, 1, H_img, W_img]
    torch.zeros(1, 3, F - 1, h, w, device=image_tensor.device)
  ], dim=2
)  # -> [B, 3, F, H_img, W_img]
# vae编码,一些设置中,时间压缩率为 4(T=(F-1)/4+1),空间压缩了为 8(H=H_img/8),通道数C=16
encoded_latents = tokenizer.encode(frames_to_encode)  # -> [B, C, T, H, W]

msk = torch.zeros(1, 4, lat_t, lat_h, lat_w, device=tensor_kwargs["device"], dtype=tensor_kwargs["dtype"])	# 0维硬编码成1,输入的潜向量的批量大小也只能为 1 了【像是一个小 bug】
msk[:, :, 0, :, :] = 1.0	# 高亮潜向量中的第一帧为条件帧
y = torch.cat([msk, encoded_latents.to(**tensor_kwargs)], dim=1)
# 将掩码和潜向量在通道维度进行拼接作为y -> [1, 4+C, T, H, W]

将上面两个组件处理后的结果作为条件,和形状为的latent一起输入给模型

python 复制代码
condition = {"crossattn_emb": repeat(text_emb.to(**tensor_kwargs), "b l d -> (k b) l d", k=args.num_samples)}	# 和t2v 的条件进行比对
condition = {"crossattn_emb": repeat(text_emb.to(**tensor_kwargs), "b l d -> (k b) l d", k=args.num_samples), "y_B_C_T_H_W": y}

# 模型调用都是输入 latent timestep condition
v_pred = net(x_B_C_T_H_W=x.to(**tensor_kwargs),
	timesteps_B_T=(t_cur.float() * ones * 1000).to(**tensor_kwargs), **condition
).to(torch.float64)

在Wanvideo中

输入分片:输入形状为,其中潜空间通道数默认是, 由VAE编码器决定,将它和形状为的y在通道维度拼接,得到形状为的x。默认分片大小patch_size=(1, 2, 2),调整x形状变成,经由一个线性分片嵌入层,把变成

时间步编码:对时间步t[B],通过1D正余弦编码变成一个维的向量,具体地,根据一半的维度构造频率,和时间步序列做外积,得到形状的二维张量fs,其中,将频率的余弦正弦值拼接(),再通过一个(两个线性层夹一个SiLU)时间嵌入层,把变成,再由一个(一个SiLU配一个线性层)时间投影层,变成6d,形状调整为

提示词嵌入:通过一个(两个线性层夹一个GELU)文本嵌入层,把变成

位置编码:如何为编码位置信息呢?不是直接加在输入上,而是通过3D旋转位置编码把位置相关的旋转作用在Q/K上。具体地,将每个头的隐藏维度划分为3个部分,其中分别为高、宽、帧数进行编码,都保证是2的倍数。分别构造频率, ,让位置和频率做外积,得到形状分别为的二维张量,通过扩展成四维,并在最后一维拼回,每个(t, h, w)对应一组 RoPE 相位,

python 复制代码
freqs_T_H_W_D = torch.cat(
  [
    repeat(freqs_t, "t d -> t h w d", h=H, w=W),
    repeat(freqs_h, "h d -> t h w d", t=T, w=W),
    repeat(freqs_w, "w d -> t h w d", t=T, h=H),
  ],
  dim=-1,
)

将这个四维的频率张量展平,变成形状为的二维相位张量。

末尾层:最后调制一下,线性层变化形状

python 复制代码
class Head(nn.Module):
  def __init__(self, dim, out_dim, patch_size, eps=1e-6):
    ...
    out_dim = math.prod(patch_size) * out_dim
    self.norm = WanLayerNorm(dim, eps)
    self.head = nn.Linear(dim, out_dim)
    
    # 在DiT中,调制层的输入是时间步和标签的嵌入向量加和,当前场景下没有 label,多一个可学习的调制参数吧
    self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
  
  def init_weights(self):
    self.norm.reset_parameters()
    std = 1.0 / math.sqrt(self.dim)
    torch.nn.init.trunc_normal_(self.modulation, std=std)
    torch.nn.init.trunc_normal_(self.head.weight, std=std)
    self.head.bias.data.zero_()
    
  def forward(self, x, e):
    assert e.dtype == torch.float32
    with amp.autocast("cuda", dtype=torch.float32):
      e = (self.modulation + e.unsqueeze(1)).chunk(2, dim=1)	# 这个输入的不是时间步嵌入诶?
      x = self.head(self.norm(x) * (1 + e[1]) + e[0]
		return x

DiT块的处理

每一个DiT块是如何处理(的潜向量、的时间步、的提示词、的位置编码频率)这些输入的呢?(忽略批量维度)

先放一些归一化层的定义:

python 复制代码
class WanLayerNorm(nn.LayerNorm):
  # elementwise_affine 带可学习的缩放/偏移参数
  def __init__(self, dim, eps=1e-6, elementwise_affine=False):
        super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
  def forward(self, x):
    with amp.autocast("cuda", dtype=torch.float32):
      return super().forward(x.float()).type_as(x)
    
class WanRMSNorm(nn.Module):
  ...
  def _norm(self, x):
    return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
  def forward(self, x):
    return self._norm(x.float()).type_as(x) * self.weight

通用attention的统一入口:根据不同GPU和软件环境,尽量选择最快的实现,更关注调度和后端选择,而非 attention数学推导本身。Hopper(H100, H20, H200......):flash attention3 --》cuDNN attention;Ampere(A100, A40 SM80、Ampere RTX)优先flash attention2;Ada(如RTX 4090)和Blackwell RTX:flash attention --》cuDNN --》xformers;对于Blackwell数据中心卡(B200, GB200 SM100)则优先cuDNN。其中SM(计算能力*10)是一个硬件特征标签

python 复制代码
# 格式 GPU 对应的计算能力:
#
# | GPU / category | Arch  |
# |================|=======|
# | A100           | SM80  |
# | A40            | SM80  |
# | Ampere RTX     | SM86  |
# |----------------|-------|
# | Ada Lovelace   | SM89  |
# |----------------|-------|
# | H20            | SM90  |
# | H100           | SM90  |
# | H200           | SM90  |
# |----------------|-------|
# | B200           | SM100 |
# | Blackwell RTX  | SM103 |
# |----------------|-------|
#

try:	# 尝试导入 flashAttention3
	from flash_attn_3.flash_attn_interface import flash_attn_func
	FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
	FLASH_ATTN_3_AVAILABLE = False

def get_device_cc(device) -> int:	# 获取SM
  if torch.cuda.is_available() abd torch.version.cuda and device.type=="cuda":
    major, minor = torch.cuda.get_device_capability(device)
    return major * 10 + minor
return 0

def attention(q, k, v, dropout_p=0.0, softmax_scale=None, q_scale=None, causal=False, deterministic=False):
  assert q.dtype == k.dtype == v.dtype
  dtype = q.dtype
  supported_dtypes = [torch.bfloat16, torch.float16, torch.float32]
  is_half = dtype in [torch.bfloat16, torch.float16]
  compute_cap = get_device_cc(q.device)
  
  if dtype not in supported_dtypes:
  	raise NotImplementedError(f"{dtype=} is not supported.")

	if q_scale is not None:
		q = q * q_scale
    
  if compute_cap == 90 and FLASH_ATTN_3_AVAILABLE and is_half:	# 优先 flashAttention3
    return flash_attn_func(q=q, k=k, v=v, softmax_scale=softmax_scale, causal=causal, deterministic=deterministic)[0]
  else:
    if compute_cap in [90, 100] and is_half:
      SDPA_BACKENDS = [
        SDPBackend.CUDNN_ATTENTION,	# 优先 cuDNN
        SDPBackend.FLASH_ATTENTION,
        SDPBackend.EFFICIENT_ATTENTION
      ]
      BEST_SDPA_BACKEND = SDPBackend.CUDNN_ATTENTION
    elif is_half:
      SDPA_BACKENDS = [
        SDPBackend.FLASH_ATTENTION,	# 优先 flash attention
        SDPBackend.CUDNN_ATTENTION,
        SDPBackend.EFFICIENT_ATTENTION,
      ]
      BEST_SDPA_BACKEND = SDPBackend.FLASH_ATTENTION if compute_cap >= 80 else SDPBackend.EFFICIENT_ATTENTION
    else:
      assert dtype == torch.float32, f"Unrecognized {dtype=}."
      SDPA_BACKENDS = [SDPBackend.EFFICIENT_ATTENTION]
      BEST_SDPA_BACKEND = SDPBackend.EFFICIENT_ATTENTION

    if deterministic:
      raise NotImplementedError("Deterministic mode in attention is only supported when Flash Attention 3 is available.")

    try:
      sdpa_kernel(backends=SDPA_BACKENDS, set_priority_order=True)
      sdpa_kernel_ = partial(sdpa_kernel, set_priority_order=True)
      except TypeError:
        sdpa_kernel_ = sdpa_kernel
        SDPA_BACKENDS = [BEST_SDPA_BACKEND]

    q = q.transpose(1, 2)
    k = k.transpose(1, 2)
    v = v.transpose(1, 2)
    
    with sdpa_kernel_(backends=SDPA_BACKENDS):
      out = torch.nn.functional.sclaed_dot_product_attention(q, k, v, is_causal=causal, dropout_p=dropout_p, scale=softmax_scale)
    out = out.transpose(1, 2).continuous()
    return out

序列并行版注意力计算:每张卡从"局部序列&全部头" -》"全部序列&局部头"-》"局部序列&全部头"

python 复制代码
class MinimalA2AAttnOp(DistributedAttention):
  ...
  def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor:
    results = super().forward(query, key, value, *args, **kwargs)
    return rearrange(results, "b ... h l -> b ... (h l)")	# 把头也聚合了
  
class DistributedAttention(torch.nn.Module):
  ...
  def forward(self, query: Tensor, key: Tensor, value: Tensor, *args: Any, **kwargs) -> Tensor:
    if self.pg is None:
      return self.local_attn(query, key, value, *args, **kwargs)
    
    pg_size = dist.get_world_size(self.pg)
    if pg_size < 2:
      return self.local_attn(query, key, value, *args, **kwargs)
    
    # 启用上下文并行
    query_layer, key_layer, value_layer = _SeqAllToAllQKV.apply(self.pg, query, key, value, pg_size, self.stream, True)	# 将 qkv 从"局部序列完整头"变成"完整序列局部头"
    context_layer = self.local_attn(query_layer, key_layer, value_layer, *args, **kwargs)	# 局部计算注意力
    output = _SeqAllToAll.apply(self.pg, context_layer, False) # 将输出从"完整序列局部头"变成"局部序列完整头"
    return output

# 启用开关设置
def set_context_parallel_group(self, group, stream):
  self.pg = group
  self.stream = stream

########## 切分QKV的前反向实现 ##########
class _SeqAllToAllQKV(torch.autograd.Function):
  @staticmethod
  def forward(ctx: Any, group: dist.ProcessGroup, 
		q: Tensor, k: Tensor, v: Tensor, 
		cp_size: int, cp_stream: torch.cuda.Stream,
		local_seq_2_local_head: bool,
	) -> Tuple[Tensor, Tensor, Tensor]:
    ctx.group = group
    ctx.cp_size = cp_size
    ctx.cp_stream = cp_stream
    ctx.local_seq_2_local_head = local_seq_2_local_head
    q, k, v = async_a2a_communicate([q, k, v], cp_size, group, cp_stream, local_seq_2_local_head)
    return q, k, v
  
  @staticmethod
  def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor, Tensor, Tensor, None, None, None]:
     q_grad, k_grad, v_grad = _SeqAllToAllQKV.apply(ctx.group, *grad_output, ctx.cp_size, ctx.cp_stream, not ctx.local_seq_2_local_head)
     return (None, q_grad, k_grad, v_grad, None, None, None)

# all-to-all 通信的异步流水线实现
def async_a2a_communicate(
  a2a_inputs: Union[torch.Tensor, List[torch.Tensor]],	# [q, k, v]
	cp_size: int, # 序列并行数
	cp_group: ProcessGroup, # 分布式通信组
  cp_stream: torch.cuda.Stream, # 用于执行通信等待和后处理的 CUDA stream
	local_seq_2_local_head: bool,	# 序列局部 -》 头局部
) -> Union[torch.Tensor, List[torch.Tensor]]:

  a2a_inputs = [a2a_inputs] if not isinstance(a2a_inputs, list) else a2a_inputs
  a2a_outputs, a2a_reqs, a2a_post_fns = [None] * len(a2a_inputs), [None] * len(a2a_inputs), [None] * len(a2a_inputs)	# 通信输出张量、异步请求句柄、通信后 reshape 函数
  
  # 3阶段流水线调度
  if ocal_seq_2_local_head:
    for i in range(len(a2a_inputs) + 2):
      
      # 2. 进行 all2all 通信,拿到部分头的完整序列
      if 0 < i < len(a2a_inputs) + 1:
        a2a_outputs[i - 1] = torch.empty_like(a2a_inputs[i - 1])
        a2a_reqs[i - 1] = dist.all_to_all_single(a2a_outputs[i - 1], a2a_inputs[i - 1], group=cp_group, async_op=True)
        a2a_post_fns[i - 1] = post_all2all(local_seq_2_local_head, cp_size)
      
      # 3. 聚合序列
      if i>1:
        with torch.cuda.stream(cp_stream):
          a2a_reqs[i - 2].wait()
          a2a_outputs[i - 2] = a2a_post_fns[i - 2](a2a_outputs[i - 2])
      
      # 1. 分头
      if i < len(a2a_inputs):
        a2a_inputs[i] = rearrange(a2a_inputs[i], "bs seq_len (w h) d -> w bs seq_len h d", w=cp_size).contiguous()
        
	else:
    for i in range(len(a2a_inputs) + 2):
      
      # 2. 进行 all2all 通信,拿到部分序列的所有头
      if 0 < i < len(a2a_inputs) + 1:
        ...
       
      # 1. 分序列
      if i < len(a2a_inputs):
        a2a_inputs[i] = rearrange(a2a_inputs[i], "bs (w s) h d -> w bs s h d", w=cp_size).contiguous()
        
      # 3. 聚合头
      if i > 1:
        with torch.cuda.stream(cp_stream):
          a2a_reqs[i - 2].wait()
          a2a_outputs[i - 2] = a2a_post_fns[i - 2](a2a_outputs[i - 2])
          
  torch.cuda.current_stream().wait_stream(cp_stream)
  return a2a_outputs[0] if len(a2a_inputs) == 1 else a2a_outputs

########## 切分输出的前反向实现 ##########
class _SeqAllToAll(torch.autograd.Function):
  @staticmethod
  def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor, local_seq_2_local_head: bool) -> Tensor:
    ctx.group = group
    ctx.local_seq_2_local_head = local_seq_2_local_head
    res = single_all_to_all(input, local_seq_2_local_head, group, False)
    return res
  
  @staticmethod
  def backward(ctx:Any, *grad_output: Tensor) -> Tuple[None, Tensor, None]:
    return (None, _SeqAllToAll.apply(ctx.group, *grad_output, not ctx.local_seq_2_local_head), None)
  
# all-to-all 通信的简单实现
def single_all_to_all(input, local_seq_2_local_head, group, async_op=False):
  seq_world_size = dist.get_world_size(group)
  post_all2all_fun = post_all2all(local_seq_2_local_head, seq_world_size)
  output = torch.empty_like(input_t)
  
  if local_seq_2_local_head:
    bs, local_seq_len, total_head_num, head_dim = input.shape
    assert total_head_num % seq_world_size == 0, f"Number of heads ({total_head_num}) must be divisible by the sequence parallel size ({seq_world_size})!"
    # 1. 分头
    input = rearrange(input, "bs seq_len (w h) d -> w bs seq_len h d", w=seq_world_size, h=total_head_num // seq_world_size).contiguous()
  else:
    bs, global_seq_len, local_head_num, head_dim = input.shape
    assert global_seq_len % seq_world_size == 0, f"Length of sequence ({global_seq_len}) must be divisible by the sequence parallel size ({seq_world_size})!"
    # 1. 分序列
    input = rearrange(input, "bs (w s) h d -> w bs s h d", w=seq_world_size, s=global_seq_len // seq_world_size).contiguous()
  
  # 2. 进行 all2all 通信
  dist.all_to_all_single(output, input, group=group, async_op=async_op)
  
  # 3. 聚合
  res = post_all2all_fun(output)
  
  return res

自注意力层:

python 复制代码
class WanSelfAttention(nn.Module):
  def __init__(self, dim, num_heads, qk_norm=True, eps=1e-6):
    assert dim % num_heads == 0
    super.__init__()
    self.dim = dim
    self.num_heads = num_heads
    self.head_dim = dim // self.num_heads
    self.qk_norm = qk_norm
    self.eps = eps
    
    self.q = nn.Linear(dim, dim)
    self.k = nn.Linear(dim, dim)
    self.v = nn.Linear(dim, dim)
    self.o = nn.Linear(dim, dim)
    self.norm_q = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
    self.norm_k = WanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
    self.attn_op = MinimalA2AAttnOp()
  
  def init_weights(self):
    # xavier 是一种"按层输入输出维度缩放方差"的初始化策略,而这里采用"从截断正态分布中采样"
    std = 1.0 / math.sqrt(self.dim)
    torch.nn.init.trunc_normal_(self.q.weight, std=std)
    ...
    # zero out bias
    self.q.bias.data.zero_()
    ...
    # reset norm weights
    if self.qk_norm:
      self.norm_q.reset_parameters()
      self.norm_k.reset_parameters()
    
  def forward(self, x, freqs):
    b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
    
    def qkv_fn(x):
      q = self.norm_q(self.q(x)).view(b, s, n, d)
      k = self.norm_k(self.k(x)).view(b, s, n, d)
      v = self.v(x).view(b, s, n, d)
      return q, k, v
  	
    q, k, v = qkv_fn(x)
    q = rope_apply(q, freqs)
    k = rope_apply(k, freqs)
    
    x = self.attn_op(q, k, v)
    x = x.flatten(2)	# 将第2维及以后所有维度,展成一个维度
    x = self.o(x)	
    return x
  
  def set_context_parallel_group(self, process_group, ranks, stream):
    self.attn_op.set_context_parallel_group(process_group, ranks, stream)

其中位置编码频率是如何应用到形状为的q和k上面的?

python 复制代码
def rope_apply(x, freqs):
  b, s, n, d = x.shape	# 序列长度 头数 每个头的维度
  freqs = frqs.view(s, d//2)
  cos = torch.cos(freqs).to(torch.float32)
  sin = torch.sin(freqs).to(torch.float32)
  rotated = flash_apply_rotary_emb(x.to(torch.float32), cos, sin, interleaved=True, inplace=False)
  return rotated.to(x.dtype)

Triton核实现的旋转位置编码:Triton是来自OpenAI,广泛用于高性能GPU kernel开发的python包,相较于更底层复杂的CUDA,Triton的接口更偏python、张量风格,更适合深度学习算子开发。Triton的类型注解不完整、动态属性多,通过标注"# type: ignore "可以让静态类型检查器的报错被忽略掉。Triton kernel直接操作指针(ptr)

python 复制代码
import triton  # type: ignore
import triton.language as tl  # type: ignore

def apply_rotary_embedding(
  x: torch.Tensor, # [b, s, n, d]
  cos: torch.Tensor, sin: torch.Tensor, # [HWT, d/2]
  interleaved: bool = False
) -> torch.Tensor:
  output = torch.empty_like(x)
  if x.dim() > 3:
    b, s, n, d = x.shape
  else:
    s, n, d = x.shape
    b = 1
	assert b % 2 == 0, "head dim must be divisible by 2"
	x = x.view(-1, d)
  output = output.view(-1, d)
  grid = (b * s * n, )
  
  if interleaved ans cos.shape[-1] == d:
    cos = cos[..., ::2].contiguous()
    sin = sin[..., ::2].continuous()
  else:	# 维度应该是只有一半的
    cos = cos.contiguous()
    sin = sin.contiguous()
  
	# grid 决定这个 kernel 会启动多少个并行程序实例(programs)  
  _rotary_embedding_kernel[grid](
    output, x, cos, sin,	# 传入Triton核时会自动变成带元素类型的指针,比如*fp32 pointer
    n, d, s, 
    x.stride(0), cos.stride(0), sin.stride(0),
    # 在第0维移到下个邻位时,内存地址要跳过多少个元素,即隔"行"元素偏移数
  )
  return output
    
# 提供多种配置,在第一次运行时都尝试一下,记录"最优配置",后续在相同条件(由 key 指定)下直接用
@triton.autotune(
  configs=[
    triton.Config({"BLOCK_HS_HALF": 32}, num_warps=2),
    triton.Config({"BLOCK_HS_HALF": 64}, num_warps=4),
    triton.Config({"BLOCK_HS_HALF": 128}, num_warps=4),
    triton.Config({"BLOCK_HS_HALF": 256}, num_warps=8),
  ], key=["head_size", "interleaved"],
)

@triton.jit	# 表示这是一个Triton kernel,会被编译成GPU代码执行
def _rotary_embedding_kernel(
  out_ptr, x_ptr, cos_ptr, sin_ptr,	# [bsn, d] [HWT, d//2]
	n, d, s, stride_x_row, stride_cos_row, stride_sin_row,
	BLOCK_HS_HALF: tl.constexpr,
):
  row_idx = tl.program_id(0)	# 当前program(类似 CUDA block)的ID,每个 program 处理一行数据
  token_idx = (row_idx // n) % s	# token维度
  
  x_row_ptr = x_ptr + row_idx * stride_x_row
  cos_row_ptr = cos_ptr + token_idx * stride_cos_row
  sin_row_ptr = sin_ptr + token_idx * stride_sin_row
  out_row_ptr = out_ptr + row_idx * stride_x_row
  
  d_half = d // 2
  
  for block_start in range(0, d_half, BLOCK_HS_HALF):	# 每次加载一块
    offset_half = block_start + tl.arnge(0, BLOCK_HS_HALF)
    mask = offset_half < d_half	# 最后一块可能不满
    
    offset_x1 = 2 * offset_half
    offset_x2 = 2 * offset_half + 1
    
    # 只加载有效(mask)位置,无效位置填0
    cos_vals = tl.load(cos_row_ptr + offset_half, mask=mask, other=0.0).to(tl.float32)
    sin_vals = tl.load(sin_row_ptr + offset_half, mask=mask, other=0.0).to(tl.float32)
    x1_vals = tl.load(x_row_ptr + offset_x1, mask=mask, other=0.0).to(tl.float32)
    x2_vals = tl.load(x_row_ptr + offset_x2, mask=mask, other=0.0).to(tl.float32)
    
    # RoPE 本质是二维旋转:x1' = x1 * cos - x2 * sin; x2' = x1 * sin + x2 * cos
    # tl.fma(a, b, c) = a * b + c 融合乘加
    o1_vals = tl.fma(-x2_vals, sin_vals, x1_vals * cos_vals)
    o2_vals = tl.fma(x1_vals, sin_vals, x2_vals * cos_vals)
    
    # 写回
    tl.store(out_ptr + offset_x1, o1_vals.to(x1_vals.dtype), mask=mask)
    tl.store(out_ptr + offset_x2, o2_vals.to(x2_vals.dtype), mask=mask)

交叉注意力层:

python 复制代码
class WanCrossAttention(WanSelfAttention):
  def forward(self, x, context, ):  
    b, n, d = x.size(0), self.num_heads, self.head_dim
    q = self.norm_q(self.q(x)).view(b, -1, n, d)
    k = self.norm_k(self.k(context).view(b, -1, n, d)
    v = self.v(context).view(b, -1, n, d)
		x = self.attn_op(q, k, v)
		x = x.flatten(2)
		x = self.o(x)
		return x

带时间步调制的DiT块:

python 复制代码
class WanAttention(nn.Module):
  def __init__(self, dim, ffn_dim, num_heads, qk_norm=True, cross_attn_norm=False, eps=1e-6):
    super().__init__()
    ...
    self.norm1 = WanLayerNorm(dim, eps)
    self.self_attn = WanSelfAttention(dim, num_heads, qk_norm, eps)
    self.norm3 = WanLayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
    self.cross_sttn = WanCrossAttention(dim, num_heads, qk_norm, eps)
    self.norm2 = WanLayerNorm(dim, eps)
    self.ffn = nn.Sequential(nn.Linear(dim, ffn_dim), nn.GELU(approximate="tanh"), nn.Linear(ffn_dim, dim))
    
    self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
    
  ...
  def forward(self, x, e, freqs, context):
    assert e.dtype == torch.float32
    with amp.autocast("cuda", dtype=torch.float32)
    	e = (self.modulation + e).chunk(6, dim=1)	# 按照第1维切成 6 份
    assert e[0].dtype == torch.float32
    
    # self-attention
    y = self.self_attn(
      (self.norm1(x).float() * (1 + e[1]) + e[0]).type_as(x), # y = x * (1+scale) + shift
      freqs
    )
    with amp.autocast("cuda", dtype=torch.float32):
      x = x + y * e[2].type_as(x)	# x + gate * y
      
    # cross-attention
    x = x + self.cross_attn(self.norm3(x), context)
    
    # ffn
    y = self.ffn(
      (self.norm2(x).float() * (1 + e[4]) + e[3]).type_as(x),	# y = x * (1+scale) + shift
    )
    with amp.autocast("cuda", dtype=torch.float32):
      x = x + y * e[2].type_as(x)	# x + gate * y
      
    return x

完整模型

类初始化

python 复制代码
class WanModel(nn.Module):
  def __init__(self,
               model_type="i2v",
               
             	 patch_size=(1, 2, 2),
               in_dim=36,			# 潜向量+图片+mask 通道数
               
               text_len=512,
               text_dim=4096,	# 提示词
               
               freq_dim=256,	# 时间步编码
               
               dim=5120,
               ffn_dim=13824,
               out_dim=16,	# 潜向量通道数
               num_heads=40,
               num_layers=40,
               qk_norm=True,
               cross_attn_norm=True,
               eps=1e-6,
               sac_config: SACCofig = SACCofig(),# 训练过程中激活者选择保存还是重算
               )
  ...
  self.patch_embedding = nn.Linear(in_dim * patch_size[0] * patch_size[1] * patch_size[2], dim)
  self.text_embedding = nn.Sequential(nn.Linear(text_dim, dim), nn.GELU(approximate="tanh"), nn.Linear(dim, dim))
  self.time_embedding = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
  self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6)

	self.blocks = nn.ModuleList(
    [WanAttentionBlock(dim, ffn_dim, num_heads, qk_norm, cross_attn_norm, eps) for _ in range(num_layers)]
  )
                                       
  self.head = Head(dim, out_dim, patch_size, eps)
	
	assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
	d = dim // num_heads

  self.rope_position_embedding = VideoRopePosition3DEmb(head_dim=d, len_h=128, len_w=128, len_t=32)	# 设置的最大长度

	self.init_weights()
 	self.enable_selective_checkpoint(sac_config)

序列并行开关使能

python 复制代码
def enable_context_parallel(self, process_group: Optional[ProcessGroup] = None):
  cp_ranks = get_process_group_ranks(process_group)
  for block in self.blocks:
    block.self_attn.set_context_parallel_group(process_group=process_group, ranks=cp_ranks, stream=torch.cuda.Stream())
  self._is_context_parallel_enabled = True
  self._cp_group = process_group
  
def disable_context_parallel(self):
  for block in self.blocks:
    block.self_attn.set_context_parallel_group(process_group=None, ranks=None, stream=torch.cuda.Stream())
  self._is_context_parallel_enabled = False
  self._cp_group = None

前向过程

python 复制代码
def forward(self, x_B_C_T_H_W, timestep_B_T, crossattn_emb, y_B_C_T_H_W=None, **kwargs,):
  cp_group = getattr(self, "_cp_group", None)
  cp_enabled = (cp_group is not None) and (cp_goup.size() > 1)
  
  if cp_enabled:
    x_B_C_T_H_W = broadcast(x_B_C_T_H_W, cp_group)
    ...	# 为了一致性做的一些广播操作
    
  assert timesteps_B_T.shape[1] == 1
  t_B = timesteps_B_T[:, 0]
  del kwargs
  if self.model_type == "i2v":
    assert y_B_C_T_H_W is not None
		x_B_C_T_H_W = torch.cat([x_B_T_H_W, y_B_C_T_H_W], dim=1)
  
  # 分片 [B, C, T, H, W] -> [B, L, d_in] -> [B, L, d]
  kt, kh, kw = self.patch_size
  B, _, T_in, H_in, W_in = x_B_T_H_W.shape
  assert (T_in % kt) == 0 and (H_in % kh) == 0 and (W_in % kw) == 0
  T, H, W = T_in // kt, H_in // kh, W_in // kw
  L = T * H * W
  x_B_L_Din = rearrange(
    x_B_C_T_H_W,
    "b c (t kt) (h kh) (w kw) -> b (t h w) (c kt kh kw)",
    kt=kt, kh=kh, kw=kw,
  ).contiguous()
  
  if cp_enabled:
    assert (L % cp_group.size()) == 0, f"L=T*H*W must be divisible by cp_size. Got L={L}, cp={cp_group.size()}."
    x_B_L_Din = split_inputs_cp(x_B_L_Din, seq_dim=1, cp_group=cp_group)
    
  x_B_L_D = self.patch_embedding(x_B_L_Din)
  seq_lens = torch.tensor([u.size(0) for u in x_B_L_D], dtype=torch.long)
  # 不同尺寸下采样、切片得到的序列长度是不一样的,也许代码中没用到,但还是写在这提示你
  
  # 时间步 [B,] -> [B, d_freq] -> [B, d] -> [B, 6d] -> [B, 6, d]
  with amp.autocast("cuda", dtype=torch.float32):
    e_B_D = self.time_embedding(sinusoidal_embedding_1d(self.freq_dim, t_B).float())
    e0_B_6_D = self.time_projection(e_B_D).unflatten(1, (6, self.dim))
    assert e_B_D.dtype == torch.float32 and e0_B_6_D.dtype == torch.float32
  
  # 提示词 [B, L_text, d_text] -> [B, L_text, d]
  context_B_L_D = self.text_embedding(crossattn_emb)
  
  # 位置编码 [L, d_head//2]
  freqs = self.repo_position_embedding.generate_embeddings(torch.Size([B, T, H, W, self.dim])).contiguous()
  if cp_enabled:
    freqs = split_inputs_cp(freqs, seq_dim=self.rope_position_embedding.seq_dim, cp_group=cp_group)
    
  kwargs = dict(
    e=e0_B_6_D,
    freqs=freqs,
    context=context_B_L_D,
  )
  for block_idx, block in enumerate(self.blocks):
    x_B_L_D = block(x_B_L_D, **kwargs)
    
  x_B_L_Dout = self.head(x_B_L_D, **kwargs)
  
  if cp_enabled:
    if torch.is_grad_enabled():
      x_B_L_Dout = cat_outputs_cp_with_grad(x_B_L_Dout, seq_dim=1, cp_group=cp_group)
    else:
      x_B_L_Dout = cat_outputs_cp(x_B_L_Dout, seq_dim=1, cp_group=cp_group)
      
  x_B_T_H_W = rearrange(
    x_B_L_Dout,
    "b (t h w) (kt kh kw d) -> b d (t kt) (h kh) (w kw)",
    kt=kt, kh=kh, kw=kw, t=T, h=H, w=W, d=self.out_dim
  )
  return x_B_T_H_W

七七八八

看到的一些装饰器:

@dataclass用来定义存储数据的类,会自动生成__init__、__repr__等

@register_to_config会自动把__init__参数保存到self.config中便于后续保存和加载

@property把类的方法当做属性来访问

为了保证一致性的广播的实现:

python 复制代码
# 广播一个张量
def robust_broadcast(
  tensor: torch.Tensor, 
  src: int, pg: ProcessGroup, 
  is_check_shape: bool = False
) -> torch.Tensor:
	...
	# 先广播张量的形状
  if distributed.get_rank() == src:
    shape = torch.tensor(tensor.shape).cuda()
	else:
		shape = torch.empty(torch.dim(), dtype=torch.long).cuda()
	if is_check_shape: ...
  dst.broadcast(shape, src, group=pg)
  
  # 非主进程调整张量形状
  if distributed.get_rank() != src:
    tensor = tensor.new_empty(shape.tolist()).type_as(tensor)
    
  # 广播张量数据
  dst.broadcast(tensor, src, group=pg)
  return tensor

def broadcast(
  item: torch.Tensor | str | None, 
  process: ProcessGroup | None = None
) -> torch.Tensor | str | None:
  if process_group is None:
    return item
  
  min_rank = min(get_process_group_rank(process_group))	# 由编号最小 rank 广播
  if isinstance(item, torch.Tensor):
    if item.device.type != "cuda":
      item = item.cuda()
    item = robust_broadcast(item, min_rank, process_group)
  elif item is not None:
    broadcastable_list = [item]
    broadcast_object_list(broadcastable_list, min_rank, group=process_group)
    item = broadcastable_list[0]
  return item

序列并行:

python 复制代码
def split_inputs_cp(x: Tensor, sq_dim: int, cp_group: ProcessGroup) -> Tensor:
  if x.device.type != "cuda":
    x = x.cuda()
    
  cp_ranks = get_process_group_ranks(cp_group)
  cp_size = len(cp_ranks)
  
  assert x.shape[seq_dim] % cp_size == 0, f"{x.shape[seq_dim]} cannot divide cp_size {cp_size}"
  x = x.view(*x.shape[:seq_dim], cp_size, x.shape[seq_dim] // cp_size, *x.shape[(seq_dim+1) :])
  seq_idx = torch.tensor([cp_group.rank()], device=x.device)
  x = x.index_select(seq_dim, seq_idx)
  x = x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_im+2) :])
  return x

# 拼接序列并行得到结果
def cat_outputs_cp(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor:
  world_size = get_world_size(cp_group)
  gathered_tensors = [torch.zero_like(x) for _ in range(world_size)]
  try:
    all_gather(gathered_tensors, x, group=cp_group)
  except RuntimeError as e:
    raise RuntimeError(f"Failed to gather tensors: {e}")
  return torch.cat(gathered_tensors, dim=seq_dim)

# 把梯度带上
def cat_output_cp_with_grad(x: Tensor, seq_dim: int, cp_group: ProcessGroup) -> Tensor:
  ...
  rank = cp_group.rank()
  gathered_tensors[rank] = x  
  # all_gather会创建一个新的 tensor 并赋值数据,计算图断开,需要替换回来以保留原始 tensor 的计算图
  return torch.cat(gathered_tensors, dim=seq_dim)
相关推荐
阿杰学AI32 分钟前
AI核心知识141—大语言模型之 对齐难题(简洁且通俗易懂版)
人工智能·安全·ai·语言模型·自然语言处理·aigc·ai对齐
qq_4609784034 分钟前
html标签怎么表示小字号文字_small标签语义说明【操作】
jvm·数据库·python
qq_4135020236 分钟前
SQL更新语句性能调优技巧_避免对索引列执行函数操作
jvm·数据库·python
2301_8176722638 分钟前
如何正确为包含浮动子元素的父容器设置完整背景色
jvm·数据库·python
2301_8038756142 分钟前
Redis如何通过永不过期策略规避击穿
jvm·数据库·python
2301_816660211 小时前
CSS中relative与absolute的区别_详解相对与绝对定位应用场景
jvm·数据库·python
qq_460978401 小时前
Golang怎么JWT设置过期时间_Golang如何在Claims中配置Token有效期【操作】
jvm·数据库·python
weixin_568996061 小时前
Cgo 中正确设置 C 结构体回调函数指针的完整方案
jvm·数据库·python
Jun6261 小时前
【RV1103】AD4115实现8通道ADC采样,MQTT数据传输,1K采样率
linux·python
LiAo_1996_Y1 小时前
mysql如何限制特定存储过程执行权限_MySQL存储过程安全访问
jvm·数据库·python