【扩散模型(八)】Stable Diffusion 3 diffusers 源码详解2 - DiT 与 MMDiT 相关代码(下)

系列文章目录


文章目录


MMDiT

四层代码结构

  • 上图中的 (a) 为第一层,(b) 为第二层和三层,而 (b) 中 Attention 的实现是在另外一个代码文件(第四层)中。
  • 文本和图像的融合部分是在第四层 (JointAttnProcessor2_0) 中。
  • 第四层的完整结构如下所示,重点放在了 Joint Attention 的具体实现上。
第一层

图(a)对应的代码在 /path/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py , 是从 noise_pred = self.transformer 进入到整个 transformer ( MM-DiT 1 至 d )中

pipeline_stable_diffusion_3.py 中的 call 函数中的以下片段

python 复制代码
noise_pred = self.transformer(
                    hidden_states=latent_model_input,
                    timestep=timestep,
                    encoder_hidden_states=prompt_embeds,
                    pooled_projections=pooled_prompt_embeds,
                    joint_attention_kwargs=self.joint_attention_kwargs,
                    return_dict=False,
                )[0]
第二和第三层

(b)对应的代码在进入 transformer( /path/lib/python3.12/site-packages/diffusers/models/transformers/transformer_sd3.py )后,的 for 循环中,依次进入每个 MM-DiT block(JointTransformerBlock)

第二层: transformer_sd3.py 中的 forward 函数中以下片段进入 for 循环,如果不训练 backbone的话,那么就是从 else 分支进入 block 中。

python 复制代码
 for index_block, block in enumerate(self.transformer_blocks):
     if self.training and self.gradient_checkpointing:

         def create_custom_forward(module, return_dict=None):
             def custom_forward(*inputs):
                 if return_dict is not None:
                     return module(*inputs, return_dict=return_dict)
                 else:
                     return module(*inputs)

             return custom_forward

         ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
         encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
             create_custom_forward(block),
             hidden_states,
             encoder_hidden_states,
             temb,
             **ckpt_kwargs,
         )

     else:                
         encoder_hidden_states, hidden_states = block(
              hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
         )

第三层: block 的实现是在 /path/lib/python3.12/site-packages/diffusers/models/attention.py 中的 JointTransformerBlock 类,其中 hidden_states (noisy latent)和 encoder_hidden_states (text prompt) 分别通过 norm1 和 norm1_context 后,进入了第四层 self.attn

python 复制代码
	norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
	if self.context_pre_only:
	    norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
	else:
	    norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
	        encoder_hidden_states, emb=temb
	    )
	
	# Attention.
	attn_output, context_attn_output = self.attn(
	    hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
	)
第四层

从 self.attn 的 init 中,我们可以找到实际代码在 JointAttnProcessor2_0() 类,即 /path/lib/python3.12/site-packages/diffusers/models/attention_processor.py

下方为 self.attn 的 init 初始化

python 复制代码
if hasattr(F, "scaled_dot_product_attention"):
            processor = JointAttnProcessor2_0()
        else:
            raise ValueError(
                "The current PyTorch version does not support the `scaled_dot_product_attention` function."
            )
        self.attn = Attention(
            query_dim=dim,
            cross_attention_dim=None,
            added_kv_proj_dim=dim,
            dim_head=attention_head_dim // num_attention_heads,
            heads=num_attention_heads,
            out_dim=attention_head_dim,
            context_pre_only=context_pre_only,
            bias=True,
            processor=processor,
        )

下方画出的图片和对应代码即为文图融合的核心关键,在原论文中[1](#1)对这部分结构的解释是 "等价于两个针对文/图模态的独立的 transformers,但在 attention 操作中两种模态联合(joining)在了一起",贴出原文描述来更好理解

python 复制代码
class JointAttnProcessor2_0:
    """Attention processor used typically in processing the SD3-like self-attention projections."""

    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        *args,
        **kwargs,
    ) -> torch.FloatTensor:
        residual = hidden_states

        input_ndim = hidden_states.ndim
        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
        context_input_ndim = encoder_hidden_states.ndim
        if context_input_ndim == 4:
            batch_size, channel, height, width = encoder_hidden_states.shape
            encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size = encoder_hidden_states.shape[0]

        # `sample` projections.
        query = attn.to_q(hidden_states)
        key = attn.to_k(hidden_states)
        value = attn.to_v(hidden_states)

        # `context` projections.
        encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

        # attention
        query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
        key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
        value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        hidden_states = hidden_states = F.scaled_dot_product_attention(
            query, key, value, dropout_p=0.0, is_causal=False
        )
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # Split the attention outputs.
        hidden_states, encoder_hidden_states = (
            hidden_states[:, : residual.shape[1]],
            hidden_states[:, residual.shape[1] :],
        )

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)
        if not attn.context_pre_only:
            encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
        if context_input_ndim == 4:
            encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        return hidden_states, encoder_hidden_states

  1. Scaling Rectified Flow Transformers for High-Resolution Image Synthesis ↩︎
相关推荐
渡我白衣14 分钟前
多路转接之epoll:理论篇
人工智能·神经网络·网络协议·tcp/ip·自然语言处理·信息与通信·tcpdump
明月照山海-15 分钟前
机器学习周报二十八
人工智能·机器学习
weixin_437497776 小时前
读书笔记:Context Engineering 2.0 (上)
人工智能·nlp
cnxy1886 小时前
围棋对弈Python程序开发完整指南:步骤1 - 棋盘基础框架搭建
开发语言·python
喝拿铁写前端6 小时前
前端开发者使用 AI 的能力层级——从表面使用到工程化能力的真正分水岭
前端·人工智能·程序员
goodfat6 小时前
Win11如何关闭自动更新 Win11暂停系统更新的设置方法【教程】
人工智能·禁止windows更新·win11优化工具
北京领雁科技7 小时前
领雁科技反洗钱案例白皮书暨人工智能在反洗钱系统中的深度应用
人工智能·科技·安全
落叶,听雪7 小时前
河南建站系统哪个好
大数据·人工智能·python
清月电子7 小时前
杰理AC109N系列AC1082 AC1074 AC1090 芯片停产替代及资料说明
人工智能·单片机·嵌入式硬件·物联网
Dev7z7 小时前
非线性MPC在自动驾驶路径跟踪与避障控制中的应用及Matlab实现
人工智能·matlab·自动驾驶