DIT详解

1. 模型简介

Dit来自论文《Scalable Diffusion Models with Transformers》,是构成Sora文生视频的核心。在Stable Diffusion中,image被压缩到latents之后,会继续被送到unet构建的attention中,然后再DIT中,作者把unet进行了替换,而使用Transformer。在推理速度和质量上达到了SOTA。

2. 模型架构

论文中的结构图如下:

3. 部件解析

3.1. 模型整体架构

模型整体包含如下架构:

  1. patch_embedding 标识每个patch
  2. t_embedding 时间embedding
  3. pos_embedding sin/cos embedding,用来标识位置信息的
  4. DiT Blocks 核心部件
  5. final_layer 对应模型的LayerNorm操作
  6. unpatchify 对应模型reshape操作

来看一下模型代码整体结构,如下所示:

ini 复制代码
class DiT(nn.Module):
    """
    Diffusion model with a Transformer backbone.
    """
    def __init__(
        self,
        input_size=32,
        patch_size=2,
        in_channels=4,
        hidden_size=1152,
        depth=28,
        num_heads=16,
        mlp_ratio=4.0,
        class_dropout_prob=0.1,
        num_classes=1000,
        learn_sigma=True,
    ):
        super().__init__()
        # 
        self.learn_sigma = learn_sigma
        # 输入维度
        self.in_channels = in_channels
        # 输出维度,如果learning_sigma=True,则输出2倍in_channel, 否则out_channel=in_channel
        # 一般情况下设置为True
        self.out_channels = in_channels * 2 if learn_sigma else in_channels
        # 分patch_size * patch_size块
        self.patch_size = patch_size
        # 多头注意力机制的头数量
        self.num_heads = num_heads
        
        # 依照VIT的patch_embed做法,引入的是from timm.models.vision_transformer import PatchEmbed
        # 也就是说,将input_size的图像分成patch_size * patch_size个块,然后每个块用hidden_size维向量标识位置
        self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
        # 创建时间embedder
        self.t_embedder = TimestepEmbedder(hidden_size)
        # 为了引导无分类标签,将随机扔掉label里面的值
        self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
        num_patches = self.x_embedder.num_patches
        # Will use fixed sin-cos embedding:
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)

        self.blocks = nn.ModuleList([
            DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
        ])
        # 
        self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
        self.initialize_weights()

    def initialize_weights(self):
        ...

    def unpatchify(self, x):
        """
        x: (N, T, patch_size**2 * C)
        imgs: (N, H, W, C)
        """
        c = self.out_channels
        p = self.x_embedder.patch_size[0]
        h = w = int(x.shape[1] ** 0.5)
        assert h * w == x.shape[1]

        x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
        return imgs

    def forward(self, x, t, y):
        """
        Forward pass of DiT.
        x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
        t: (N,) tensor of diffusion timesteps
        y: (N,) tensor of class labels
        """
        # 对应patchify
        x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2, T是分块数量
        t = self.t_embedder(t)                   # (N, D)
        y = self.y_embedder(y, self.training)    # (N, D)
        # 对应Label 和 timestep相加
        c = t + y                                # (N, D)
        # 对应送入Dit Blocks
        for block in self.blocks:
            x = block(x, c)                      # (N, T, D)
        # 对应layer_norm
        x = self.final_layer(x, c)                # (N, T, patch_size ** 2 * out_channels)
        # 将patch合并成一整个大图,对应reshape操作
        x = self.unpatchify(x)                   # (N, out_channels, H, W)
        return x

    def forward_with_cfg(self, x, t, y, cfg_scale):
        """
        Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
        """
        # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
        half = x[: len(x) // 2]
        combined = torch.cat([half, half], dim=0)
        model_out = self.forward(combined, t, y)
        # For exact reproducibility reasons, we apply classifier-free guidance on only
        # three channels by default. The standard approach to cfg applies it to all channels.
        # This can be done by uncommenting the following line and commenting-out the line following that.
        # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
        eps, rest = model_out[:, :3], model_out[:, 3:]
        cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
        half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
        eps = torch.cat([half_eps, half_eps], dim=0)
        return torch.cat([eps, rest], dim=1)

3.2. DIT Blocks

python 复制代码
class DiTBlock(nn.Module):
    """
    A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
    """
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
        super().__init__()
        # norm1
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        # attention
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
        # norm2
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
        # silu + mlp  
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        # 过adaLN_module之后, chunk成6份。每一份都是(N, T, D)大小
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
        x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        return x

自适应归一化操作:

scss 复制代码
def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
相关推荐
Kai HVZ28 分钟前
《深度学习》——bert框架
人工智能·深度学习·bert
紫雾凌寒1 小时前
自然语言处理|金融舆情解析:智能事件抽取与风险预警之道
人工智能·深度学习·自然语言处理·金融·事件抽取·金融舆情·风险预警
进取星辰1 小时前
PyTorch 深度学习实战(30):模型压缩与量化部署
人工智能·pytorch·深度学习
小白的高手之路3 小时前
常用的卷积神经网络及Pytorch示例实现
人工智能·pytorch·python·深度学习·神经网络·cnn
神经星星3 小时前
在线教程丨YOLO系列重要创新!清华团队发布YOLOE,直击开放场景物体实时检测与分割
人工智能·深度学习·机器学习
Java中文社群4 小时前
SpringAI用嵌入模型操作向量数据库!
后端·aigc·openai
卑微小文4 小时前
惊!代理 IP 助力股海菜鸟变身赛场冠军!
爬虫·深度学习·数据分析
程序员X小鹿4 小时前
5个免费可用AI声音克隆工具,90%的人都不知道!建议收藏,早晚用得上!(附保姆级教程)
aigc
在下_诸葛4 小时前
DeepSeek的API调用 | 结合DeepSeek API文档 | Python环境 | 对话补全(二)
人工智能·python·gpt·prompt·aigc
Json_5 小时前
Vue 初识Hello word
前端·vue.js·深度学习