UNet将文本嵌入作为条件信息,与图像特征信息融合,在去噪时遵循文本描述生成相关的图像。
而扩散模型的训练过程,则是一个对"噪声残差"进行预测和优化的循环过程。
这里结合Unet组件和伪代码尝试说明这一过程。
1 条件信息融合
文本到图像生成模型中,UNet不仅接收带噪图像,还必须理解文本提示。
其关键组件及文本信息的融入方式如下表所示。
1.1 时间步嵌入
告知模型当前在去噪过程中的位置,以使用正确的策略。
时间步嵌入与文本条件独立,共同作为模型的情境信息。
具体为时间步条件影响权重缩放和偏置。
1.2 条件嵌入
将文本提示词编码为模型可理解的向量序列。这里文本通过独立的文本编码器(如CLIP)转换为嵌入向量。
1.3 空间变换器
空间变换器模块负责将文本条件与图像视觉特征进行深度融合。
残差后引入,通过交叉注意力机制,图像特征作为查询,文本嵌入作为键和值,进行特征对齐。
1.4 编码器-解码器
标准的Unet型结构,用于捕获多尺度图像特征并通过跳跃连接保留细节。
在每个层级(尤其在设定的attention_resolutions层)插入空间变换器,实现多层次的条件控制。
2 处理过程伪码
这里展示上述流程的核心逻辑,首先指明文本处理核心机制,然后通过代码示例处理过程。
2.1 文本处理机制
UNet在Transformer框架下统一处理图文数据的核心机制,并非直接处理文本信息,而是将文本信息作为一组全局的、语义丰富的键值对(Key/Value),通过交叉注意力机制持续地引导和调制图像特征(Query)的生成过程。
这里展示了在扩散模型UNet中,文本条件如何被整合的过程。
2.2 处理伪码示例
为简化分析,这里使用一层Transformer,实际情况下可能为多层Transformer。
示例代码如下,细节参考注释。
import torch
import torch.nn as nn
import torch.nn.functional as F
class SpatialTransformer(nn.Module):
"""空间变换器模块,实现图像特征与文本条件的交叉注意力。"""
def __init__(self, in_channels, context_dim):
super().__init__()
self.norm = nn.GroupNorm(32, in_channels) # 归一化图像特征
self.proj_in = nn.Conv2d(in_channels, in_channels, 1)
# 核心:多头交叉注意力层
self.attn = CrossAttention(query_dim=in_channels, context_dim=context_dim)
self.proj_out = nn.Conv2d(in_channels, in_channels, 1)
def forward(self, x, context):
"""
x: 图像特征图 [B, C, H, W]
context: 文本条件嵌入 [B, L, D]
"""
batch, channel, height, width = x.shape
residual = x
# 1. 对图像特征进行归一化和投影
x = self.norm(x)
x = self.proj_in(x)
# 2. 重塑为序列以进行注意力计算
x = x.view(batch, channel, height * width).permute(0, 2, 1) # [B, N, C]
# 3. 交叉注意力:图像特征为Query,文本嵌入为Key/Value
x = self.attn(x, context)
# 4. 重塑回图像特征图并与残差连接
x = x.permute(0, 2, 1).view(batch, channel, height, width)
x = self.proj_out(x)
return x + residual
class UNetConditionalBlock(nn.Module):
"""集成了空间变换器的UNet条件残差块。"""
def __init__(self, in_channels, time_emb_dim, context_dim):
super().__init__()
# 时间步信息融入
self.time_emb_proj = nn.Linear(time_emb_dim, in_channels * 2)
# 第一个卷积组
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, padding=1)
# 空间变换器(条件注入点)
self.transformer = SpatialTransformer(in_channels, context_dim)
# 第二个卷积组
self.conv2 = nn.Conv2d(in_channels, in_channels, 3, padding=1)
def forward(self, x, time_emb, context):
# 时间步条件影响权重缩放和偏置
scale, shift = self.time_emb_proj(time_emb).chunk(2, dim=1)
x = x * (1 + scale[:, :, None, None]) + shift[:, :, None, None]
# 主干卷积路径
residual = x
x = F.silu(self.conv1(x))
# 关键步骤:将文本条件context注入图像特征
x = self.transformer(x, context)
x = F.silu(self.conv2(x))
return x + residual
3 扩散训练过程
在上述统一架构下,扩散模型的训练是一个端到端的过程。
其目标是让UNet学会根据文本条件和时间步,从带噪图像中预测出噪声。
3.1 前向加噪
前向加噪即扩散过程,这是一个已知的、无需学习的固定过程。
在训练时,随机采样一个时间步 t,对原始图像 x₀ 添加高斯噪声 ε,得到带噪图像 x'。
3.2 条件编码
同时,将文本提示词(如"一只松鼠")通过文本编码器(如CLIP)转换为嵌入向量,作为条件 context。
3.3 噪声预测
将带噪图像 x'、时间步 t 的嵌入向量和文本条件 context 一起输入给UNet。UNet的目标是预测出添加到图像中的噪声 ε_θ(x', t, context)。其损失函数通常是最小化预测噪声与真实噪声之间的均方误差:
Loss = ||ε-ε_θ(x', t, context)||²
通过在海量的图像-文本对数据上重复这个过程,UNet逐步学会理解文本语义,并掌握在任何噪声阶段根据该语义引导去噪的能力。
reference
扩散模型数学太难?经典扩散模型DDPM手把手Pytorch代码实现,对照数学公式详解
https://www.zhuanzhi.ai/document/3d2b2838eb66aa7ccc57967b58d8ba4a
U-MixFormer开源 |UNet与Transformer高效设计,Mix-Attention+UNet让精度和参数都很美丽
https://hub.baai.ac.cn/view/33466
操作教程丨搭建MaxKB图文混合文档分析工作流,轻松分析带图片的文档
https://blog.fit2cloud.com/?p=019b31af-fc64-73b5-9825-0c6c0532d5bd
快速上手