【扩散模型(八)】Stable Diffusion 3 diffusers 源码详解2 - DiT 与 MMDiT 相关代码(下)

系列文章目录


文章目录


MMDiT

四层代码结构

  • 上图中的 (a) 为第一层,(b) 为第二层和三层,而 (b) 中 Attention 的实现是在另外一个代码文件(第四层)中。
  • 文本和图像的融合部分是在第四层 (JointAttnProcessor2_0) 中。
  • 第四层的完整结构如下所示,重点放在了 Joint Attention 的具体实现上。
第一层

图(a)对应的代码在 /path/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py , 是从 noise_pred = self.transformer 进入到整个 transformer ( MM-DiT 1 至 d )中

pipeline_stable_diffusion_3.py 中的 call 函数中的以下片段

python 复制代码
noise_pred = self.transformer(
                    hidden_states=latent_model_input,
                    timestep=timestep,
                    encoder_hidden_states=prompt_embeds,
                    pooled_projections=pooled_prompt_embeds,
                    joint_attention_kwargs=self.joint_attention_kwargs,
                    return_dict=False,
                )[0]
第二和第三层

(b)对应的代码在进入 transformer( /path/lib/python3.12/site-packages/diffusers/models/transformers/transformer_sd3.py )后,的 for 循环中,依次进入每个 MM-DiT block(JointTransformerBlock)

第二层: transformer_sd3.py 中的 forward 函数中以下片段进入 for 循环,如果不训练 backbone的话,那么就是从 else 分支进入 block 中。

python 复制代码
 for index_block, block in enumerate(self.transformer_blocks):
     if self.training and self.gradient_checkpointing:

         def create_custom_forward(module, return_dict=None):
             def custom_forward(*inputs):
                 if return_dict is not None:
                     return module(*inputs, return_dict=return_dict)
                 else:
                     return module(*inputs)

             return custom_forward

         ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
         encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
             create_custom_forward(block),
             hidden_states,
             encoder_hidden_states,
             temb,
             **ckpt_kwargs,
         )

     else:                
         encoder_hidden_states, hidden_states = block(
              hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
         )

第三层: block 的实现是在 /path/lib/python3.12/site-packages/diffusers/models/attention.py 中的 JointTransformerBlock 类,其中 hidden_states (noisy latent)和 encoder_hidden_states (text prompt) 分别通过 norm1 和 norm1_context 后,进入了第四层 self.attn

python 复制代码
	norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
	if self.context_pre_only:
	    norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
	else:
	    norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
	        encoder_hidden_states, emb=temb
	    )
	
	# Attention.
	attn_output, context_attn_output = self.attn(
	    hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
	)
第四层

从 self.attn 的 init 中,我们可以找到实际代码在 JointAttnProcessor2_0() 类,即 /path/lib/python3.12/site-packages/diffusers/models/attention_processor.py

下方为 self.attn 的 init 初始化

python 复制代码
if hasattr(F, "scaled_dot_product_attention"):
            processor = JointAttnProcessor2_0()
        else:
            raise ValueError(
                "The current PyTorch version does not support the `scaled_dot_product_attention` function."
            )
        self.attn = Attention(
            query_dim=dim,
            cross_attention_dim=None,
            added_kv_proj_dim=dim,
            dim_head=attention_head_dim // num_attention_heads,
            heads=num_attention_heads,
            out_dim=attention_head_dim,
            context_pre_only=context_pre_only,
            bias=True,
            processor=processor,
        )

下方画出的图片和对应代码即为文图融合的核心关键,在原论文中^1^对这部分结构的解释是 "等价于两个针对文/图模态的独立的 transformers,但在 attention 操作中两种模态联合(joining)在了一起",贴出原文描述来更好理解

python 复制代码
class JointAttnProcessor2_0:
    """Attention processor used typically in processing the SD3-like self-attention projections."""

    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        *args,
        **kwargs,
    ) -> torch.FloatTensor:
        residual = hidden_states

        input_ndim = hidden_states.ndim
        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
        context_input_ndim = encoder_hidden_states.ndim
        if context_input_ndim == 4:
            batch_size, channel, height, width = encoder_hidden_states.shape
            encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size = encoder_hidden_states.shape[0]

        # `sample` projections.
        query = attn.to_q(hidden_states)
        key = attn.to_k(hidden_states)
        value = attn.to_v(hidden_states)

        # `context` projections.
        encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

        # attention
        query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
        key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
        value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        hidden_states = hidden_states = F.scaled_dot_product_attention(
            query, key, value, dropout_p=0.0, is_causal=False
        )
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # Split the attention outputs.
        hidden_states, encoder_hidden_states = (
            hidden_states[:, : residual.shape[1]],
            hidden_states[:, residual.shape[1] :],
        )

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)
        if not attn.context_pre_only:
            encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
        if context_input_ndim == 4:
            encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        return hidden_states, encoder_hidden_states

  1. Scaling Rectified Flow Transformers for High-Resolution Image Synthesis ↩︎
相关推荐
-Nemophilist-15 分钟前
机器学习与深度学习-1-线性回归从零开始实现
深度学习·机器学习·线性回归
云空21 分钟前
《Python 与 SQLite:强大的数据库组合》
数据库·python·sqlite
成富1 小时前
文本转SQL(Text-to-SQL),场景介绍与 Spring AI 实现
数据库·人工智能·sql·spring·oracle
程序员X小鹿1 小时前
全部免费!6款AI对口型神器,让照片开口说话唱歌,早晚用得上,建议收藏!(附保姆级教程)
aigc
凤枭香1 小时前
Python OpenCV 傅里叶变换
开发语言·图像处理·python·opencv
CSDN云计算1 小时前
如何以开源加速AI企业落地,红帽带来新解法
人工智能·开源·openshift·红帽·instructlab
测试杂货铺1 小时前
外包干了2年,快要废了。。
自动化测试·软件测试·python·功能测试·测试工具·面试·职场和发展
艾派森1 小时前
大数据分析案例-基于随机森林算法的智能手机价格预测模型
人工智能·python·随机森林·机器学习·数据挖掘
hairenjing11231 小时前
在 Android 手机上从SD 卡恢复数据的 6 个有效应用程序
android·人工智能·windows·macos·智能手机
真忒修斯之船1 小时前
大模型分布式训练并行技术(三)流水线并行
面试·llm·aigc