一、模型架构
一图说明:


主要由编码器和解码器的部分组成,每个阶段相对称,看起来像"U",所以就叫他U-Net
"编码器"抽象语义,但失去了空间细节;"解码器"还原图像时,补回细节得到精确的分割效果
1.encoder(编码器):从输入图像中提取特征

编码器=卷积+下采样
- 卷积(横的箭头):3X3的卷积层提取特征、每个卷积层后,ReLU激活函数被元素的应用到每个特征
- 下采样(向下的箭头):每个阶段后,2X2的最大池化操作对特征进行下采样(步长为2的操作,相当于在图像上滚动一个不重叠的窗口,并选择最大值,降低了特征的空间维度,为了补偿这一点,每次下采样操作后通道数都会翻倍)
2.decoder(解码器):对中间特征进行上采样并产生最终输出
解码器的任务是:把语义信息强但模糊的位置,重新变成清晰的、原始尺寸的分割图。

解码器=上采样+跳跃连接+卷积融合
- 上采样(向上的箭头):对当前特征集进行上采样,然后应用2X2卷积层将通道数减半,上采样操作用于恢复在编码阶段丢失的特征的空间分辨率
- 跳跃连接: (在每个采样的相对应阶段)简单复制编码器对称的特征,将编码器里浅层特征和解码器的深层特征进行融合concat
-
- 编码:包含更多语义信息,这个东西是自行车
- 解码:包含更多空间信息,这些是自行车所在的像素
- 结合在一起就可以得到高精准的分割
- 卷积融合(横的箭头):3X3的卷积层提取特征、每个卷积层后,ReLU激活函数被元素的应用到每个特征上
-
- concat拼接后,通道数变多,不同来源的信息"各说各话",需要多个 3×3的卷积层卷积一下统一一下语言(也就是跨通道的信息融合)。
3.瓶颈


瓶颈(网络中介特征的桥梁) 是连接编码器和解码器的"最低分辨率"部分。虽然空间分辨率最小,但特征表达最丰富(感受野最大)。通常这里会使用两次卷积而不再做下采样,因为已经是最低层级了,要增强表达力,不再减小尺寸。
- 首先我们对特征进行降采样(下)------然后将他们通过可识别的卷积层(横)------最后将他们再次上采样到瓶颈期的相应分辨率(上)
4.一张图片经过Unet的经历
编码器过程:



卷积+下采样的这个过程一致重复,直到达到瓶颈部分
瓶颈过程:

解码器过程:




5.文字理解------每次学完重新品读理解的更到位
我对分割的理解:想要知道当前像素属于哪一个类别。传统方法需要由它周围的一小片区域推出来。但是现在只要这个像素的抽象层级够高,在原始输入图像中的感受野够大。那么也是可以得到这个像素属于哪一个类别,图像分类网络不断的卷积,相当于一个信息不断抽象,感受野不断扩大的过程,卷积到一定层级,这一层所包含的信息已经足够用来分类了。但是这个过程损失了图片的空间分辨率信息,所以抽象到一定层级后又必须进行图像尺寸的还原。
还原的过程中,由于分割物体还涉及到切分物体的边缘,以及上采样过程中,图像尺寸虽然增大了,但是有些像素是填充而来的并不是原来真实的信息,所以这个时候就要将浅层特征进行融合,一方面是浅层特征包含了边缘信息,一方面是浅层特征可以补充由于原先编码网络下采样损失的信息,这样上采样扩大尺寸后的图,信息才够完全。然后解码器上采样之后还要进行卷积的学习,是因为浅层特征和深层特征融合是concat的方式,需要用卷积进行跨通道交流。另一方面考虑,是继续进行学习。使得性能更好,所以上采样后还需要进行两次卷积。
二、模型源码讲解
文件结构:
less
├── src: 搭建U-Net模型代码
├── train_utils: 训练、验证以及多GPU训练相关模块
├── my_dataset.py: 自定义dataset用于读取DRIVE数据集(视网膜血管分割)
├── train.py: 以单GPU为例进行训练
├── train_multi_GPU.py: 针对使用多GPU的用户使用
├── predict.py: 简易的预测脚本,使用训练好的权重进行预测测试
└── compute_mean_std.py: 统计数据集各通道的均值和标准差

1.模型搭建
(1)卷积模块

ini
# 定义一个双卷积模块:两个连续的Conv2d + BN + ReLU
class DoubleConv(nn.Sequential):
def __init__(self, in_channels, out_channels, mid_channels=None):
if mid_channels is None:
mid_channels = out_channels
super(DoubleConv, self).__init__(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), # 第一个卷积
nn.BatchNorm2d(mid_channels), # 批归一化
nn.ReLU(inplace=True), # 激活函数
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), # 第二个卷积
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
(2)下采样模块

ruby
# 下采样模块:MaxPool + DoubleConv
class Down(nn.Sequential):
def __init__(self, in_channels, out_channels):
super(Down, self).__init__(
nn.MaxPool2d(2, stride=2), # 最大池化层,尺寸缩小一半
DoubleConv(in_channels, out_channels) # 双卷积处理特征
)
(3)上采样模块

ini
# 上采样模块:支持双线性插值或反卷积(转置卷积)
class Up(nn.Module):
def __init__(self, in_channels, out_channels, bilinear=True):
super(Up, self).__init__()
if bilinear:
# 使用双线性插值上采样
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
# 上采样后通道数减半(因为concat后通道会变成2倍)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
# 使用转置卷积上采样
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
x1 = self.up(x1) # 上采样
# 计算上采样后与跳跃连接 特征图的尺寸差异
diff_y = x2.size()[2] - x1.size()[2]
diff_x = x2.size()[3] - x1.size()[3]
# 使用padding确保尺寸一致,方便拼接
x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
diff_y // 2, diff_y - diff_y // 2])
x = torch.cat([x2, x1], dim=1) # 拼接跳跃连接特征图和当前特征图(在通道维度)
x = self.conv(x) # 拼接后的特征图再通过DoubleConv
return x
(3)输出模块

(4)UNet主体结构
ini
# UNet 主体结构
class UNet(nn.Module):
def __init__(self,
in_channels: int = 1, # 输入通道数(如灰度图是1,RGB图是3)
num_classes: int = 2, # 输出类别数
bilinear: bool = True, # 是否使用双线性插值上采样
base_c: int = 64): # 初始通道数
super(UNet, self).__init__()
self.in_channels = in_channels
self.num_classes = num_classes
self.bilinear = bilinear
# 编码器部分(下采样)
self.in_conv = DoubleConv(in_channels, base_c) # 第一层双卷积
self.down1 = Down(base_c, base_c * 2) # 第二层:下采样后通道加倍
self.down2 = Down(base_c * 2, base_c * 4)
self.down3 = Down(base_c * 4, base_c * 8)
factor = 2 if bilinear else 1 # 如果使用插值,则最后一层通道减半,匹配上采样
self.down4 = Down(base_c * 8, base_c * 16 // factor)
# 解码器部分(上采样)
self.up1 = Up(base_c * 16, base_c * 8 // factor, bilinear)
self.up2 = Up(base_c * 8, base_c * 4 // factor, bilinear)
self.up3 = Up(base_c * 4, base_c * 2 // factor, bilinear)
self.up4 = Up(base_c * 2, base_c, bilinear)
# 最后一层输出
self.out_conv = OutConv(base_c, num_classes)
(5)前向传播模块
ini
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
x1 = self.in_conv(x) # 编码器第1层输出
x2 = self.down1(x1) # 编码器第2层
x3 = self.down2(x2) # 编码器第3层
x4 = self.down3(x3) # 编码器第4层
x5 = self.down4(x4) # 编码器第5层(最深处)
x = self.up1(x5, x4) # 解码器第1层(与x4跳跃连接)
x = self.up2(x, x3) # 解码器第2层
x = self.up3(x, x2) # 解码器第3层
x = self.up4(x, x1) # 解码器第4层
logits = self.out_conv(x) # 输出层,得到最终分割图
return {"out": logits} # 返回字典形式输出,便于扩展其他输出(比如边界、特征图等)
区分一下Unet网络结构与向前传播
我们写 class UNet(nn.Module):
是在 定义模型的"框架"结构 ,而 def forward(self, x):
是在 告诉这个框架"数据怎么走一遍" 。
它俩缺一不可,合起来才是一个完整的神经网络模型。
类比理解:
想象你在设计一座流水线工厂:
角色 | 作用 |
---|---|
__init__() 方法(也就是 UNet 类里的各种层定义) |
👉 造好了流水线上的各个"机器"组件,比如卷积、上采样、下采样。只是"放好了" |
forward() 方法 |
👉 告诉工人:"原材料从哪里进来,先过哪台机器,再去哪,最后从哪出去" |
光有机器没有路线不行,路线没有机器也不行。UNet 是"框架",forward 是"运行路线"。
2.调整自己的训练集
训练自己的数据集,主要是调整my_dataset.py
官方的DRIVE数据集结构:

python
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
class DriveDataset(Dataset):
def __init__(self, root: str, train: bool, transforms=None):
super(DriveDataset, self).__init__()
# 判断是训练集还是测试集,设置相应的标志
self.flag = "training" if train else "test"
# 生成数据集路径,data_root是训练集或测试集的根目录
data_root = os.path.join(root, "DRIVE", self.flag)
# 检查路径是否存在,如果不存在就报错
assert os.path.exists(data_root), f"path '{data_root}' does not exists."
self.transforms = transforms
# 获取所有.tif格式的图像文件名
img_names = [i for i in os.listdir(os.path.join(data_root, "images")) if i.endswith(".tif")]
# 生成图像路径的列表
self.img_list = [os.path.join(data_root, "images", i) for i in img_names]
# 为每个图像生成对应的人工标注路径
self.manual = [os.path.join(data_root, "1st_manual", i.split("_")[0] + "_manual1.gif")
for i in img_names]
# 检查每个人工标注文件是否存在
for i in self.manual:
if os.path.exists(i) is False:
raise FileNotFoundError(f"file {i} does not exists.")
# 为每个图像生成对应的ROI掩码路径
self.roi_mask = [os.path.join(data_root, "mask", i.split("_")[0] + f"_{self.flag}_mask.gif")
for i in img_names]
# 检查每个ROI掩码文件是否存在
for i in self.roi_mask:
if os.path.exists(i) is False:
raise FileNotFoundError(f"file {i} does not exists.")
def __getitem__(self, idx):
"""
这个方法返回数据集中某个特定位置(idx)的图像和对应的掩码。
1. 读取图像、人工标注图像和ROI掩码图像。
2. 将人工标注转化为[0, 1]范围的数值,ROI掩码会反转(因为它可能是黑白反转的)。
3. 合并人工标注和ROI掩码,得到最终的掩码。
4. 如果有转换函数(transforms),就对图像和掩码一起进行处理。
"""
# 打开原始图像并转换为RGB格式
img = Image.open(self.img_list[idx]).convert('RGB')
# 打开人工标注图像并转换为灰度模式
manual = Image.open(self.manual[idx]).convert('L')
# 将人工标注的值转换为[0, 1]之间
manual = np.array(manual) / 255
# 打开ROI掩码图像并转换为灰度模式
roi_mask = Image.open(self.roi_mask[idx]).convert('L')
# 反转ROI掩码图像,将黑色区域变成白色区域,白色区域变成黑色
roi_mask = 255 - np.array(roi_mask)
# 合并人工标注和ROI掩码,得到最终的掩码,超出[0, 255]范围的部分会被裁剪
mask = np.clip(manual + roi_mask, a_min=0, a_max=255)
# 将NumPy格式的掩码转换回PIL格式,因为转换函数(transforms)通常是针对PIL格式的
mask = Image.fromarray(mask)
# 如果提供了转换函数,就应用它们(例如,数据增强等)
if self.transforms is not None:
img, mask = self.transforms(img, mask)
# 返回处理过的图像和掩码
return img, mask
def __len__(self):
# 返回数据集中的样本数量,即图像的数量
return len(self.img_list)
@staticmethod
def collate_fn(batch):
"""
这个函数用于将多个样本合并成一个批次(batch)。
- batch:是由多个图像和目标掩码组成的列表
- 使用`cat_list`函数将图像和掩码拼接成一个统一大小的批次
"""
# 将批次中的所有图像和目标掩码分别取出
images, targets = list(zip(*batch))
# 将图像列表按最大尺寸拼接成一个批次
batched_imgs = cat_list(images, fill_value=0)
# 将掩码列表按最大尺寸拼接成一个批次
batched_targets = cat_list(targets, fill_value=255)
# 返回拼接好的图像和掩码
return batched_imgs, batched_targets
def cat_list(images, fill_value=0):
"""
这个函数将多个图像拼接成一个批次。
所有图像会被填充成相同的尺寸,填充部分用`fill_value`来填充。
- images:要拼接的图像列表
- fill_value:填充区域的值,默认是0(黑色)
"""
# 找到所有图像中最大的尺寸(宽和高)
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
# 创建一个形状为(batch_size, max_height, max_width)的零矩阵,用来存放批次中的图像
batch_shape = (len(images),) + max_size
# 创建一个全是`fill_value`的矩阵,作为初始的空白批次
batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
# 将每个图像复制到对应的位置,保证原图大小不变,超出部分用填充值填充
for img, pad_img in zip(images, batched_imgs):
pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)
# 返回拼接好的批次图像
return batched_imgs
如果没有mask的话,就把对应mask的代码去掉就好,下面是没有蒙版的代码
python
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset
class DriveDataset(Dataset):
def __init__(self, root: str, train: bool, transforms=None):
super(DriveDataset, self).__init__()
# 根据是否是训练模式,设置数据集的标志为"training"或"test"
self.flag = "training" if train else "test"
# 生成数据集路径,data_root是训练集或测试集的根目录
data_root = os.path.join(root, "DRIVE", self.flag)
# 检查路径是否存在,如果不存在就报错
assert os.path.exists(data_root), f"path '{data_root}' does not exists."
self.transforms = transforms
# 获取所有.tif格式的图像文件名
img_names = [i for i in os.listdir(os.path.join(data_root, "images")) if i.endswith(".tif")]
# 生成图像路径的列表
self.img_list = [os.path.join(data_root, "images", i) for i in img_names]
# 构建对应的人工标注文件路径列表
self.manual = [os.path.join(data_root, "1st_manual", i.split("_")[0] + "_manual1.gif")
for i in img_names]
# 检查人工标注文件是否存在
for i in self.manual:
if os.path.exists(i) is False:
raise FileNotFoundError(f"file {i} does not exists.")
def __getitem__(self, idx):
"""
这个方法返回数据集中某个特定位置(idx)的图像和对应的人工标注。
1. 读取图像文件。
2. 读取人工标注图像并转换为[0, 1]范围的标注。
3. 如果提供了转换函数(transforms),则对图像和标注进行相应处理。
"""
# 打开原始图像并转换为RGB格式
img = Image.open(self.img_list[idx]).convert('RGB')
# 打开人工标注图像并转换为灰度模式
manual = Image.open(self.manual[idx]).convert('L')
# 将人工标注的值转换为[0, 1]之间
manual = np.array(manual) / 255
# 将人工标注转换回PIL图像
manual = Image.fromarray((manual * 255).astype(np.uint8))
# 如果提供了数据增强或变换函数(transforms),则对图像和标注进行处理
if self.transforms is not None:
img, manual = self.transforms(img, manual)
# 返回图像和人工标注
return img, manual
def __len__(self):
# 返回数据集的大小,即图像的数量
return len(self.img_list)
@staticmethod
def collate_fn(batch):
"""
自定义批处理函数,用于将多个样本合并为一个批次。
- batch: 包含多个图像和人工标注的样本列表
- 使用cat_list函数将图像和标注批量处理,填充成相同的大小
"""
# 将批次中的所有图像和人工标注分开
images, targets = list(zip(*batch))
# 将图像列表按最大尺寸合并成一个批次
batched_imgs = cat_list(images, fill_value=0)
# 将标注列表按最大尺寸合并成一个批次
batched_targets = cat_list(targets, fill_value=255)
return batched_imgs, batched_targets
def cat_list(images, fill_value=0):
"""
这个函数将多个图像拼接成一个批次。
所有图像会被填充成相同的尺寸,填充部分用`fill_value`来填充。
- images:要拼接的图像列表
- fill_value:填充区域的值,默认是0(黑色)
"""
# 找到所有图像中最大的尺寸(宽和高)
max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
# 创建一个形状为(batch_size, max_height, max_width)的零矩阵,用来存放批次中的图像
batch_shape = (len(images),) + max_size
# 创建一个全是`fill_value`的矩阵,作为初始的空白批次
batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
# 将每个图像复制到对应的位置,保证原图大小不变,超出部分用填充值填充
for img, pad_img in zip(images, batched_imgs):
pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)
# 返回拼接好的批次图像
return batched_imgs
3.dice损失计算

(1)什么是Dice?

(2) dice_coeff
是怎么计算的?
dice_coeff
计算的是一张图像的 Dice 系数:
- 先把预测的结果和真实标签展开成一维数组。
- 计算它们的 重叠区域(就是预测对的部分)。
- 然后算出 总区域(预测区域加上真实区域的大小)。
- 最后,计算出 Dice 系数:重叠区域越大,Dice 系数越高。
如果有某些部分需要忽略(比如标记为 ignore_index
的区域),它会自动跳过这些部分,确保不计算这些区域。
(3)multiclass_dice_coeff
是怎么计算的?
如果你在做多类分割(比如分割不同类型的物体),它会对每一类计算一个 Dice 系数,然后算出这些类的平均值。比如,假设你分割了人和车,它会分别计算人和车的 Dice 系数,最后取平均。
(4)dice_loss
是怎么计算的?
dice_loss
计算的是 Dice Loss ,就是模型优化的目标,值越小越好。这个损失值其实就是 1 - Dice 系数:
- Dice 系数越大,模型预测越准确,损失越小。
- 因为我们想让损失值最小化,所以目标是让 Dice 系数最大化。
计算函数如下:
python
import torch
import torch.nn as nn
def build_target(target: torch.Tensor, num_classes: int = 2, ignore_index: int = -100):
"""构建Dice系数需要的目标标签"""
dice_target = target.clone() # 克隆目标标签,避免直接修改原数据
if ignore_index >= 0:
# 如果存在ignore_index,找到目标标签中所有等于ignore_index的部分
ignore_mask = torch.eq(target, ignore_index)
dice_target[ignore_mask] = 0 # 将ignore_index部分设置为0
# 将目标标签进行one-hot编码,转换为[N, H, W, C]的格式
dice_target = nn.functional.one_hot(dice_target, num_classes).float()
dice_target[ignore_mask] = ignore_index # 恢复ignore_index的区域
else:
# 如果没有ignore_index,直接进行one-hot编码
dice_target = nn.functional.one_hot(dice_target, num_classes).float()
# 调整维度顺序,从[N, H, W, C]变为[N, C, H, W]
return dice_target.permute(0, 3, 1, 2)
def dice_coeff(x: torch.Tensor, target: torch.Tensor, ignore_index: int = -100, epsilon=1e-6):
"""计算一个batch中所有图片某个类别的Dice系数"""
d = 0. # 初始化Dice系数
batch_size = x.shape[0] # 获取batch的大小
for i in range(batch_size):
# 将第i张图像和目标标签展平为一维数组
x_i = x[i].reshape(-1)
t_i = target[i].reshape(-1)
if ignore_index >= 0:
# 如果有ignore_index,找到目标标签中不为ignore_index的区域
roi_mask = torch.ne(t_i, ignore_index)
x_i = x_i[roi_mask] # 只保留不为ignore_index的部分
t_i = t_i[roi_mask] # 目标标签也做同样的处理
# 计算交集(预测值和目标值的点积)
inter = torch.dot(x_i, t_i)
# 计算并集(预测区域总和 + 目标区域总和)
sets_sum = torch.sum(x_i) + torch.sum(t_i)
if sets_sum == 0:
# 如果并集为0(预测和目标都没有预测到任何目标),就直接返回2 * 交集
sets_sum = 2 * inter
# 计算Dice系数,并且避免除0错误,加入一个小常数epsilon
d += (2 * inter + epsilon) / (sets_sum + epsilon)
# 返回batch内所有图片的平均Dice系数
return d / batch_size
def multiclass_dice_coeff(x: torch.Tensor, target: torch.Tensor, ignore_index: int = -100, epsilon=1e-6):
"""计算所有类别的Dice系数平均值"""
dice = 0. # 初始化Dice系数
for channel in range(x.shape[1]): # 对每一个类别(通道)计算Dice系数
dice += dice_coeff(x[:, channel, ...], target[:, channel, ...], ignore_index, epsilon)
# 返回所有类别的平均Dice系数
return dice / x.shape[1]
def dice_loss(x: torch.Tensor, target: torch.Tensor, multiclass: bool = False, ignore_index: int = -100):
"""计算Dice损失(目标是最小化该损失)"""
x = nn.functional.softmax(x, dim=1) # 对模型输出进行softmax,转化为概率分布
fn = multiclass_dice_coeff if multiclass else dice_coeff # 根据是否多分类选择对应的函数
# 计算Dice损失,目标是最小化,所以是1减去Dice系数
return 1 - fn(x, target, ignore_index=ignore_index)
损失计算过程:
1.将真实的图像标签通过one-hot编码转换成多个通道,医学影像中我们只有两个类别:背景和前景,2个类别,2 个类别的标签会变成一个 2 通道的图像,就是说每个像素会有两个数值,一个代表类别 0 的概率,另一个代表类别 1 的概率。
学习器现在就可以看到这样:

为什么要用这种方式?
(1)多通道表示 :通过 One-hot 编码,每个类别都有一个独立的通道,可以清晰地区分每个类别的区域。这个特性在深度学习模型中很有用,尤其是对于 图像分割任务,我们需要对每个像素进行分类,判断它属于哪个类别。
(2)模型训练:在训练神经网络时,模型会根据这些 One-hot 编码的标签计算损失(比如 Dice 损失或者交叉熵损失)。这有助于网络理解和学习每个像素的类别。
2.接下来我们就计算单个类别的Dice系数, 对于每一张图像,我们把 预测结果 和 真实标签 展平(变成一维),然后比较它们的重叠部分。
然后,我们计算它们的 交集 和 并集,然后用公式算出 Dice 系数

- 计算 Dice 损失 : dice_loss
Dice 损失 就是 1 - Dice 系数: 也就是说,Dice 系数越高,损失越小,模型越好。
4.模型是如何优化调参的?
先看PyTorch 模型训练四大阶段:
阶段 | PyTorch 中做了什么 | 举例说明 |
---|---|---|
1. 初始化模型 | 定义模型结构:构造网络层,比如卷积、池化等(__init__() ) |
就是你写的 UNet(nn.Module) ,在 __init__() 里搭建各个模块 |
2. 前向传播 | 执行 forward() 方法,输入数据通过网络流动,得到输出结果(比如预测图像分割) |
out = model(input) 就是触发了 forward() |
3. 反向传播 | 自动计算损失函数对模型参数的梯度(loss.backward() ) |
loss = criterion(output, label) 后 loss.backward() |
4. 更新参数 | 用优化器更新参数,让模型更准确(optimizer.step() ) |
比如使用 Adam 优化器更新参数:optimizer.step() |
训练epoch: 执行一个训练周期的训练过程(一个epoch意味着训练数据集中的每个样本都被用来训练模型一次。 )
ini
def train_one_epoch(model, optimizer, data_loader, device, epoch, num_classes,
lr_scheduler, print_freq=10, scaler=None):
model.train() # 设置模型为训练模式
metric_logger = utils.MetricLogger(delimiter=" ") # 用于记录日志
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) # 记录学习率
header = 'Epoch: [{}]'.format(epoch)
# 如果是二分类,设置交叉熵中的损失权重(背景和前景的权重不同)
if num_classes == 2:
loss_weight = torch.as_tensor([1.0, 2.0], device=device) # 背景权重1.0,前景权重2.0
else:
loss_weight = None
for image, target in metric_logger.log_every(data_loader, print_freq, header):
image, target = image.to(device), target.to(device)
# 使用自动混合精度(如果有的话)进行前向传播
with torch.cuda.amp.autocast(enabled=scaler is not None):
output = model(image) # 前向传播
# 计算损失
loss = criterion(output, target, loss_weight, num_classes=num_classes, ignore_index=255)
optimizer.zero_grad() # 清空梯度
if scaler is not None:
# 使用混合精度训练时,进行反向传播和参数更新
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
else:
# 常规反向传播
loss.backward()
optimizer.step()
lr_scheduler.step() # 更新学习率
# 获取当前学习率
lr = optimizer.param_groups[0]["lr"]
# 更新日志
metric_logger.update(loss=loss.item(), lr=lr)
return metric_logger.meters["loss"].global_avg, lr # 返回平均损失和当前学习率
步骤:
- 设置模型为训练模式(
model.train()
)。 - 设置损失权重(如果是二分类,前景的损失权重设置为2,背景为1)。
- 遍历训练数据,执行前向传播计算输出。
- 使用混合精度训练时,进行反向传播并更新参数。
- 如果没有混合精度,则进行常规反向传播。
- 更新学习率 (
lr_scheduler.step()
)。 - 返回平均损失和当前学习率。
首先,我们先初始化一个学习率(步长),然后进行前向传播(模型接收输入数据计算出预测结果的过程),计算出损失率(1-dice),再进行反向传播( 根据损失函数计算每个参数的梯度[每个参数对损失的影响]),然后进行梯度下降,利用刚刚的梯度来更新模型的参数,每次更新的幅度就是学习率的大小,在epoch训练的过程中不断调整学习率,最终找到最优的参数,得到一个收敛的模型。
再梳理一下:
1. 初始化学习率(步长):
-
- 在训练开始之前,我们设定一个初始的学习率(通常是一个较小的数值),这个学习率决定了优化过程中每次参数更新的步伐。
2. 前向传播:
-
- 模型接收输入数据(如图片、文本等),然后进行计算,得出预测结果。这是模型的前向传播过程。
3. 计算损失率:
-
- 将模型的预测结果和真实标签(ground truth)进行比较,计算出模型的误差,这个误差用一个损失函数来量化。比如,Dice损失(1 - Dice系数)就可以用来衡量模型预测的分割结果和真实标签之间的差异。
4. 反向传播:
-
- 根据损失值,通过反向传播算法计算出每个参数的梯度。梯度告诉我们每个参数对损失的贡献有多大,即它们需要调整的方向和幅度。
5. 梯度下降更新参数:
-
- 使用计算出来的梯度,按照学习率的大小来更新模型的参数。梯度下降就是根据"梯度"来调整参数,学习率(步长)控制每次调整的大小。
- 更新公式:
参数 = 参数 - 学习率 * 梯度
这里,学习率控制着每次更新的步伐,如果学习率太大,可能会跨过最优解;如果太小,收敛速度会变慢。
6. 调整学习率:
-
- 在训练过程中,学习率可以根据预设的策略进行调整。例如,训练到一定epoch后,逐渐减小学习率,让模型能更精细地调整参数,避免错过最优解。
- 这种调整可以通过学习率调度器(如学习率衰减)来实现。
7. 训练过程:
-
- 在多个epoch的训练过程中,模型会不断调整参数,每次的损失都会变小,最终逐渐收敛到一个最优的参数解。
8. 最终得到收敛的模型:
-
- 经过多次的前向传播、损失计算、反向传播和梯度下降后,模型会找到一组最优参数,这样模型在新数据上的预测效果就最好了。
总结:
- 训练过程的核心就是通过反向传播计算梯度,并通过梯度下降来更新模型的参数,学习率决定了每次更新的步伐。
- 随着训练的进行,通过不断调整学习率,可以让模型更好地收敛到一个最优解