计算机视觉语义分割------U-Net(Convolutional Networks for Biomedical Image Segmentation)
文章目录
- [计算机视觉语义分割------U-Net(Convolutional Networks for Biomedical Image Segmentation)](#计算机视觉语义分割——U-Net(Convolutional Networks for Biomedical Image Segmentation))
-
- 摘要
- Abstract
- 一、U-Net
-
- [1. 基本思想](#1. 基本思想)
- [2. 收缩路径(Contracting path)](#2. 收缩路径(Contracting path))
- [3. 扩展路径(Expanding path)](#3. 扩展路径(Expanding path))
- [4. U-Net代码实践](#4. U-Net代码实践)
-
- [4.1 数据集选择](#4.1 数据集选择)
- [4.2 加载数据集](#4.2 加载数据集)
- [4.3 U-Net模型搭建](#4.3 U-Net模型搭建)
- [4.4 训练U-Net](#4.4 训练U-Net)
- [4.5 预测](#4.5 预测)
- [5. U-Net与FCN的关系](#5. U-Net与FCN的关系)
- [6. U-Net与生成模型](#6. U-Net与生成模型)
-
- [6.1 Stable Diffusion中的U-Net](#6.1 Stable Diffusion中的U-Net)
- [6.2 U-Net在生成模型中的改进](#6.2 U-Net在生成模型中的改进)
- [6.3 U-Net在生成模型中训练目标的变化](#6.3 U-Net在生成模型中训练目标的变化)
- 总结
摘要
本周的工作聚焦于U-Net网络在语义分割任务中的研究与实现。U-Net作为一种经典的全卷积神经网络,因其编码器-解码器的对称结构和跳跃连接机制,广泛应用于生物医学图像分割。本次工作详细解析了U-Net的网络结构,包括收缩路径、扩展路径,以及其与FCN的关系。此外,还探讨了U-Net在生成模型中的改进及其在Stable Diffusion中的应用。通过搭建基于PyTorch的U-Net网络,完成了数据预处理、模型训练及预测,并在医学图像分割任务中取得了良好的效果。实验表明,U-Net可以有效减少信息丢失,提升分割精度,为后续生成模型中的应用奠定了基础。
Abstract
This week's work focused on the study and implementation of the U-Net network in semantic segmentation tasks. U-Net, as a classic fully convolutional neural network, is widely used in biomedical image segmentation due to its symmetrical encoder-decoder structure and skip connection mechanism. This report provides a detailed analysis of the U-Net architecture, including its contracting path, expanding path, and its relationship with FCN. Additionally, improvements to U-Net in generative models and its application in Stable Diffusion were explored. By building a PyTorch-based U-Net network, tasks such as data preprocessing, model training, and prediction were completed, achieving promising results in medical image segmentation. The experiments demonstrate that U-Net effectively reduces information loss and improves segmentation accuracy, laying a solid foundation for its application in generative models.
一、U-Net
1. 基本思想
U-Net由 Olaf Ronneberger 等人为生物医学图像分割而开发。该架构包含两条路径。第一条路径是收缩路径(也称为编码器),用于捕获图像中的上下文。编码器只是卷积层和最大池化层的传统堆栈。第二条路径是扩展路径(也称为解码器),用于使用转置卷积实现精确定位。因此,它是一个端到端的全卷积网络 (FCN),即它只包含卷积层,不包含任何密集层,因此它可以接受任何大小的图像。
2. 收缩路径(Contracting path)
收缩路径使用卷积层 和池化层 的组合来提取和捕获图像中的特征,同时减少其空间维度。现在让我们来看看下面收缩路径 中的5 个block中的每一个。
Block 1:
- 一张尺寸为572² 的 输入图像被输入到 U-Net 中。该输入图像仅包含1 个灰度通道。
- 然后将两个3x3 卷积 层 (无填充)应用于输入图像,每个层后跟一个ReLU 层。同时将通道数增加到**64 个,**以便捕获更高级别的特征。
- 然后应用步长为2 的2x2 最大池化 层。这会将特征图下采样为其大小的一半 ,即284²。
Block 2:
- 与Block 1 一样,两个3x3 卷积 层(无填充)应用于Block 1 的输出,每个层后又跟有一个ReLU层 。在每个新块中,特征通道 数量加倍,现在为128。
- 接下来,再次将2x2 最大池化 层应用于生成的特征图,将空间尺寸减少一半 至140²。
Block 3:
- 与Block 1和 Block 2中的步骤相似,最后将特征图下采样至68²。
Block 4:
- 与Block 1和 Block 2中的步骤相似,最后将特征图下采样至32²。
Block 5:
- 在收缩路径的最后一个块中 ,特征通道 数在每个块中翻倍,达到1024个。
- 此块还包含两个3x3 卷积 层(未填充),每个层后面都有一个ReLU 层。但是,出于对称目的,我只包含一个层,并将第二层包含在扩展路径中。
在提取出复杂的特征之后,特征图就会进入到扩展路径的运算阶段。
3. 扩展路径(Expanding path)
扩展路径使用卷积 和转置卷积操作来组合学习到的特征并对输入特征图进行上采样,直到生成分割图。扩展路径与收缩路径非常相似,下面将逐步了解每个Block。
注意:上图的灰色箭头代表"跳跃连接",跳跃连接用于将图像直接从收缩路径发送到扩展路径,而无需经过所有块。这样可以保留和学习高级和低级特征,从而减少收缩路径期间发生的任何信息丢失。
Block 5:
- 从收缩路径继续,应用第二个3x3 卷积 (未填充),其后接着一个ReLU层。
- 然后应用2x2 卷积(上卷积)层,将空间维度上采样两倍 ,并将通道数减半为512。
Block 4:
- 然后使用跳跃连接 将收缩路径中的相应特征图连接起来 ,将特征通道数 加倍至1024 。请注意,必须裁剪此连接以匹配扩展路径的尺寸。
- 应用两个3x3 卷积层(未填充) ,每个层后面都有一个ReLU 层,将通道数 减少到512。
- 之后,应用2x2 卷积(上卷积)层,将空间维度上采样两倍 ,并将通道数减半为256。
Block 3:
- 与Block 5和 Block 4中的步骤相似,最后将特征图上采样两倍至200² ,并将通道数减半为128。
Block 2:
- 与Block 5和 Block 4中的步骤相似,最后将特征图上采样两倍至392² ,并将通道数减半为64。
Block 1:
- 在扩展路径的最后一块 中,连接 跳跃连接后共有128 个通道。
- 接下来,在特征图上应用两个3x3 卷积层(未填充) ,中间的ReLU层将特征通道 数减少到64。
- 最后,使用1x1 卷积 层,然后是激活层 (二元分类中的Sigmoid函数),将通道 数减少到所需的类别数。在本例中,为 2 个类别,因为二元分类通常用于医学成像。
在扩展路径上对特征图进行上采样后,应该生成分割图,其中每个像素都被单独分类。
4. U-Net代码实践
4.1 数据集选择
收集的30张果蝇的电镜细胞图,分辨率为512x512,上图第一张为原图,第二张则为标签图。
训练集图片:
训练集标签:
4.2 加载数据集
数据加载要做哪些处理,是根据任务和数据集而决定的,对于我们的分割任务,不用做太多处理,但由于数据量很少,仅30张,我们可以使用一些数据增强方法,来扩大我们的数据集。
Pytorch 给我们提供了一个方法,方便我们加载数据,我们可以使用这个框架,去加载我们的数据。看下伪代码:
python
# ================================================================== #
# Input pipeline for custom dataset #
# ================================================================== #
# You should build your custom dataset as below.
class CustomDataset(torch.utils.data.Dataset):
def __init__(self):
# TODO
# 1. Initialize file paths or a list of file names.
pass
def __getitem__(self, index):
# TODO
# 1. Read one data from file (e.g. using numpy.fromfile, PIL.Image.open).
# 2. Preprocess the data (e.g. torchvision.Transform).
# 3. Return a data pair (e.g. image and label).
pass
def __len__(self):
# You should change 0 to the total size of your dataset.
return 0
# You can then use the prebuilt data loader.
custom_dataset = CustomDataset()
train_loader = torch.utils.data.DataLoader(dataset=custom_dataset,
batch_size=64,
shuffle=True)
这是一个标准的模板,我们就使用这个模板,来加载数据,定义标签,以及进行数据增强。
创建一个dataset.py文件,编写代码如下:
python
import torch
import cv2
import os
import glob
from torch.utils.data import Dataset
import random
class ISBI_Loader(Dataset):
def __init__(self, data_path):
# 初始化函数,读取所有data_path下的图片
self.data_path = data_path
self.imgs_path = glob.glob(os.path.join(data_path, 'image/*.png'))
def augment(self, image, flipCode):
# 使用cv2.flip进行数据增强,filpCode为1水平翻转,0垂直翻转,-1水平+垂直翻转
flip = cv2.flip(image, flipCode)
return flip
def __getitem__(self, index):
# 根据index读取图片
image_path = self.imgs_path[index]
# 根据image_path生成label_path
label_path = image_path.replace('image', 'label')
# 读取训练图片和标签图片
image = cv2.imread(image_path)
label = cv2.imread(label_path)
# 将数据转为单通道的图片
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
label = cv2.cvtColor(label, cv2.COLOR_BGR2GRAY)
image = image.reshape(1, image.shape[0], image.shape[1])
label = label.reshape(1, label.shape[0], label.shape[1])
# 处理标签,将像素值为255的改为1
if label.max() > 1:
label = label / 255
# 随机进行数据增强,为2时不做处理
flipCode = random.choice([-1, 0, 1, 2])
if flipCode != 2:
image = self.augment(image, flipCode)
label = self.augment(label, flipCode)
return image, label
def __len__(self):
# 返回训练集大小
return len(self.imgs_path)
if __name__ == "__main__":
isbi_dataset = ISBI_Loader("data/train/")
print("数据个数:", len(isbi_dataset))
train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
batch_size=2,
shuffle=True)
for image, label in train_loader:
print(image.shape)
使用dataset.py文件加载数据的输出结果为:
4.3 U-Net模型搭建
如果完全按照论文的结构,模型输出的尺寸会稍微小于图片输入的尺寸,如果使用论文的网络结构需要在结果输出后,做一个 resize 操作。为了省去这一步,我们可以修改网络,使网络的输出尺寸正好等于图片的输入尺寸。
创建unet_parts.py文件,编写如下代码:
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = torch.tensor([x2.size()[2] - x1.size()[2]])
diffX = torch.tensor([x2.size()[3] - x1.size()[3]])
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
创建unet_model.py文件,编写如下代码:
python
"""利用unet_parts.py中的部件组装形成完整的Unet网络"""
import torch.nn.functional as F
from .unet_parts import *
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=True):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = DoubleConv(n_channels, 64)
self.down1 = Down(64, 128)
self.down2 = Down(128, 256)
self.down3 = Down(256, 512)
self.down4 = Down(512, 512)
self.up1 = Up(1024, 256, bilinear)
self.up2 = Up(512, 128, bilinear)
self.up3 = Up(256, 64, bilinear)
self.up4 = Up(128, 64, bilinear)
self.outc = OutConv(64, n_classes)
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
if __name__ == '__main__':
net = UNet(n_channels=3, n_classes=1)
print(net)
这样调整过后,网络的输出尺寸就与图片的输入尺寸相同了。
4.4 训练U-Net
我们今天的任务,只需要分割出细胞边缘,也就是一个很简单的二分类任务,所以我们可以使用BCEWithLogitsLoss。
BCEWithLogitsLoss是 Pytorch 提供的用来计算二分类交叉熵的函数:
这个公式就是 Logistic 回归的损失函数,它利用的是 Sigmoid 函数阈值在[0,1]这个特性来进行分类的。
python
from model.unet_model import UNet
from utils.dataset import ISBI_Loader
from torch import optim
import torch.nn as nn
import torch
def train_net(net, device, data_path, epochs=40, batch_size=1, lr=0.00001):
# 加载训练集
isbi_dataset = ISBI_Loader(data_path)
train_loader = torch.utils.data.DataLoader(dataset=isbi_dataset,
batch_size=batch_size,
shuffle=True)
# 定义RMSprop算法
optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
# 定义Loss算法
criterion = nn.BCEWithLogitsLoss()
# best_loss统计,初始化为正无穷
best_loss = float('inf')
# 训练epochs次
for epoch in range(epochs):
# 训练模式
net.train()
# 按照batch_size开始训练
for image, label in train_loader:
optimizer.zero_grad()
# 将数据拷贝到device中
image = image.to(device=device, dtype=torch.float32)
label = label.to(device=device, dtype=torch.float32)
# 使用网络参数,输出预测结果
pred = net(image)
# 计算loss
loss = criterion(pred, label)
print('Loss/train', loss.item())
# 保存loss值最小的网络参数
if loss < best_loss:
best_loss = loss
torch.save(net.state_dict(), 'best_model.pth')
# 更新参数
loss.backward()
optimizer.step()
if __name__ == "__main__":
# 选择设备,有cuda用cuda,没有就用cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载网络,图片单通道1,分类为1。
net = UNet(n_channels=1, n_classes=1)
# 将网络拷贝到deivce中
net.to(device=device)
# 指定训练集地址,开始训练
data_path = "data/train/"
train_net(net, device, data_path)
训练的输出LOSS为:
python
D:\Pytorch\Anaconda3\envs\d2l\python.exe E:/Unet/train.py
Loss/train 0.7607325911521912
Loss/train 0.7042417526245117
Loss/train 0.6359086036682129
Loss/train 0.5970470905303955
Loss/train 0.5956597328186035
Loss/train 0.5168007016181946
Loss/train 0.509236216545105
Loss/train 0.4616139233112335
Loss/train 0.4775354862213135
Loss/train 0.44973015785217285
Loss/train 0.4584605395793915
Loss/train 0.4683246314525604
Loss/train 0.42488470673561096
Loss/train 0.4471128582954407
Loss/train 0.4211571216583252
..............................
Loss/train 0.08856159448623657
Loss/train 0.09392605721950531
Loss/train 0.10895076394081116
Loss/train 0.10066412389278412
Loss/train 0.10104762017726898
Loss/train 0.09606537222862244
Loss/train 0.09012659639120102
Loss/train 0.10549496114253998
Loss/train 0.0879889726638794
Loss/train 0.10525484383106232
Loss/train 0.10839355736970901
Process finished with exit code 0
4.5 预测
模型训练好了,我们可以用它在测试集上看下效果。在工程根目录创建 predict.py 文件,编写如下代码:
python
import glob
import numpy as np
import torch
import os
import cv2
from model.unet_model import UNet
if __name__ == "__main__":
# 选择设备,有cuda用cuda,没有就用cpu
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 加载网络,图片单通道,分类为1。
net = UNet(n_channels=1, n_classes=1)
# 将网络拷贝到deivce中
net.to(device=device)
# 加载模型参数
net.load_state_dict(torch.load('best_model.pth', map_location=device))
# 测试模式
net.eval()
# 读取所有图片路径
tests_path = glob.glob('data/test/*.png')
# 遍历所有图片
for test_path in tests_path:
# 保存结果地址
save_res_path = test_path.split('.')[0] + '_res.png'
# 读取图片
img = cv2.imread(test_path)
# 转为灰度图
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
# 转为batch为1,通道为1,大小为512*512的数组
img = img.reshape(1, 1, img.shape[0], img.shape[1])
# 转为tensor
img_tensor = torch.from_numpy(img)
# 将tensor拷贝到device中,只用cpu就是拷贝到cpu中,用cuda就是拷贝到cuda中。
img_tensor = img_tensor.to(device=device, dtype=torch.float32)
# 预测
pred = net(img_tensor)
# 提取结果
pred = np.array(pred.data.cpu()[0])[0]
# 处理结果
pred[pred >= 0.5] = 255
pred[pred < 0.5] = 0
# 保存图片
cv2.imwrite(save_res_path, pred)
运行完后,在data/test目录下,看到预测结果:
5. U-Net与FCN的关系
U-Net是一种基于FCN的图像分割模型,它使用了FCN的思路和结构,并在其基础上添加了编码器-解码器结构,以提高分割的准确性和效率。
U-Net与FCN不同的是,U-Net的网络结构为对称的"U"型。具体来说,U-Net的下采样过程与FCN类似,特征图经过两次卷积操作以及一次池化操作,尺寸缩小为原来的一半,维度扩张为两倍。但是U-Net的上采样过程与FCN不同,相比起FCN的直接从底层特征图进行32倍或16倍等的上采样,U-Net的上采样过程与下采样过程对称,特征图经过两次卷积以及一次转置卷积操作后,尺寸扩大到原来两倍,维度为原来一半。
同时,U-Net的跳跃连接也与FCN不同。U-Net的跳跃连接将下采样时同尺寸的特征图裁剪后,与上采样时的特征图在维度上融合在一起,并作为解码器的下一个模块的输入。而不是FCN的将特征图相加后,直接上采样至原图的尺寸。
6. U-Net与生成模型
参考:深入浅出完整解析Stable Diffusion中U-Net的前世今生与核心知识 - 知乎
Unet为何可以作为生成模型(Diffusion)的backbone?/为何在AIGC时代,U-Net成为了Stable Diffusion这个划时代模型的关键结构?
这主要归结为U-net的四个特质:
- U-Net中Encoder模块的压缩特质。作为Encoder模块最初的应用,输入的图像经过下采样,抽取出比原图小得多的高维特征,相当于进行了压缩操作。这和Stable diffusion的latent逻辑不谋而合。
- U-Net中Decoder模块的去噪特质,作为Decoder模块最初的应用。
- U-Net整体结构上的简洁、稳定和高效,使得其在Stable Diffusion中能够从容的迭代去噪声,能够撑起Stable Diffusion的整个图像生成逻辑。
- Encoder-Decoder结构的强兼容性,让U-Net不管是在分割领域,还是在生成领域,都能和Transformer等新生代模型的从容融合。
6.1 Stable Diffusion中的U-Net
Stable Diffusion中的U-Net包含约860M的参数,在float32的精度下,约占3.4G的存储空间。
在上图中可以看到,U-Net是Stable Diffusion中的核心模块。U-Net主要在"扩散"循环中对高斯噪声矩阵进行迭代降噪,并且每次预测的噪声都由文本和timesteps进行引导,将预测的噪声在随机高斯噪声矩阵上去除,最终将随机高斯噪声矩阵转换成图片的隐特征。
在U-Net执行"扩散"循环的过程中,Content Embedding始终保持不变,而Time Embedding每次都会发生变化。每次U-Net预测的噪声都在Latent特征中减去,并且将迭代后的Latent作为U-Net的新输入。
6.2 U-Net在生成模型中的改进
Stable Diffusion中的U-Net,在Encoder-Decoder结构的基础上,增加了Time Embedding模块,Spatial Transformer(Cross Attention)模块和self-attention模块。
(1)Time Embedding模块
首先,什么是Time Embedding呢?
Time Embedding(时间嵌入)是一种在时间序列数据中用于表示时间信息的技术。时间序列数据是指按照时间顺序排列的数据,例如股票价格、天气数据、传感器数据等。时间嵌入的目的是将时间作为一个特征进行编码,以便在深度学习模型中更好地学习时间相关性特征。
Time Embedding的基本思想是将时间信息映射到一个连续的向量空间,使得时间之间的关系可以被模型学习和利用。
Time Embedding的使用可以帮助深度学习模型更好地理解时间相关性,从而提高模型的性能。比如在Stable Diffusion中,将Time Embedding引入U-Net中,帮助其在扩散过程中从容预测噪声。
Stable Diffusion需要迭代多次对噪音进行逐步预测,使用Time Embedding就可以将time编码到网络中,从而在每一次迭代中让U-Net更加合适的噪声预测。
讲完Time Embedding的核心基础知识,我们再解析一下Stable Diffusion中U-Net的Time Embeddings模块是如何构造的:
可以看到,Time Embeddings模块 + Encoder模块中原本的卷积层,组成了一个Residual Block结构。它包含两个卷积层,一个Time Embedding和一个skip Connection。而这里的全连接层将Time Embedding变换为和Latent Feature一样的维度。最后通过两者的加和完成time的编码。
(2)Spatial Transformer(Cross Attention)模块
在Stable Diffusion中,使用了Spatial Transformer来表示类Cross Attention模块。
Cross Attention是一种多头注意力机制,它可以在两个不同的输入序列之间建立关联,并且可以将其中一个输入序列的信息传递给另一个输入序列。
在计算机视觉中,Cross Attention可以用于将图像与文本之间的关联建立。例如,在图像字幕生成任务中,Cross Attention可以将图像中的区域与生成的文字之间建立关联,以便生成更准确的描述。
Stable Diffusion中使用Cross Attention模块控制文本信息和图像信息的融合交互,通俗来说,控制U-Net把噪声矩阵的某一块与文本里的特定信息相对应。
讲完Cross Attention的核心基础知识,我们再解析一下Stable Diffusion中U-Net的Cross Attention模块是如何构造的:
可以看到,Latent Feature和Context Embedding作为输入,将两者进行Cross Attenetion操作,将图像信息和文本信息进行了融合,整体上是一个经典的Transformer流程。
6.3 U-Net在生成模型中训练目标的变化
在Stable Diffusion中,U-Net在不断的训练过程中主要学会了一件事,那就是去噪!去噪!去噪!
想要让U-Net能够高效去噪,并获得图像的隐特征,我们就要让U-Net知道什么是噪声数据。
于是我们在训练的预处理过程中,向训练集有策略地加入噪声。
这个加噪策略主要包括设定不同级别的噪声,比如说0-100共101个强度的噪声,在每个Batch中,随机加入1-n个101强度序列中的噪声,生成噪声图片。
加噪+噪声强度+加噪次数+原数据集,构成了Stable Diffusion中U-Net训练数据的基石。
有了数据预处理的大逻辑,在训练过程中,U-Net需要在已知噪声强度的条件下,不断学习提升从噪声图片中计算出噪声的能力。
需要注意的是,Stable Diffusion中的U-Net并不直接输出无噪声的原数据,而是去预测原数据上所加过的噪声。
Stable Diffusion中U-Net的训练过程
如上图所示,Stable Diffusion中U-Net的训练一共分四步:
- 从训练集中选取一张加噪过的图片和噪声强度,比如上图的加噪街道图和噪声强度3。
- 将数据输入U-Nnet,并且预测噪声矩阵。
- 将预测的噪声矩阵和实际噪声矩阵(Label)进行误差的计算。
- 通过反向传播更新U-Net的参数。
在推理阶段中,我们将U-Net预测的噪声不断在噪声图片中减去就能恢复出图片的隐特征了。
当我们完成了U-Net在Stable Diffusion中的训练,如果我们再将噪声强度和噪声图输入U-Net,那么U-Net就能较准确地预测出有加在原素材上的噪声:
Stable Diffusion中U-Net预测噪声
有了U-Net对噪声的强预测能力,在Stable Diffusion的推理过程中,我们就可以使用U-Net循环预测噪声,并在噪声图上逐步减去这些被预测出来的噪声,从而得到一个我们想要的高质量的图像隐特征,去噪流程如下图所示:
总结
本次周报深入探讨了U-Net网络在医学图像分割和生成模型中的应用。通过对U-Net结构的逐层解析,结合数据增强和模型训练,验证了其在语义分割任务中的卓越表现。实验表明,U-Net凭借其特有的跳跃连接机制,能够在上下文信息捕获与精确定位之间取得良好的平衡。此外,U-Net在Stable Diffusion中的应用进一步彰显了其在生成模型领域的重要性,为未来的研究和优化提供了方向。