U-Net网络+代码实操【保姆级教程理解一文全搞懂】

一、模型架构

一图说明:

主要由编码器和解码器的部分组成,每个阶段相对称,看起来像"U",所以就叫他U-Net

"编码器"抽象语义,但失去了空间细节;"解码器"还原图像时,补回细节得到精确的分割效果

1.encoder(编码器):从输入图像中提取特征

编码器=卷积+下采样

  • 卷积(横的箭头):3X3的卷积层提取特征、每个卷积层后,ReLU激活函数被元素的应用到每个特征
  • 下采样(向下的箭头):每个阶段后,2X2的最大池化操作对特征进行下采样(步长为2的操作,相当于在图像上滚动一个不重叠的窗口,并选择最大值,降低了特征的空间维度,为了补偿这一点,每次下采样操作后通道数都会翻倍)

2.decoder(解码器):对中间特征进行上采样并产生最终输出

解码器的任务是:把语义信息强但模糊的位置,重新变成清晰的、原始尺寸的分割图

解码器=上采样+跳跃连接+卷积融合

  • 上采样(向上的箭头):对当前特征集进行上采样,然后应用2X2卷积层将通道数减半,上采样操作用于恢复在编码阶段丢失的特征的空间分辨率
  • 跳跃连接: (在每个采样的相对应阶段)简单复制编码器对称的特征,将编码器里浅层特征和解码器的深层特征进行融合concat
    • 编码:包含更多语义信息,这个东西是自行车
    • 解码:包含更多空间信息,这些是自行车所在的像素
    • 结合在一起就可以得到高精准的分割
  • 卷积融合(横的箭头):3X3的卷积层提取特征、每个卷积层后,ReLU激活函数被元素的应用到每个特征上
    • concat拼接后,通道数变多,不同来源的信息"各说各话",需要多个 3×3的卷积层卷积一下统一一下语言(也就是跨通道的信息融合)。

3.瓶颈

瓶颈(网络中介特征的桥梁) 是连接编码器和解码器的"最低分辨率"部分。虽然空间分辨率最小,但特征表达最丰富(感受野最大)。通常这里会使用两次卷积而不再做下采样,因为已经是最低层级了,要增强表达力,不再减小尺寸。

  • 首先我们对特征进行降采样(下)------然后将他们通过可识别的卷积层(横)------最后将他们再次上采样到瓶颈期的相应分辨率(上)

4.一张图片经过Unet的经历

编码器过程:

卷积+下采样的这个过程一致重复,直到达到瓶颈部分

瓶颈过程:

解码器过程:

5.文字理解------每次学完重新品读理解的更到位

我对分割的理解:想要知道当前像素属于哪一个类别。传统方法需要由它周围的一小片区域推出来。但是现在只要这个像素的抽象层级够高,在原始输入图像中的感受野够大。那么也是可以得到这个像素属于哪一个类别,图像分类网络不断的卷积,相当于一个信息不断抽象,感受野不断扩大的过程,卷积到一定层级,这一层所包含的信息已经足够用来分类了。但是这个过程损失了图片的空间分辨率信息,所以抽象到一定层级后又必须进行图像尺寸的还原。

还原的过程中,由于分割物体还涉及到切分物体的边缘,以及上采样过程中,图像尺寸虽然增大了,但是有些像素是填充而来的并不是原来真实的信息,所以这个时候就要将浅层特征进行融合,一方面是浅层特征包含了边缘信息,一方面是浅层特征可以补充由于原先编码网络下采样损失的信息,这样上采样扩大尺寸后的图,信息才够完全。然后解码器上采样之后还要进行卷积的学习,是因为浅层特征和深层特征融合是concat的方式,需要用卷积进行跨通道交流。另一方面考虑,是继续进行学习。使得性能更好,所以上采样后还需要进行两次卷积。

二、模型源码讲解

文件结构:

less 复制代码
  ├── src: 搭建U-Net模型代码
  ├── train_utils: 训练、验证以及多GPU训练相关模块
  ├── my_dataset.py: 自定义dataset用于读取DRIVE数据集(视网膜血管分割)
  ├── train.py: 以单GPU为例进行训练
  ├── train_multi_GPU.py: 针对使用多GPU的用户使用
  ├── predict.py: 简易的预测脚本,使用训练好的权重进行预测测试
  └── compute_mean_std.py: 统计数据集各通道的均值和标准差

1.模型搭建

(1)卷积模块

ini 复制代码
# 定义一个双卷积模块:两个连续的Conv2d + BN + ReLU
class DoubleConv(nn.Sequential):
    def __init__(self, in_channels, out_channels, mid_channels=None):
        if mid_channels is None:
            mid_channels = out_channels
        super(DoubleConv, self).__init__(
            nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),  # 第一个卷积
            nn.BatchNorm2d(mid_channels),  # 批归一化
            nn.ReLU(inplace=True),  # 激活函数
            nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),  # 第二个卷积
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

(2)下采样模块

ruby 复制代码
# 下采样模块:MaxPool + DoubleConv
class Down(nn.Sequential):
    def __init__(self, in_channels, out_channels):
        super(Down, self).__init__(
            nn.MaxPool2d(2, stride=2),  # 最大池化层,尺寸缩小一半
            DoubleConv(in_channels, out_channels)  # 双卷积处理特征
        )

(3)上采样模块

ini 复制代码
# 上采样模块:支持双线性插值或反卷积(转置卷积)
class Up(nn.Module):
    def __init__(self, in_channels, out_channels, bilinear=True):
        super(Up, self).__init__()
        if bilinear:
            # 使用双线性插值上采样
            self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
            # 上采样后通道数减半(因为concat后通道会变成2倍)
            self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
        else:
            # 使用转置卷积上采样
            self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
            self.conv = DoubleConv(in_channels, out_channels)

    def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
        x1 = self.up(x1)  # 上采样
        # 计算上采样后与跳跃连接 特征图的尺寸差异
        diff_y = x2.size()[2] - x1.size()[2]
        diff_x = x2.size()[3] - x1.size()[3]

        # 使用padding确保尺寸一致,方便拼接
        x1 = F.pad(x1, [diff_x // 2, diff_x - diff_x // 2,
                        diff_y // 2, diff_y - diff_y // 2])

        x = torch.cat([x2, x1], dim=1)  # 拼接跳跃连接特征图和当前特征图(在通道维度)
        x = self.conv(x)  # 拼接后的特征图再通过DoubleConv
        return x

(3)输出模块

(4)UNet主体结构

ini 复制代码
# UNet 主体结构
class UNet(nn.Module):
    def __init__(self,
                 in_channels: int = 1,  # 输入通道数(如灰度图是1,RGB图是3)
                 num_classes: int = 2,  # 输出类别数
                 bilinear: bool = True,  # 是否使用双线性插值上采样
                 base_c: int = 64):  # 初始通道数
        super(UNet, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes
        self.bilinear = bilinear

        # 编码器部分(下采样)
        self.in_conv = DoubleConv(in_channels, base_c)  # 第一层双卷积
        self.down1 = Down(base_c, base_c * 2)  # 第二层:下采样后通道加倍
        self.down2 = Down(base_c * 2, base_c * 4)
        self.down3 = Down(base_c * 4, base_c * 8)
        factor = 2 if bilinear else 1  # 如果使用插值,则最后一层通道减半,匹配上采样
        self.down4 = Down(base_c * 8, base_c * 16 // factor)

        # 解码器部分(上采样)
        self.up1 = Up(base_c * 16, base_c * 8 // factor, bilinear)
        self.up2 = Up(base_c * 8, base_c * 4 // factor, bilinear)
        self.up3 = Up(base_c * 4, base_c * 2 // factor, bilinear)
        self.up4 = Up(base_c * 2, base_c, bilinear)

        # 最后一层输出
        self.out_conv = OutConv(base_c, num_classes)

(5)前向传播模块

ini 复制代码
    def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
        x1 = self.in_conv(x)   # 编码器第1层输出
        x2 = self.down1(x1)    # 编码器第2层
        x3 = self.down2(x2)    # 编码器第3层
        x4 = self.down3(x3)    # 编码器第4层
        x5 = self.down4(x4)    # 编码器第5层(最深处)

        x = self.up1(x5, x4)   # 解码器第1层(与x4跳跃连接)
        x = self.up2(x, x3)    # 解码器第2层
        x = self.up3(x, x2)    # 解码器第3层
        x = self.up4(x, x1)    # 解码器第4层
        logits = self.out_conv(x)  # 输出层,得到最终分割图

        return {"out": logits}  # 返回字典形式输出,便于扩展其他输出(比如边界、特征图等)

区分一下Unet网络结构与向前传播

我们写 class UNet(nn.Module): 是在 定义模型的"框架"结构 ,而 def forward(self, x): 是在 告诉这个框架"数据怎么走一遍"

它俩缺一不可,合起来才是一个完整的神经网络模型。

类比理解:

想象你在设计一座流水线工厂:

角色 作用
__init__()方法(也就是 UNet 类里的各种层定义) 👉 造好了流水线上的各个"机器"组件,比如卷积、上采样、下采样。只是"放好了"
forward()方法 👉 告诉工人:"原材料从哪里进来,先过哪台机器,再去哪,最后从哪出去"

光有机器没有路线不行,路线没有机器也不行。UNet 是"框架",forward 是"运行路线"。

2.调整自己的训练集

训练自己的数据集,主要是调整my_dataset.py

官方的DRIVE数据集结构:

python 复制代码
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset


class DriveDataset(Dataset):
    def __init__(self, root: str, train: bool, transforms=None):
        super(DriveDataset, self).__init__()
        # 判断是训练集还是测试集,设置相应的标志
        self.flag = "training" if train else "test"
        # 生成数据集路径,data_root是训练集或测试集的根目录
        data_root = os.path.join(root, "DRIVE", self.flag)
        # 检查路径是否存在,如果不存在就报错
        assert os.path.exists(data_root), f"path '{data_root}' does not exists."
        self.transforms = transforms
        # 获取所有.tif格式的图像文件名
        img_names = [i for i in os.listdir(os.path.join(data_root, "images")) if i.endswith(".tif")]
        # 生成图像路径的列表
        self.img_list = [os.path.join(data_root, "images", i) for i in img_names]
        # 为每个图像生成对应的人工标注路径
        self.manual = [os.path.join(data_root, "1st_manual", i.split("_")[0] + "_manual1.gif")
                       for i in img_names]
        # 检查每个人工标注文件是否存在
        for i in self.manual:
            if os.path.exists(i) is False:
                raise FileNotFoundError(f"file {i} does not exists.")

        # 为每个图像生成对应的ROI掩码路径
        self.roi_mask = [os.path.join(data_root, "mask", i.split("_")[0] + f"_{self.flag}_mask.gif")
                         for i in img_names]
        # 检查每个ROI掩码文件是否存在
        for i in self.roi_mask:
            if os.path.exists(i) is False:
                raise FileNotFoundError(f"file {i} does not exists.")

    def __getitem__(self, idx):
        """
        这个方法返回数据集中某个特定位置(idx)的图像和对应的掩码。
        1. 读取图像、人工标注图像和ROI掩码图像。
        2. 将人工标注转化为[0, 1]范围的数值,ROI掩码会反转(因为它可能是黑白反转的)。
        3. 合并人工标注和ROI掩码,得到最终的掩码。
        4. 如果有转换函数(transforms),就对图像和掩码一起进行处理。
        """
        # 打开原始图像并转换为RGB格式
        img = Image.open(self.img_list[idx]).convert('RGB')
        # 打开人工标注图像并转换为灰度模式
        manual = Image.open(self.manual[idx]).convert('L')
        # 将人工标注的值转换为[0, 1]之间
        manual = np.array(manual) / 255
        # 打开ROI掩码图像并转换为灰度模式
        roi_mask = Image.open(self.roi_mask[idx]).convert('L')
        # 反转ROI掩码图像,将黑色区域变成白色区域,白色区域变成黑色
        roi_mask = 255 - np.array(roi_mask)
        # 合并人工标注和ROI掩码,得到最终的掩码,超出[0, 255]范围的部分会被裁剪
        mask = np.clip(manual + roi_mask, a_min=0, a_max=255)

        # 将NumPy格式的掩码转换回PIL格式,因为转换函数(transforms)通常是针对PIL格式的
        mask = Image.fromarray(mask)

        # 如果提供了转换函数,就应用它们(例如,数据增强等)
        if self.transforms is not None:
            img, mask = self.transforms(img, mask)

        # 返回处理过的图像和掩码
        return img, mask

    def __len__(self):
        # 返回数据集中的样本数量,即图像的数量
        return len(self.img_list)

    @staticmethod
    def collate_fn(batch):
        """
        这个函数用于将多个样本合并成一个批次(batch)。
        - batch:是由多个图像和目标掩码组成的列表
        - 使用`cat_list`函数将图像和掩码拼接成一个统一大小的批次
        """
        # 将批次中的所有图像和目标掩码分别取出
        images, targets = list(zip(*batch))
        # 将图像列表按最大尺寸拼接成一个批次
        batched_imgs = cat_list(images, fill_value=0)
        # 将掩码列表按最大尺寸拼接成一个批次
        batched_targets = cat_list(targets, fill_value=255)
        # 返回拼接好的图像和掩码
        return batched_imgs, batched_targets


def cat_list(images, fill_value=0):
    """
    这个函数将多个图像拼接成一个批次。
    所有图像会被填充成相同的尺寸,填充部分用`fill_value`来填充。
    - images:要拼接的图像列表
    - fill_value:填充区域的值,默认是0(黑色)
    """
    # 找到所有图像中最大的尺寸(宽和高)
    max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
    # 创建一个形状为(batch_size, max_height, max_width)的零矩阵,用来存放批次中的图像
    batch_shape = (len(images),) + max_size
    # 创建一个全是`fill_value`的矩阵,作为初始的空白批次
    batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
    # 将每个图像复制到对应的位置,保证原图大小不变,超出部分用填充值填充
    for img, pad_img in zip(images, batched_imgs):
        pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)
    # 返回拼接好的批次图像
    return batched_imgs

如果没有mask的话,就把对应mask的代码去掉就好,下面是没有蒙版的代码

python 复制代码
import os
from PIL import Image
import numpy as np
from torch.utils.data import Dataset


class DriveDataset(Dataset):
    def __init__(self, root: str, train: bool, transforms=None):
        super(DriveDataset, self).__init__()
        # 根据是否是训练模式,设置数据集的标志为"training"或"test"
        self.flag = "training" if train else "test"
        # 生成数据集路径,data_root是训练集或测试集的根目录
        data_root = os.path.join(root, "DRIVE", self.flag)
        # 检查路径是否存在,如果不存在就报错
        assert os.path.exists(data_root), f"path '{data_root}' does not exists."
        self.transforms = transforms
        # 获取所有.tif格式的图像文件名
        img_names = [i for i in os.listdir(os.path.join(data_root, "images")) if i.endswith(".tif")]
        # 生成图像路径的列表
        self.img_list = [os.path.join(data_root, "images", i) for i in img_names]
        # 构建对应的人工标注文件路径列表
        self.manual = [os.path.join(data_root, "1st_manual", i.split("_")[0] + "_manual1.gif")
                       for i in img_names]
        # 检查人工标注文件是否存在
        for i in self.manual:
            if os.path.exists(i) is False:
                raise FileNotFoundError(f"file {i} does not exists.")

    def __getitem__(self, idx):
        """
        这个方法返回数据集中某个特定位置(idx)的图像和对应的人工标注。
        1. 读取图像文件。
        2. 读取人工标注图像并转换为[0, 1]范围的标注。
        3. 如果提供了转换函数(transforms),则对图像和标注进行相应处理。
        """
        # 打开原始图像并转换为RGB格式
        img = Image.open(self.img_list[idx]).convert('RGB')
        # 打开人工标注图像并转换为灰度模式
        manual = Image.open(self.manual[idx]).convert('L')
        # 将人工标注的值转换为[0, 1]之间
        manual = np.array(manual) / 255
        # 将人工标注转换回PIL图像
        manual = Image.fromarray((manual * 255).astype(np.uint8))

        # 如果提供了数据增强或变换函数(transforms),则对图像和标注进行处理
        if self.transforms is not None:
            img, manual = self.transforms(img, manual)

        # 返回图像和人工标注
        return img, manual

    def __len__(self):
        # 返回数据集的大小,即图像的数量
        return len(self.img_list)

    @staticmethod
    def collate_fn(batch):
        """
        自定义批处理函数,用于将多个样本合并为一个批次。
        - batch: 包含多个图像和人工标注的样本列表
        - 使用cat_list函数将图像和标注批量处理,填充成相同的大小
        """
        # 将批次中的所有图像和人工标注分开
        images, targets = list(zip(*batch))
        # 将图像列表按最大尺寸合并成一个批次
        batched_imgs = cat_list(images, fill_value=0)
        # 将标注列表按最大尺寸合并成一个批次
        batched_targets = cat_list(targets, fill_value=255)
        return batched_imgs, batched_targets


def cat_list(images, fill_value=0):
    """
    这个函数将多个图像拼接成一个批次。
    所有图像会被填充成相同的尺寸,填充部分用`fill_value`来填充。
    - images:要拼接的图像列表
    - fill_value:填充区域的值,默认是0(黑色)
    """
    # 找到所有图像中最大的尺寸(宽和高)
    max_size = tuple(max(s) for s in zip(*[img.shape for img in images]))
    # 创建一个形状为(batch_size, max_height, max_width)的零矩阵,用来存放批次中的图像
    batch_shape = (len(images),) + max_size
    # 创建一个全是`fill_value`的矩阵,作为初始的空白批次
    batched_imgs = images[0].new(*batch_shape).fill_(fill_value)
    # 将每个图像复制到对应的位置,保证原图大小不变,超出部分用填充值填充
    for img, pad_img in zip(images, batched_imgs):
        pad_img[..., :img.shape[-2], :img.shape[-1]].copy_(img)
    # 返回拼接好的批次图像
    return batched_imgs

3.dice损失计算

(1)什么是Dice?

(2) dice_coeff 是怎么计算的?

dice_coeff 计算的是一张图像的 Dice 系数:

  • 先把预测的结果和真实标签展开成一维数组。
  • 计算它们的 重叠区域(就是预测对的部分)。
  • 然后算出 总区域(预测区域加上真实区域的大小)。
  • 最后,计算出 Dice 系数:重叠区域越大,Dice 系数越高。

如果有某些部分需要忽略(比如标记为 ignore_index 的区域),它会自动跳过这些部分,确保不计算这些区域。

(3)multiclass_dice_coeff 是怎么计算的?

如果你在做多类分割(比如分割不同类型的物体),它会对每一类计算一个 Dice 系数,然后算出这些类的平均值。比如,假设你分割了人和车,它会分别计算人和车的 Dice 系数,最后取平均。

(4)dice_loss 是怎么计算的?

dice_loss 计算的是 Dice Loss ,就是模型优化的目标,值越小越好。这个损失值其实就是 1 - Dice 系数

  • Dice 系数越大,模型预测越准确,损失越小。
  • 因为我们想让损失值最小化,所以目标是让 Dice 系数最大化。

计算函数如下:

python 复制代码
import torch
import torch.nn as nn


def build_target(target: torch.Tensor, num_classes: int = 2, ignore_index: int = -100):
    """构建Dice系数需要的目标标签"""
    dice_target = target.clone()  # 克隆目标标签,避免直接修改原数据
    if ignore_index >= 0:
        # 如果存在ignore_index,找到目标标签中所有等于ignore_index的部分
        ignore_mask = torch.eq(target, ignore_index)
        dice_target[ignore_mask] = 0  # 将ignore_index部分设置为0
        # 将目标标签进行one-hot编码,转换为[N, H, W, C]的格式
        dice_target = nn.functional.one_hot(dice_target, num_classes).float()
        dice_target[ignore_mask] = ignore_index  # 恢复ignore_index的区域
    else:
        # 如果没有ignore_index,直接进行one-hot编码
        dice_target = nn.functional.one_hot(dice_target, num_classes).float()

    # 调整维度顺序,从[N, H, W, C]变为[N, C, H, W]
    return dice_target.permute(0, 3, 1, 2)


def dice_coeff(x: torch.Tensor, target: torch.Tensor, ignore_index: int = -100, epsilon=1e-6):
    """计算一个batch中所有图片某个类别的Dice系数"""
    d = 0.  # 初始化Dice系数
    batch_size = x.shape[0]  # 获取batch的大小
    for i in range(batch_size):
        # 将第i张图像和目标标签展平为一维数组
        x_i = x[i].reshape(-1)
        t_i = target[i].reshape(-1)

        if ignore_index >= 0:
            # 如果有ignore_index,找到目标标签中不为ignore_index的区域
            roi_mask = torch.ne(t_i, ignore_index)
            x_i = x_i[roi_mask]  # 只保留不为ignore_index的部分
            t_i = t_i[roi_mask]  # 目标标签也做同样的处理

        # 计算交集(预测值和目标值的点积)
        inter = torch.dot(x_i, t_i)
        # 计算并集(预测区域总和 + 目标区域总和)
        sets_sum = torch.sum(x_i) + torch.sum(t_i)

        if sets_sum == 0:
            # 如果并集为0(预测和目标都没有预测到任何目标),就直接返回2 * 交集
            sets_sum = 2 * inter

        # 计算Dice系数,并且避免除0错误,加入一个小常数epsilon
        d += (2 * inter + epsilon) / (sets_sum + epsilon)

    # 返回batch内所有图片的平均Dice系数
    return d / batch_size


def multiclass_dice_coeff(x: torch.Tensor, target: torch.Tensor, ignore_index: int = -100, epsilon=1e-6):
    """计算所有类别的Dice系数平均值"""
    dice = 0.  # 初始化Dice系数
    for channel in range(x.shape[1]):  # 对每一个类别(通道)计算Dice系数
        dice += dice_coeff(x[:, channel, ...], target[:, channel, ...], ignore_index, epsilon)

    # 返回所有类别的平均Dice系数
    return dice / x.shape[1]


def dice_loss(x: torch.Tensor, target: torch.Tensor, multiclass: bool = False, ignore_index: int = -100):
    """计算Dice损失(目标是最小化该损失)"""
    x = nn.functional.softmax(x, dim=1)  # 对模型输出进行softmax,转化为概率分布
    fn = multiclass_dice_coeff if multiclass else dice_coeff  # 根据是否多分类选择对应的函数
    # 计算Dice损失,目标是最小化,所以是1减去Dice系数
    return 1 - fn(x, target, ignore_index=ignore_index)

损失计算过程:

1.将真实的图像标签通过one-hot编码转换成多个通道,医学影像中我们只有两个类别:背景和前景,2个类别,2 个类别的标签会变成一个 2 通道的图像,就是说每个像素会有两个数值,一个代表类别 0 的概率,另一个代表类别 1 的概率。

学习器现在就可以看到这样:

为什么要用这种方式?

(1)多通道表示 :通过 One-hot 编码,每个类别都有一个独立的通道,可以清晰地区分每个类别的区域。这个特性在深度学习模型中很有用,尤其是对于 图像分割任务,我们需要对每个像素进行分类,判断它属于哪个类别。

(2)模型训练:在训练神经网络时,模型会根据这些 One-hot 编码的标签计算损失(比如 Dice 损失或者交叉熵损失)。这有助于网络理解和学习每个像素的类别。

2.接下来我们就计算单个类别的Dice系数, 对于每一张图像,我们把 预测结果真实标签 展平(变成一维),然后比较它们的重叠部分。

然后,我们计算它们的 交集并集,然后用公式算出 Dice 系数

  1. 计算 Dice 损失 : dice_loss

Dice 损失 就是 1 - Dice 系数: 也就是说,Dice 系数越高,损失越小,模型越好。

4.模型是如何优化调参的?

先看PyTorch 模型训练四大阶段:

阶段 PyTorch 中做了什么 举例说明
1. 初始化模型 定义模型结构:构造网络层,比如卷积、池化等(__init__() 就是你写的 UNet(nn.Module),在 __init__()里搭建各个模块
2. 前向传播 执行 forward()方法,输入数据通过网络流动,得到输出结果(比如预测图像分割) out = model(input)就是触发了 forward()
3. 反向传播 自动计算损失函数对模型参数的梯度(loss.backward() loss = criterion(output, label)loss.backward()
4. 更新参数 用优化器更新参数,让模型更准确(optimizer.step() 比如使用 Adam优化器更新参数:optimizer.step()

训练epoch: 执行一个训练周期的训练过程(一个epoch意味着训练数据集中的每个样本都被用来训练模型一次。 )

ini 复制代码
def train_one_epoch(model, optimizer, data_loader, device, epoch, num_classes,
                    lr_scheduler, print_freq=10, scaler=None):
    model.train()  # 设置模型为训练模式
    metric_logger = utils.MetricLogger(delimiter="  ")  # 用于记录日志
    metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))  # 记录学习率
    header = 'Epoch: [{}]'.format(epoch)

    # 如果是二分类,设置交叉熵中的损失权重(背景和前景的权重不同)
    if num_classes == 2:
        loss_weight = torch.as_tensor([1.0, 2.0], device=device)  # 背景权重1.0,前景权重2.0
    else:
        loss_weight = None

    for image, target in metric_logger.log_every(data_loader, print_freq, header):
        image, target = image.to(device), target.to(device)

        # 使用自动混合精度(如果有的话)进行前向传播
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            output = model(image)  # 前向传播
            # 计算损失
            loss = criterion(output, target, loss_weight, num_classes=num_classes, ignore_index=255)

        optimizer.zero_grad()  # 清空梯度
        if scaler is not None:
            # 使用混合精度训练时,进行反向传播和参数更新
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            # 常规反向传播
            loss.backward()
            optimizer.step()

        lr_scheduler.step()  # 更新学习率

        # 获取当前学习率
        lr = optimizer.param_groups[0]["lr"]
        # 更新日志
        metric_logger.update(loss=loss.item(), lr=lr)

    return metric_logger.meters["loss"].global_avg, lr  # 返回平均损失和当前学习率

步骤

  1. 设置模型为训练模式(model.train())。
  2. 设置损失权重(如果是二分类,前景的损失权重设置为2,背景为1)。
  3. 遍历训练数据,执行前向传播计算输出。
  4. 使用混合精度训练时,进行反向传播并更新参数。
  5. 如果没有混合精度,则进行常规反向传播。
  6. 更新学习率lr_scheduler.step())。
  7. 返回平均损失和当前学习率。

首先,我们先初始化一个学习率(步长),然后进行前向传播(模型接收输入数据计算出预测结果的过程),计算出损失率(1-dice),再进行反向传播( 根据损失函数计算每个参数的梯度[每个参数对损失的影响]),然后进行梯度下降,利用刚刚的梯度来更新模型的参数,每次更新的幅度就是学习率的大小,在epoch训练的过程中不断调整学习率,最终找到最优的参数,得到一个收敛的模型。

再梳理一下:

1. 初始化学习率(步长):

    • 在训练开始之前,我们设定一个初始的学习率(通常是一个较小的数值),这个学习率决定了优化过程中每次参数更新的步伐。

2. 前向传播:

    • 模型接收输入数据(如图片、文本等),然后进行计算,得出预测结果。这是模型的前向传播过程。

3. 计算损失率:

    • 将模型的预测结果和真实标签(ground truth)进行比较,计算出模型的误差,这个误差用一个损失函数来量化。比如,Dice损失(1 - Dice系数)就可以用来衡量模型预测的分割结果和真实标签之间的差异。

4. 反向传播:

    • 根据损失值,通过反向传播算法计算出每个参数的梯度。梯度告诉我们每个参数对损失的贡献有多大,即它们需要调整的方向和幅度。

5. 梯度下降更新参数:

    • 使用计算出来的梯度,按照学习率的大小来更新模型的参数。梯度下降就是根据"梯度"来调整参数,学习率(步长)控制每次调整的大小。
    • 更新公式:
      参数 = 参数 - 学习率 * 梯度 这里,学习率控制着每次更新的步伐,如果学习率太大,可能会跨过最优解;如果太小,收敛速度会变慢。

6. 调整学习率:

    • 在训练过程中,学习率可以根据预设的策略进行调整。例如,训练到一定epoch后,逐渐减小学习率,让模型能更精细地调整参数,避免错过最优解。
    • 这种调整可以通过学习率调度器(如学习率衰减)来实现。

7. 训练过程:

    • 在多个epoch的训练过程中,模型会不断调整参数,每次的损失都会变小,最终逐渐收敛到一个最优的参数解。

8. 最终得到收敛的模型:

    • 经过多次的前向传播、损失计算、反向传播和梯度下降后,模型会找到一组最优参数,这样模型在新数据上的预测效果就最好了。

总结:

  • 训练过程的核心就是通过反向传播计算梯度,并通过梯度下降来更新模型的参数,学习率决定了每次更新的步伐。
  • 随着训练的进行,通过不断调整学习率,可以让模型更好地收敛到一个最优解
相关推荐
niuniu_6666 分钟前
针对 Python 3.7.0,以下是 Selenium 版本的兼容性建议和安装步骤
开发语言·chrome·python·selenium·测试工具
苏卫苏卫苏卫9 分钟前
【Python】数据结构练习
开发语言·数据结构·笔记·python·numpy·pandas
DexterLien1 小时前
Flask + Pear Admin Layui 快速开发管理后台
python·flask·layui
码界筑梦坊1 小时前
基于Flask的笔记本电脑数据可视化分析系统
python·信息可视化·flask·毕业设计·电脑
_x_w1 小时前
【8】数据结构的栈与队列练习篇章
开发语言·数据结构·笔记·python·链表
海姐软件测试2 小时前
Jmeter如何使用MD5进行加密?
python·jmeter·压力测试
Cl_rown去掉l变成C2 小时前
第P10周:Pytorch实现车牌识别
人工智能·pytorch·python
Y1nhl2 小时前
Pyspark学习二:快速入门基本数据结构
大数据·数据结构·python·学习·算法·hdfs·pyspark
独好紫罗兰2 小时前
洛谷题单3-P1423 小玉在游泳-python-流程图重构
开发语言·python·算法
onejason3 小时前
利用 Python 爬虫获取淘宝商品 SKU 详细信息
python