对多模态扩散模型UNet架构的探索

UNet将文本嵌入作为条件信息,与图像特征信息融合,在去噪时遵循文本描述生成相关的图像。

而扩散模型的训练过程,则是一个对"噪声残差"进行预测和优化的循环过程。

这里结合Unet组件和伪代码尝试说明这一过程。

1 条件信息融合

文本到图像生成模型中,UNet不仅接收带噪图像,还必须理解文本提示。

其关键组件及文本信息的融入方式如下表所示。

1.1 时间步嵌入

告知模型当前在去噪过程中的位置,以使用正确的策略。

时间步嵌入与文本条件独立,共同作为模型的情境信息。

具体为时间步条件影响权重缩放和偏置。

1.2 条件嵌入

将文本提示词编码为模型可理解的向量序列。这里文本通过独立的文本编码器(如CLIP)转换为嵌入向量。

1.3 空间变换器

空间变换器模块负责将文本条件与图像视觉特征进行深度融合。

残差后引入,通过交叉注意力机制,图像特征作为查询,文本嵌入作为键和值,进行特征对齐。

1.4 编码器-解码器

标准的Unet型结构,用于捕获多尺度图像特征并通过跳跃连接保留细节。

在每个层级(尤其在设定的attention_resolutions层)插入空间变换器,实现多层次的条件控制。

2 处理过程伪码

这里展示上述流程的核心逻辑,首先指明文本处理核心机制,然后通过代码示例处理过程。

2.1 文本处理机制

UNet在Transformer框架下统一处理图文数据的核心机制,并非直接处理文本信息,而是将文本信息作为一组全局的、语义丰富的键值对(Key/Value),通过交叉注意力机制持续地引导和调制图像特征(Query)的生成过程。

这里展示了在扩散模型UNet中,文本条件如何被整合的过程。

2.2 处理伪码示例

为简化分析,这里使用一层Transformer,实际情况下可能为多层Transformer。

示例代码如下,细节参考注释。

复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class SpatialTransformer(nn.Module):
    """空间变换器模块,实现图像特征与文本条件的交叉注意力。"""
    def __init__(self, in_channels, context_dim):
        super().__init__()
        self.norm = nn.GroupNorm(32, in_channels)  # 归一化图像特征
        self.proj_in = nn.Conv2d(in_channels, in_channels, 1)
        # 核心:多头交叉注意力层
        self.attn = CrossAttention(query_dim=in_channels, context_dim=context_dim)
        self.proj_out = nn.Conv2d(in_channels, in_channels, 1)

    def forward(self, x, context):
        """
        x: 图像特征图 [B, C, H, W]
        context: 文本条件嵌入 [B, L, D]
        """
        batch, channel, height, width = x.shape
        residual = x

        # 1. 对图像特征进行归一化和投影
        x = self.norm(x)
        x = self.proj_in(x)

        # 2. 重塑为序列以进行注意力计算
        x = x.view(batch, channel, height * width).permute(0, 2, 1)  # [B, N, C]

        # 3. 交叉注意力:图像特征为Query,文本嵌入为Key/Value
        x = self.attn(x, context)

        # 4. 重塑回图像特征图并与残差连接
        x = x.permute(0, 2, 1).view(batch, channel, height, width)
        x = self.proj_out(x)
        return x + residual

class UNetConditionalBlock(nn.Module):
    """集成了空间变换器的UNet条件残差块。"""
    def __init__(self, in_channels, time_emb_dim, context_dim):
        super().__init__()
        # 时间步信息融入
        self.time_emb_proj = nn.Linear(time_emb_dim, in_channels * 2)
        # 第一个卷积组
        self.conv1 = nn.Conv2d(in_channels, in_channels, 3, padding=1)
        # 空间变换器(条件注入点)
        self.transformer = SpatialTransformer(in_channels, context_dim)
        # 第二个卷积组
        self.conv2 = nn.Conv2d(in_channels, in_channels, 3, padding=1)

    def forward(self, x, time_emb, context):
        # 时间步条件影响权重缩放和偏置
        scale, shift = self.time_emb_proj(time_emb).chunk(2, dim=1)
        x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None]

        # 主干卷积路径
        residual = x
        x = F.silu(self.conv1(x))
        # 关键步骤:将文本条件context注入图像特征
        x = self.transformer(x, context)
        x = F.silu(self.conv2(x))

        return x + residual

3 扩散训练过程

在上述统一架构下,扩散模型的训练是一个端到端的过程。

其目标是让UNet学会根据文本条件和时间步,从带噪图像中预测出噪声。

3.1 前向加噪

前向加噪即扩散过程,这是一个已知的、无需学习的固定过程。

在训练时,随机采样一个时间步 t,对原始图像 x₀ 添加高斯噪声 ε,得到带噪图像 x'

3.2 条件编码

同时,将文本提示词(如"一只松鼠")通过文本编码器(如CLIP)转换为嵌入向量,作为条件 context

3.3 噪声预测

将带噪图像 x'、时间步 t 的嵌入向量和文本条件 context 一起输入给UNet。UNet的目标是预测出添加到图像中的噪声 ε_θ(x', t, context)。其损失函数通常是最小化预测噪声与真实噪声之间的均方误差:

Loss = ||ε-ε_θ(x', t, context)||²

通过在海量的图像-文本对数据上重复这个过程,UNet逐步学会理解文本语义,并掌握在任何噪声阶段根据该语义引导去噪的能力。

reference


扩散模型数学太难?经典扩散模型DDPM手把手Pytorch代码实现,对照数学公式详解

https://www.zhuanzhi.ai/document/3d2b2838eb66aa7ccc57967b58d8ba4a

U-MixFormer开源 |UNet与Transformer高效设计,Mix-Attention+UNet让精度和参数都很美丽

https://hub.baai.ac.cn/view/33466

操作教程丨搭建MaxKB图文混合文档分析工作流,轻松分析带图片的文档

https://blog.fit2cloud.com/?p=019b31af-fc64-73b5-9825-0c6c0532d5bd

快速上手

https://openi.pcl.ac.cn/kewei/diffusers/raw/commit/0a26527169c152d91052467f7d8096c0bf4fc66c/docs/source/zh/quicktour.md

相关推荐
杜子不疼.2 小时前
AI Agent 开发指南:LangChain + 工具调用,构建自动化任务流
人工智能·langchain·自动化
西西弗Sisyphus2 小时前
神经网络的正向传播和反向传播 包括可视化源码
人工智能·深度学习·神经网络·反向传播·正向传播
芝麻开门-新起点3 小时前
第11章 线上/线下交易系统
大数据
Lethehong3 小时前
昇腾Atlas 800T平台下Qwen-14B大语言模型的SGLang适配与性能实测
人工智能·语言模型·sglang·昇腾npu
杜子不疼.3 小时前
Spring AI 与向量数据库:构建企业级 RAG 智能问答系统
数据库·人工智能·spring
ayingmeizi1633 小时前
AI CRM赋能全链路数字化如何重塑医械企业渠道竞争力?
人工智能
————A3 小时前
从 RAG 召回失败到故障链推理
人工智能·rag
Chase_______4 小时前
AI提效指南:Nano Banana 生成精美PPT与漫画
人工智能·powerpoint
雨大王5124 小时前
汽车产业供应链优化的可行策略及案例分析
人工智能·机器学习