1. 模型简介
Dit来自论文《Scalable Diffusion Models with Transformers》,是构成Sora文生视频的核心。在Stable Diffusion中,image被压缩到latents之后,会继续被送到unet构建的attention中,然后再DIT中,作者把unet进行了替换,而使用Transformer。在推理速度和质量上达到了SOTA。
2. 模型架构
论文中的结构图如下:
3. 部件解析
3.1. 模型整体架构
模型整体包含如下架构:
- patch_embedding 标识每个patch
- t_embedding 时间embedding
- pos_embedding sin/cos embedding,用来标识位置信息的
- DiT Blocks 核心部件
- final_layer 对应模型的LayerNorm操作
- unpatchify 对应模型reshape操作
来看一下模型代码整体结构,如下所示:
ini
class DiT(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size=32,
patch_size=2,
in_channels=4,
hidden_size=1152,
depth=28,
num_heads=16,
mlp_ratio=4.0,
class_dropout_prob=0.1,
num_classes=1000,
learn_sigma=True,
):
super().__init__()
#
self.learn_sigma = learn_sigma
# 输入维度
self.in_channels = in_channels
# 输出维度,如果learning_sigma=True,则输出2倍in_channel, 否则out_channel=in_channel
# 一般情况下设置为True
self.out_channels = in_channels * 2 if learn_sigma else in_channels
# 分patch_size * patch_size块
self.patch_size = patch_size
# 多头注意力机制的头数量
self.num_heads = num_heads
# 依照VIT的patch_embed做法,引入的是from timm.models.vision_transformer import PatchEmbed
# 也就是说,将input_size的图像分成patch_size * patch_size个块,然后每个块用hidden_size维向量标识位置
self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
# 创建时间embedder
self.t_embedder = TimestepEmbedder(hidden_size)
# 为了引导无分类标签,将随机扔掉label里面的值
self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
num_patches = self.x_embedder.num_patches
# Will use fixed sin-cos embedding:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
self.blocks = nn.ModuleList([
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
])
#
self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
self.initialize_weights()
def initialize_weights(self):
...
def unpatchify(self, x):
"""
x: (N, T, patch_size**2 * C)
imgs: (N, H, W, C)
"""
c = self.out_channels
p = self.x_embedder.patch_size[0]
h = w = int(x.shape[1] ** 0.5)
assert h * w == x.shape[1]
x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
x = torch.einsum('nhwpqc->nchpwq', x)
imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
return imgs
def forward(self, x, t, y):
"""
Forward pass of DiT.
x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
t: (N,) tensor of diffusion timesteps
y: (N,) tensor of class labels
"""
# 对应patchify
x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2, T是分块数量
t = self.t_embedder(t) # (N, D)
y = self.y_embedder(y, self.training) # (N, D)
# 对应Label 和 timestep相加
c = t + y # (N, D)
# 对应送入Dit Blocks
for block in self.blocks:
x = block(x, c) # (N, T, D)
# 对应layer_norm
x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
# 将patch合并成一整个大图,对应reshape操作
x = self.unpatchify(x) # (N, out_channels, H, W)
return x
def forward_with_cfg(self, x, t, y, cfg_scale):
"""
Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
"""
# https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
half = x[: len(x) // 2]
combined = torch.cat([half, half], dim=0)
model_out = self.forward(combined, t, y)
# For exact reproducibility reasons, we apply classifier-free guidance on only
# three channels by default. The standard approach to cfg applies it to all channels.
# This can be done by uncommenting the following line and commenting-out the line following that.
# eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
eps, rest = model_out[:, :3], model_out[:, 3:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)
3.2. DIT Blocks
python
class DiTBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
# norm1
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# attention
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
# norm2
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
# silu + mlp
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c):
# 过adaLN_module之后, chunk成6份。每一份都是(N, T, D)大小
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
return x
自适应归一化操作:
scss
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)