【扩散模型(八)】Stable Diffusion 3 diffusers 源码详解2 - DiT 与 MMDiT 相关代码(下)

系列文章目录


文章目录


MMDiT

四层代码结构

  • 上图中的 (a) 为第一层,(b) 为第二层和三层,而 (b) 中 Attention 的实现是在另外一个代码文件(第四层)中。
  • 文本和图像的融合部分是在第四层 (JointAttnProcessor2_0) 中。
  • 第四层的完整结构如下所示,重点放在了 Joint Attention 的具体实现上。
第一层

图(a)对应的代码在 /path/lib/python3.12/site-packages/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py , 是从 noise_pred = self.transformer 进入到整个 transformer ( MM-DiT 1 至 d )中

pipeline_stable_diffusion_3.py 中的 call 函数中的以下片段

python 复制代码
noise_pred = self.transformer(
                    hidden_states=latent_model_input,
                    timestep=timestep,
                    encoder_hidden_states=prompt_embeds,
                    pooled_projections=pooled_prompt_embeds,
                    joint_attention_kwargs=self.joint_attention_kwargs,
                    return_dict=False,
                )[0]
第二和第三层

(b)对应的代码在进入 transformer( /path/lib/python3.12/site-packages/diffusers/models/transformers/transformer_sd3.py )后,的 for 循环中,依次进入每个 MM-DiT block(JointTransformerBlock)

第二层: transformer_sd3.py 中的 forward 函数中以下片段进入 for 循环,如果不训练 backbone的话,那么就是从 else 分支进入 block 中。

python 复制代码
 for index_block, block in enumerate(self.transformer_blocks):
     if self.training and self.gradient_checkpointing:

         def create_custom_forward(module, return_dict=None):
             def custom_forward(*inputs):
                 if return_dict is not None:
                     return module(*inputs, return_dict=return_dict)
                 else:
                     return module(*inputs)

             return custom_forward

         ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
         encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
             create_custom_forward(block),
             hidden_states,
             encoder_hidden_states,
             temb,
             **ckpt_kwargs,
         )

     else:                
         encoder_hidden_states, hidden_states = block(
              hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
         )

第三层: block 的实现是在 /path/lib/python3.12/site-packages/diffusers/models/attention.py 中的 JointTransformerBlock 类,其中 hidden_states (noisy latent)和 encoder_hidden_states (text prompt) 分别通过 norm1 和 norm1_context 后,进入了第四层 self.attn

python 复制代码
	norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
	if self.context_pre_only:
	    norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
	else:
	    norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
	        encoder_hidden_states, emb=temb
	    )
	
	# Attention.
	attn_output, context_attn_output = self.attn(
	    hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
	)
第四层

从 self.attn 的 init 中,我们可以找到实际代码在 JointAttnProcessor2_0() 类,即 /path/lib/python3.12/site-packages/diffusers/models/attention_processor.py

下方为 self.attn 的 init 初始化

python 复制代码
if hasattr(F, "scaled_dot_product_attention"):
            processor = JointAttnProcessor2_0()
        else:
            raise ValueError(
                "The current PyTorch version does not support the `scaled_dot_product_attention` function."
            )
        self.attn = Attention(
            query_dim=dim,
            cross_attention_dim=None,
            added_kv_proj_dim=dim,
            dim_head=attention_head_dim // num_attention_heads,
            heads=num_attention_heads,
            out_dim=attention_head_dim,
            context_pre_only=context_pre_only,
            bias=True,
            processor=processor,
        )

下方画出的图片和对应代码即为文图融合的核心关键,在原论文中[1](#1)对这部分结构的解释是 "等价于两个针对文/图模态的独立的 transformers,但在 attention 操作中两种模态联合(joining)在了一起",贴出原文描述来更好理解

python 复制代码
class JointAttnProcessor2_0:
    """Attention processor used typically in processing the SD3-like self-attention projections."""

    def __init__(self):
        if not hasattr(F, "scaled_dot_product_attention"):
            raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: torch.FloatTensor = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        *args,
        **kwargs,
    ) -> torch.FloatTensor:
        residual = hidden_states

        input_ndim = hidden_states.ndim
        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
        context_input_ndim = encoder_hidden_states.ndim
        if context_input_ndim == 4:
            batch_size, channel, height, width = encoder_hidden_states.shape
            encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size = encoder_hidden_states.shape[0]

        # `sample` projections.
        query = attn.to_q(hidden_states)
        key = attn.to_k(hidden_states)
        value = attn.to_v(hidden_states)

        # `context` projections.
        encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
        encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
        encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

        # attention
        query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
        key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
        value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)

        inner_dim = key.shape[-1]
        head_dim = inner_dim // attn.heads
        query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
        value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

        hidden_states = hidden_states = F.scaled_dot_product_attention(
            query, key, value, dropout_p=0.0, is_causal=False
        )
        hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
        hidden_states = hidden_states.to(query.dtype)

        # Split the attention outputs.
        hidden_states, encoder_hidden_states = (
            hidden_states[:, : residual.shape[1]],
            hidden_states[:, residual.shape[1] :],
        )

        # linear proj
        hidden_states = attn.to_out[0](hidden_states)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)
        if not attn.context_pre_only:
            encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
        if context_input_ndim == 4:
            encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        return hidden_states, encoder_hidden_states

  1. Scaling Rectified Flow Transformers for High-Resolution Image Synthesis ↩︎
相关推荐
动感光博10 分钟前
Unity(URP渲染管线)的后处理、动画制作、虚拟相机(Virtual Camera)
开发语言·人工智能·计算机视觉·unity·c#·游戏引擎
欲掩26 分钟前
神经网络与深度学习第六章--循环神经网络(理论)
rnn·深度学习·神经网络
IT古董28 分钟前
【漫话机器学习系列】259.神经网络参数的初始化(Initialization Of Neural Network Parameters)
人工智能·神经网络·机器学习
tyatyatya34 分钟前
神经网络在MATLAB中是如何实现的?
人工智能·神经网络·matlab
ZWaruler1 小时前
二: 字典及函数的使用
python
蚰蜒螟1 小时前
深入解析JVM字节码解释器执行流程(OpenJDK 17源码实现)
开发语言·jvm·python
Jackson@ML1 小时前
一分钟了解大语言模型(LLMs)
人工智能·语言模型·自然语言处理
让学习成为一种生活方式1 小时前
大麦(Hordeum vulgare)中 BAHD 超家族酰基转移酶-文献精读129
人工智能
思茂信息1 小时前
CST软件对OPERA&CST软件联合仿真汽车无线充电站对人体的影响
c语言·开发语言·人工智能·matlab·汽车·软件构建
墨绿色的摆渡人1 小时前
pytorch小记(二十):深入解析 PyTorch 的 `torch.randn_like`:原理、参数与实战示例
人工智能·pytorch·python