Stable Diffusion核心网络结构——U-Net

本文详细详细介绍Stable Diffusion核心网络结构------U-Net,作用,架构,加噪去噪过程损失函数等。

目录

[Stable Diffusion核心网络结构](#Stable Diffusion核心网络结构)

SD模型整体架构初识

U-Net模型

【1】U-Net的核心作用

【2】U-Net模型的完整结构图

(1)ResNetBlock模块

(2)CrossAttention模块

[(3)BasicTransformer Block模块](#(3)BasicTransformer Block模块)

[(4)Spatial Transformer模块](#(4)Spatial Transformer模块)

(5)CrossAttnDownBlock/CrossAttnUpBlock/CrossAttnMidBlock模块

[(6)Stable Diffusion U-Net整体宏观角度小结](#(6)Stable Diffusion U-Net整体宏观角度小结)

[【3】Stable Diffusion中U-Net的训练过程与损失函数](#【3】Stable Diffusion中U-Net的训练过程与损失函数)

[【4】SD模型融合详解(Merge Block Weighted,MBW)](#【4】SD模型融合详解(Merge Block Weighted,MBW))

历史文章


Stable Diffusion核心网络结构

摘录来源:https://zhuanlan.zhihu.com/p/632809634

SD模型整体架构初识

Stable Diffusion模型整体上是一个End-to-End模型 ,主要由VAE变分自编码器,Variational Auto-Encoder),U-Net 以及CLIP Text Encoder三个核心组件构成。

本文主要介绍U-Net,CLIP Text Encoder和VAE请参考:

  1. Stable Diffusion核心网络结构------VAE
  2. Stable Diffusion核心网络结构------CLIP Text Encoder

在FP16精度下Stable Diffusion模型大小2G(FP32:4G),其中U-Net大小1.6G,VAE模型大小160M以及CLIP Text Encoder模型大小235M(约123M参数)。其中U-Net结构包含约860M参数,FP32精度下大小为3.4G左右。
Stable Diffusion整体架构图

U-Net模型

【1】U-Net的核心作用

在Stable Diffusion中,U-Net模型是一个关键核心部分, 能够预测噪声残差 ,并结合Sampling method(调度算法【这里其实不是调度算法,是采样算法 】:DDPM、DDIM、DPM++等)【去噪】 对输入的特征矩阵进行重构逐步将其从随机高斯噪声转化成图片的Latent Feature

【 "并结合Sampling method(调度算法:DDPM、DDIM、DPM++等)"这句话DDPM、DDIM、DPM++是Sampling method,不是调度算法。调度算法是线性调度、余弦调度等】

详情参考:Stable Diffusion的加噪和去噪详解-CSDN博客

具体来说,在前向推理过程 中【不是训练过程 】,SD模型通过反复调用 U-Net ,将预测出的噪声 残差从原噪声矩阵中去除 ,得到逐步去噪后的图像Latent Feature,再通过VAE的Decoder结构将Latent Feature重建成像素级图像。

从噪声到图片的生成过程,其中就是U-Net在不断的为大家去除噪声的过程。

在扩散模型(如Stable Diffusion)中,首先需要明确采样方法调度算法的区别和各自的作用:

1. 采样方法 (Sampling Methods),去噪过程用于推断每一步如何从当前噪声生成下一步图像 ,解决去噪过程中的不确定性。如DDIM、DPM++等。加噪过程不需要采样算法,因为噪声注入是确定性的,按照调度器规则进行。

2. 调度算法 (Schedule Methods),加噪过程中控制每个时间步向图像中注入的噪声比例 ,逐步将图像转化为纯噪声。去噪过程中控制每个时间步去除的噪声比例,确保噪声逐步减少,图像逐步恢复。如线性调度、余弦调度等。

  1. DDPM(Denoising Diffusion Probabilistic Model)

DDPM 是最基础的扩散模型,它通过在多个时间步逐步去除噪声 ,最终从一个接近随机噪声的状态生成高质量图像。这个采样过程通常是随机的, 充满了不确定性

  • 反向扩散:模型学习如何从每个时间步中的带噪声图像中去除噪声,从而逐渐恢复到原始的清晰图像。

特点:

  • 逐步采样每个时间步依赖前一步的结果,因此采样过程是逐步完成的,通常需要数百或上千步来生成图像。

  • 采样效率:DDPM的采样过程比较慢,因为需要多次迭代逐步去噪。

优点:

  • 生成质量高:通过多次迭代,DDPM能生成高质量图像。

  • 理论上稳定:每一步的去噪过程都有明确的概率分布。

缺点:

  • 速度慢:因为需要数百甚至上千个时间步,生成过程非常耗时。

  1. DDIM(Denoising Diffusion Implicit Models)

DDIM 是对DDPM 的一种改进。它的设计目标是减少生成步骤的数量 ,从而提高采样速度,同时保留高质量的生成结果。DDIM通过引入一个非马尔可夫链的 确定性采样方法 ,对时间步的改变,允许模型跳过某些时间步,实现更高效的采样。

DDPM 中,每一步的采样过程通常是随机的 ,从模型预测的噪声分布中随机采样。所以每次生成的图像会有细微差异。

DDIM通过引入一种确定性采样方法 ,它不依赖于每个时间步的随机性,而是通过显式公式一步步更新,而不是从噪声分布中随机采样。这意味着给定相同的初始噪声,DDIM可以在多次生成中输出相同的图像。它能够将原本长时间的采样过程减少为较少的时间步(例如从1000步减少到几十步)。

特点:

  • 跳跃采样:DDIM可以通过控制时间步之间的跳跃,减少采样步骤。

  • 确定性生成:与DDPM的随机采样不同,DDIM的采样过程是确定性的,即给定相同的输入,会产生相同的输出。

优点:

  • 速度更快:通过跳跃时间步,DDIM大幅减少了采样时间。

  • 可调节采样质量:采样步数可以调整,少步采样仍能生成高质量图像。

缺点:

  • 生成质量可能稍差:与DDPM相比,在减少时间步数时,生成图像的质量可能会稍有下降。

  1. DPM++(Denoising Diffusion Probabilistic Models++)

DPM++ 是对扩散采样过程的进一步优化,旨在同时提高生成的效率和质量 。DPM++在采样过程中结合了多种策略,使得生成过程可以在少量时间步中保持图像质量。

DPM++通过优化反向过程中的噪声估计,使每个时间步的去噪过程更加精确。这种方法能够在保持较少时间步的同时,生成更高质量的图像。

特点:

  • 多策略融合:结合了多个不同的采样优化策略,以提高采样速度和质量。

    • 噪声在不同时间步的强度分配

    • 减少采样步数

    • 结合了确定性采样随机性采样 的优势。在某些阶段采用确定性采样 来保持图像的生成一致性 ,而在另一些阶段采用随机性采样 ,增加生成多样性

    • 多重时间步架构 。在早期快速消除大部分噪声 ,而在后期对图像进行更精细的处理。

  • 更好的噪声估计:DPM++的噪声估计过程更加精准,反向扩散过程进行了更精确的建模和改进,减少了在推理阶段的误差累积,能够更好地去除噪声。

优点:

  • 效率与质量平衡:在较少的时间步中依然能生成高质量图像。

  • 速度更快:相比于DDPM,DPM++在大幅减少时间步的同时保持了生成质量。

缺点:

  • 复杂度较高:DPM++的采样算法更复杂,可能在某些实现中不如DDIM或DDPM直观。

总结:DDPM、DDIM、DPM++的对比

采样方法 生成速度 生成质量 工作机制 优缺点
DDPM 多时间步逐步去噪(随机性强) 质量高,但采样步骤多,生成时间长
DDIM 较快 中-高 确定性采样,可跳过时间步 速度快,质量与时间步数相关,可调节
DPM++ 多策略优化噪声估计,减少时间步 高质量与快速采样的平衡,采样过程复杂

选择使用的采样方法:

  1. 如果生成速度是关键 ,如在推理时实时生成图像,建议使用DDIMDPM++,因为它们可以减少时间步数,并且仍能保持良好的图像质量。

  2. 如果质量是第一位 且时间不敏感,DDPM的逐步采样过程可能是最好的选择,尽管它的采样过程较慢。

DPM++在实际应用中提供了最好的速度与质量平衡,特别适合对高质量生成有要求的场景。

【2】U-Net模型的完整结构图

Stable Diffusion中的U-Net,在传统深度学习时代的Encoder-Decoder结构的基础上,增加了以下的模块:

  1. ResNetBlock(包含Time Embedding)模块
  2. Spatial Transformer(SelfAttention + CrossAttention + FeedForward)模块
  3. CrossAttnDownBlock,CrossAttnUpBlock和CrossAttnMidBlock模块

那么各个模块都有什么作用呢?不着急,咱们先看看SD U-Net的整体架构(AIGC算法工程师面试核心考点)。

下图是Rocky梳理的Stable Diffusion U-Net的完整结构图

Stable Diffusion U-Net完整结构图 标题

上图中包含Stable Diffusion U-Net的十四个基本模块:

  1. GSC模块: Stable Diffusion U-Net中的最小组件之一,由GroupNorm+SiLU +Conv三者组成。【VAE用的Swish,SiLU (Sigmoid-Weighted Linear Unit)和 Swish 实际上是同一种激活函数,它们的定义完全相同,只是在不同的文献和框架中使用了不同的名称。】
  2. DownSample模块: Stable Diffusion U-Net中的下采样 组件,使用了Conv(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))进行采下采样
  3. UpSample模块: Stable Diffusion U-Net中的上采样 组件,由插值算法(nearest)+Conv组成。
  4. ResNetBlock模块: 借鉴ResNet模型的"残差结构 ",让网络能够构建的更深的同时,将Time Embedding信息嵌入模型 。【处理的主要是图像特征
  5. CrossAttention模块:文本的语义信息图像的语义信息进行Attention机制,增强输入文本Prompt对生成图片的控制。
  6. SelfAttention模块: SelfAttention模块的整体结构与CrossAttention模块相同,这是输入全部都是图像信息,不再输入文本信息。
  7. FeedForward模块: Attention机制中的经典模块,由GeGlU+Dropout +Linear组成,增强模型的表达能力
  8. BasicTransformer Block模块:LayerNorm +SelfAttention+CrossAttention+FeedForward组成,是多重Attention机制的级联,并且也借鉴ResNet模型的"残差结构"。通过加深网络和多Attention机制,大幅增强模型的学习能力与图文的匹配能力
  9. Spatial Transformer模块:GroupNorm +Conv+BasicTransformer Block +Conv构成,ResNet模型的"残差结构"依旧没有缺席。确保局部特征和全局特征的有效融合
  10. **DownBlock模块:**由两个ResNetBlock模块组成。
  11. **UpBlock_X模块:**由X个ResNetBlock模块和一个UpSample模块组成。
  12. **CrossAttnDownBlock_X模块:**是Stable Diffusion U-Net中Encoder部分的主要模块,由X个(ResNetBlock模块+Spatial Transformer模块)+DownSample模块组成。
  13. **CrossAttnUpBlock_X模块:**是Stable Diffusion U-Net中Decoder部分的主要模块,由X个(ResNetBlock模块+Spatial Transformer模块)+UpSample模块组成。
  14. **CrossAttnMidBlock模块:**是Stable Diffusion U-Net中Encoder和ecoder连接的部分,由ResNetBlock+Spatial Transformer+ResNetBlock组成。

接下来,为大家全面分析SD模型中U-Net结构的核心知识。

(1)ResNetBlock模块

借鉴ResNet模型的"残差结构 ",让网络能够构建的更深的同时,将Time Embedding信息嵌入模型 。【处理的主要是图像特征

Stable Diffusion U-Net完整结构图中展示了完整的ResNetBlock模块,其输入包括Latent Feature和 Time Embedding 。首先Latent Feature经过GSC (GroupNorm+SiLU激活函数+卷积)模块后和Time Embedding (经过SiLU激活函数+全连接层处理)做加和操作之后再经过GSC模块和Skip Connection而来的输入Latent Feature做加和操作 ,进行两次特征融合 后最终得到ResNetBlock模块的Latent Feature输出,增强SD模型的特征学习能力。

GSC模块: Stable Diffusion U-Net中的最小组件之一,由GroupNorm+SiLU +Conv三者组成。【VAE用的Swish,SiLU (Sigmoid-Weighted Linear Unit)和 Swish 实际上是同一种激活函数,它们的定义完全相同,只是在不同的文献和框架中使用了不同的名称。】

值得注意的是,Time Embedding 输入到ResNetBlock模块中,为U-Net引入了时间信息 (时间步长T,T的大小代表了噪声扰动的强度 ),模拟一个随时间变化不断增加不同强度噪声扰动的过程,让SD模型能够更好地理解时间相关性。能告诉U-Net现在是整个迭代过程的哪一步 ,并及时控制 U-Net够根据不同的输入特征和迭代阶段而预测不同的噪声残差

在迭代的早期 ,能够先生成整幅图片的轮廓与边缘特征 ,随着迭代的深入 ,再补充生成图片的高频和细节特征信息

定义Time Embedding的代码如下所示,可以看到Time Embedding的生成方式,主要通过sin和cos函数再经过Linear层进行变换:

def time_step_embedding(self, time_steps: torch.Tensor, max_period: int = 10000):
    half = self.channels // 2
    frequencies = torch.exp(
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=time_steps.device)
    args = time_steps[:, None].float() * frequencies[None]
    return torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
(2)CrossAttention模块

CrossAttention模块是我们使用输入文本Prompt 控制SD模型图片 内容生成的关键一招。

Cross Attention模块 接受两个输入 :一个是ResNetBlock模块的输出图像特征 】,另外一个是输入文本Prompt 经过CLIP Text Encoder模型编码后的Context Embedding文本特征】。

两个输入首先经过Attention机制(将Context Embedding对应的语义信息与Latent Feature中对应的语义信息相耦合 ),输出新的Latent Feature【Q和K计算注意力得分】 ,再将 输出的Latent Feature与输入的Context Embedding再做一次Attention机制**【与V做加权】** ,从而使得SD模型学习到了文本与图片之间的特征对应关系

【通过CrossAttention 机制将 ResNetBlock 输出的图像特征和 CLIP Text Encoder 编码后的文本提示结合起来,使生成的图像与输入的文本提示相匹配。】

看CrossAttention模块的结构图,大家可能会疑惑为什么Context Embedding用来生成K和V,Latent Feature用来生成Q呢?

原因也非常简单:因为在Stable Diffusion中,主要的目的是想把文本信息注入到图像信息中 里,所以用图片token对文本信息做 Attention实现逐步的文本特征提取和耦合

补充CrossAttention模块细节内容,在Stable Diffusion中:

  1. Text Condition信息,通过Cross Attention组件嵌入,作为K Matrix和V Matrix。
  2. 图片的Latent Feature作为Q Matrix。

Text Condition是三维 的,而Latent Feature是四维的,那它们是怎么进行Attention机制的呢?

  1. 每次进行Attention机制前,需要将Latent Feature从四维转到三维[batch_size,channels,height,width]转换到[batch_size,height*width,channels],就能够和Text Condition做CrossAttention操作。
  2. 在完成CrossAttention操作后,我们再将Latent Feature从[batch_size,height*width,channels]转换到[batch_size,channels,height,width] ,这样就又重新回到原来的维度。

还有一点是Text Condition如何跟latent Feature大小保持一致呢?

因为latent embedding不同位置的H和W是不一样的,但是Text Condition是从文本中提取的,其H和W是固定的。这里在CorssAttention模块中有一个非常巧妙的点,那就是在不同特征做Attention操作前,使用Linear层将不同的特征的尺寸大小对齐

摘录于:https://zhuanlan.zhihu.com/p/643420260

(3)BasicTransformer Block模块

BasicTransformer Block模块是在CrossAttention 子模块的基础上,增加了SelfAttention 子模块和Feedforward 子模块共同组成的,并且每个子模块都是一个残差结构 ,这样除了能让文本的语义信息与图像的语义信息更好的融合之外,还能通过SelfAttention机制让模型更好的学习图像数据的特征

  • SelfAttention输入只有图像信息,主要是为了让SD模型更好的学习图像数据的整体特征。
    • 再者,SelfAttention可以将输入图像的不同部分(像素或图像Patch)进行交互,从而实现特征的整合和全局上下文的引入 ,能够让模型建立捕捉图像全局关系的能力,有助于模型理解不同位置的像素之间的依赖关系,以更好地理解图像的语义。
    • 在此基础上,SelfAttention还能减少平移不变性问题,SelfAttention模块可以在不考虑位置的情况下捕捉特征之间的关系,因此具有一定的平移不变性。
  • FeedForward模块 :Attention机制中的经典模块,由GeGlU+Dropout+Linear组成。
(4)Spatial Transformer模块

Spatial Transformer模块 :在BasicTransformer Block 模块基础上,加入GroupNorm两个卷积层。在Encoder中的CrossAttnDownBlock模块,Decoder中的CrossAttnUpBlock模块以及CrossAttnMidBlock模块都包含了大量的Spatial Transformer子模块。

在生成式模型中,GroupNorm的效果一般会比BatchNorm更好,主要有以下一些优势,让其能够成为生成式模型的标配:

  1. 对训练中不同Batch-Size的适应性:在生成式模型中,通常需要使用不同的Batch-Size进行训练和微调。这会导致 BatchNorm在训练期间的不稳定性,而GroupNorm不受Batch-Size的影响,因此更适合生成式模型。
  2. 能适应通道数变化:GroupNorm 是一种基于通道分组的归一化方法,更适应通道数的变化,而不需要大量调整。
  3. 更稳定的训练:生成式模型的训练通常更具挑战性,存在训练不稳定性的问题。GroupNorm可以减轻训练过程中的梯度问题,有助于更稳定的收敛。
  4. 能适应不同数据分布 :生成式模型通常需要处理多模态多模态多模态数据分布,GroupNorm 能够更好地适应不同的数据分布,因为它不像 Batch Normalization那样依赖于整个批量的统计信息。
(5)CrossAttnDownBlock/CrossAttnUpBlock/CrossAttnMidBlock模块

CrossAttnDownBlock:

在Stable Diffusion U-Net的Encoder部分 中,使用了三个CrossAttnDownBlock_X模块,

  • CrossAttnDownBlock_X模块X 个(ResNetBlock模块+Spatial Transformer模块 )+DownSample模块组成。
  • Downsample 是Stable Diffusion U-Net中的下采样 组件,使用了Conv(kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))进行采下采样

CrossAttnUpBlock:

Decoder部分 中,使用了三个CrossAttnUpBlock模块,

  • CrossAttnUpBlock由 X个(ResNetBlock模块+Spatial Transformer模块 )+UpSample模块组成。
  • Upsample是上采样 组件,使用插值算法+卷积来实现,插值算法将输入的Latent Feature尺寸扩大一倍,同时通过一个卷积(kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))改变Latent Feature的通道数,以便于输入后续的模块中。

CrossAttnMidBlock:

是Stable Diffusion U-Net中**EncoderDecoder连接** 的部分,由ResNetBlock+Spatial Transformer+ResNetBlock组成。

补充:

**DownBlock模块:**由两个ResNetBlock模块组成。

**UpBlock_X模块:**由X个ResNetBlock模块和一个UpSample模块组成。

DownBlock 和 CrossAttnDownBlock_X 的作用区别

  • DownBlock 专注于提取图像特征,通过下采样操作逐步压缩图像分辨率,不涉及文本信息。
  • CrossAttnDownBlock_X 在图像特征提取的同时加入了 Cross-Attention ,使得图像 特征能够与文本 提示信息结合,为后续生成提供语义线索。

UpBlock_X 和 CrossAttnUpBlock_X 的作用区别

  • UpBlock_X 仅专注于上采样和图像细节恢复,在解码器中通过逐步提高特征分辨率还原图像。
  • CrossAttnUpBlock_X 在恢复图像的同时,通过 Cross-Attention文本 信息持续注入到图像 特征中,使重构的图像细节能够响应输入文本,确保图像生成的精细度和文本一致性

DownBlock vs CrossAttnDownBlock_X:

  • DownBlock 仅执行图像特征提取和下采样 ,适合普通的【图像】特征提取过程
  • CrossAttnDownBlock_X 则通过 Cross-Attention 将文本信息注入【图像】特征提取过程,使图像特征从编码器阶段开始就受到文本的影响。

UpBlock_X vs CrossAttnUpBlock_X:

  • UpBlock_X 仅执行图像的上采样和细节恢复 ,专注于图像重构
  • CrossAttnUpBlock_X 则在恢复图像分辨率的同时注入文本信息,确保生成的图像符合输入文本的描述。
(6)Stable Diffusion U-Net整体宏观角度小结

从整体上看,不管是在训练过程还是前向推理过程,Stable Diffusion中的U-Net在每次循环迭代中Content Embedding部分始终保持不变,而Time Embedding每次都会发生变化。

和传统深度学习时代的U-Net一样,Stable Diffusion中的U-Net也是不限制输入图片的尺寸,因为这是个基于Transformer和卷积的模型结构

【3】Stable Diffusion中U-Net的训练过程与损失函数

在我们进行Stable Diffusion模型训练 时,VAE部分和CLIP部分 都是冻结 的,所以说官方在训练SD系列模型的时候,训练过程一般主要训练U-Net部分

我们之前我们已经讲过在Stable Diffusion中U-Net主要是进行噪声残差预测,在SD系列模型训练时和DDPM一样采用预测噪声残差的方法来训练U-Net,其损失函数如下所示:

​到这里,Stable Diffusion U-Net的完整核心基础知识就介绍好了。

【4】SD模型融合详解(Merge Block Weighted,MBW)

不管是传统深度学习时代,还是AIGC时代,模型融合永远都是学术界、工业界以及竞赛界的一个重要Trick。

在AI绘画领域,很多AI绘画开源社区里都有SD融合模型的身影,这些融合模型往往集成了多个SD模型的优点,同时规避了不足,让这些SD融合模型在开源社区中很受欢迎。

详细了解SD模型的模型融合过程与方法,大家可能会好奇为什么SD模型融合会在介绍SD U-Net的章节中讲到,原因是SD的模型融合方法主要作用于U-Net部分

首先,我们需要知道SD模型融合的形式,一共三种有如下所示:

  • SD模型 + SD模型 -> 新SD模型
  • SD模型 + LoRA模型 -> 新SD模型
  • LoRA模型 + LoRA模型 -> 新LoRA模型

历史文章

Stable Diffusion概要讲解-CSDN博客

Stable diffusion详细讲解-CSDN博客

Stable Diffusion的加噪和去噪详解-CSDN博客

Diffusion Model 原理-CSDN博客

Stable Diffusion核心网络结构------VAE-CSDN博客

Stable Diffusion核心网络结构------CLIP Text Encoder-CSDN博客

相关推荐
tuan_zhang14 分钟前
第17章 安全培训筑牢梦想根基
人工智能·安全·工业软件·太空探索·战略欺骗·算法攻坚
Antonio91539 分钟前
【opencv】第10章 角点检测
人工智能·opencv·计算机视觉
互联网资讯40 分钟前
详解共享WiFi小程序怎么弄!
大数据·运维·网络·人工智能·小程序·生活
helianying551 小时前
AI赋能零售:ScriptEcho如何提升效率,优化用户体验
前端·人工智能·ux·零售
积鼎科技-多相流在线2 小时前
探索国产多相流仿真技术应用,积鼎科技助力石油化工工程数字化交付
人工智能·科技·cfd·流体仿真·多相流·virtualflow
XianxinMao2 小时前
开源AI崛起:新模型逼近商业巨头
人工智能·开源
格砸2 小时前
Trae使用体验,未来已至?
人工智能·openai·trae
AI2AGI2 小时前
天天AI-20250121:全面解读 AI 实践课程:动手学大模型(含PDF课件)
大数据·人工智能·百度·ai·文心一言
滴滴哒哒答答3 小时前
《自动驾驶与机器人中的SLAM技术》ch4:基于预积分和图优化的 GINS
人工智能·机器人·自动驾驶
多森3 小时前
Cursor太贵?字节Trae可免费用Claude,10分钟带你实现全栈开发
aigc