DIT详解

1. 模型简介

Dit来自论文《Scalable Diffusion Models with Transformers》,是构成Sora文生视频的核心。在Stable Diffusion中,image被压缩到latents之后,会继续被送到unet构建的attention中,然后再DIT中,作者把unet进行了替换,而使用Transformer。在推理速度和质量上达到了SOTA。

2. 模型架构

论文中的结构图如下:

3. 部件解析

3.1. 模型整体架构

模型整体包含如下架构:

  1. patch_embedding 标识每个patch
  2. t_embedding 时间embedding
  3. pos_embedding sin/cos embedding,用来标识位置信息的
  4. DiT Blocks 核心部件
  5. final_layer 对应模型的LayerNorm操作
  6. unpatchify 对应模型reshape操作

来看一下模型代码整体结构,如下所示:

ini 复制代码
class DiT(nn.Module):
    """
    Diffusion model with a Transformer backbone.
    """
    def __init__(
        self,
        input_size=32,
        patch_size=2,
        in_channels=4,
        hidden_size=1152,
        depth=28,
        num_heads=16,
        mlp_ratio=4.0,
        class_dropout_prob=0.1,
        num_classes=1000,
        learn_sigma=True,
    ):
        super().__init__()
        # 
        self.learn_sigma = learn_sigma
        # 输入维度
        self.in_channels = in_channels
        # 输出维度,如果learning_sigma=True,则输出2倍in_channel, 否则out_channel=in_channel
        # 一般情况下设置为True
        self.out_channels = in_channels * 2 if learn_sigma else in_channels
        # 分patch_size * patch_size块
        self.patch_size = patch_size
        # 多头注意力机制的头数量
        self.num_heads = num_heads
        
        # 依照VIT的patch_embed做法,引入的是from timm.models.vision_transformer import PatchEmbed
        # 也就是说,将input_size的图像分成patch_size * patch_size个块,然后每个块用hidden_size维向量标识位置
        self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
        # 创建时间embedder
        self.t_embedder = TimestepEmbedder(hidden_size)
        # 为了引导无分类标签,将随机扔掉label里面的值
        self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
        num_patches = self.x_embedder.num_patches
        # Will use fixed sin-cos embedding:
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)

        self.blocks = nn.ModuleList([
            DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
        ])
        # 
        self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
        self.initialize_weights()

    def initialize_weights(self):
        ...

    def unpatchify(self, x):
        """
        x: (N, T, patch_size**2 * C)
        imgs: (N, H, W, C)
        """
        c = self.out_channels
        p = self.x_embedder.patch_size[0]
        h = w = int(x.shape[1] ** 0.5)
        assert h * w == x.shape[1]

        x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
        x = torch.einsum('nhwpqc->nchpwq', x)
        imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
        return imgs

    def forward(self, x, t, y):
        """
        Forward pass of DiT.
        x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
        t: (N,) tensor of diffusion timesteps
        y: (N,) tensor of class labels
        """
        # 对应patchify
        x = self.x_embedder(x) + self.pos_embed  # (N, T, D), where T = H * W / patch_size ** 2, T是分块数量
        t = self.t_embedder(t)                   # (N, D)
        y = self.y_embedder(y, self.training)    # (N, D)
        # 对应Label 和 timestep相加
        c = t + y                                # (N, D)
        # 对应送入Dit Blocks
        for block in self.blocks:
            x = block(x, c)                      # (N, T, D)
        # 对应layer_norm
        x = self.final_layer(x, c)                # (N, T, patch_size ** 2 * out_channels)
        # 将patch合并成一整个大图,对应reshape操作
        x = self.unpatchify(x)                   # (N, out_channels, H, W)
        return x

    def forward_with_cfg(self, x, t, y, cfg_scale):
        """
        Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
        """
        # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
        half = x[: len(x) // 2]
        combined = torch.cat([half, half], dim=0)
        model_out = self.forward(combined, t, y)
        # For exact reproducibility reasons, we apply classifier-free guidance on only
        # three channels by default. The standard approach to cfg applies it to all channels.
        # This can be done by uncommenting the following line and commenting-out the line following that.
        # eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
        eps, rest = model_out[:, :3], model_out[:, 3:]
        cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
        half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
        eps = torch.cat([half_eps, half_eps], dim=0)
        return torch.cat([eps, rest], dim=1)

3.2. DIT Blocks

python 复制代码
class DiTBlock(nn.Module):
    """
    A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
    """
    def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
        super().__init__()
        # norm1
        self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        # attention
        self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
        # norm2
        self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
        mlp_hidden_dim = int(hidden_size * mlp_ratio)
        approx_gelu = lambda: nn.GELU(approximate="tanh")
        self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
        # silu + mlp  
        self.adaLN_modulation = nn.Sequential(
            nn.SiLU(),
            nn.Linear(hidden_size, 6 * hidden_size, bias=True)
        )

    def forward(self, x, c):
        # 过adaLN_module之后, chunk成6份。每一份都是(N, T, D)大小
        shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
        x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa))
        x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
        return x

自适应归一化操作:

scss 复制代码
def modulate(x, shift, scale):
    return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
相关推荐
↣life♚7 分钟前
从SAM看交互式分割与可提示分割的区别与联系:Interactive Segmentation & Promptable Segmentation
人工智能·深度学习·算法·sam·分割·交互式分割
WenGyyyL2 小时前
研读论文——《用于3D工业异常检测的自监督特征自适应》
人工智能·python·深度学习·机器学习·计算机视觉·3d
Code_流苏4 小时前
《Python星球日记》 第71天:命名实体识别(NER)与关系抽取
python·深度学习·ner·预训练语言模型·关系抽取·统计机器学习·标注方式
北京地铁1号线4 小时前
卷积神经网络(CNN)前向传播手撕
人工智能·pytorch·深度学习
AI不止绘画4 小时前
分享一个可以用GPT打标的傻瓜式SD图片打标工具——辣椒炒肉图片打标助手
人工智能·ai·aigc·图片打标·图片模型训练·lora训练打标·sd打标
乌恩大侠6 小时前
【东枫科技】使用LabVIEW进行深度学习开发
科技·深度学习·labview
视觉语言导航7 小时前
武汉大学无人机视角下的多目标指代理解新基准!RefDrone:无人机场景指代表达理解数据集
人工智能·深度学习·无人机·具身智能
蹦蹦跳跳真可爱5897 小时前
Python----神经网络(《Inverted Residuals and Linear Bottlenecks》论文概括和MobileNetV2网络)
网络·人工智能·python·深度学习·神经网络
Mory_Herbert8 小时前
5.2 参数管理
人工智能·pytorch·深度学习·神经网络·机器学习
macken99998 小时前
音频分类的学习
人工智能·深度学习·学习·计算机视觉·音视频