基于 ResNet18 迁移学习的 20 类食物图像分类

在计算机视觉的图像分类任务中,从头搭建并训练深度神经网络不仅需要大量的标注数据,还需要耗费大量的计算资源。迁移学习作为一种高效的建模方法,能够借助预训练模型的特征提取能力,快速适配新的分类任务。本文将以 20 类食物分类为例,详细讲解如何基于 PyTorch 框架,利用 ResNet18 预训练模型实现迁移学习,从模型改造、数据处理到训练测试。

一、迁移学习逻辑

迁移学习的核心实现逻辑分为两步:

  1. 冻结预训练模型的卷积层:保留其已学习的特征提取能力,避免训练过程中破坏原有权重;

  2. 替换全连接层:ResNet18 原全连接层输出为 1000 类(适配 ImageNet),本次任务为 20 类食物分类,因此将原全连接层替换为输入特征数不变、输出为 20 的新全连接层,仅训练该层的参数,大幅降低训练成本。

同时,为了提升模型泛化能力,对训练集进行数据增强,验证集仅做标准化处理;训练过程中使用 Adam 优化器更新参数,结合学习率调度器动态调整学习率,最终实现高效的模型训练。

二、预训练模型改造与参数冻结

这一步是迁移学习的核心,需要完成加载预训练 ResNet18、冻结卷积层参数、替换全连接层三个关键操作,同时筛选出需要训练的参数。

2.1 加载预训练 ResNet18 模型

PyTorch 的torchvision.models模块提供了封装好的 ResNet18,通过指定weights=models.ResNet18_Weights.DEFAULT可以直接加载在 ImageNet 上预训练的权重,无需手动下载权重文件:

python 复制代码
import torch
import torchvision.models as models
from torch import nn

# 加载ResNet18预训练模型
resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

2.2 冻结卷积层所有参数

将模型所有参数的requires_grad属性设置为False,表示这些参数在训练过程中不计算梯度,也不会被优化器更新,即实现参数冻结:

python 复制代码
# 冻结所有预训练参数(卷积层为主)
for param in resnet_model.parameters():
    param.requires_grad = False

2.3 替换全连接层并筛选训练参数

ResNet18 的全连接层(fc)是模型的分类头,需要根据任务需求替换。首先获取原全连接层的输入特征数(由卷积层的输出特征决定,固定为 512),然后替换为输出 20 的新全连接层;最后筛选出requires_grad=True的参数(仅新全连接层的参数),作为后续优化器的更新对象:

python 复制代码
# 获取原全连接层的输入特征数
in_features = resnet_model.fc.in_features
# 替换为输出20的新全连接层(适配20类食物分类)
resnet_model.fc = nn.Linear(in_features, 20)

# 筛选需要训练的参数(仅全连接层参数)
params_to_update = []
for param in resnet_model.parameters():
    if param.requires_grad == True:
        params_to_update.append(param)

三、图像数据处理:变换、加载与封装

食物分类的原始数据为图像文件和对应的标签文件(trainda.txt/testda.txt),标签文件中每行格式为图像路径 类别标签。需要通过数据变换、自定义 Dataset、DataLoader 封装三步,将原始数据转换为 PyTorch 可训练的格式。

3.1 图像变换策略:训练集增强,验证集标准化

图像变换的核心原则是:训练集做数据增强提升泛化能力,验证 / 测试集仅做与训练集一致的标准化处理,避免引入额外噪声。本次使用torchvision.transforms实现变换

python 复制代码
from torchvision import transforms

# 定义训练集和验证集的图像变换
data_transforms = {
    "train": transforms.Compose([
        transforms.Resize([300, 300]),  # 缩放为300*300,为后续裁剪预留空间
        transforms.RandomRotation(45),  # 随机旋转-45~45度,增强旋转鲁棒性
        transforms.CenterCrop(224),     # 中心裁剪为224*224,匹配ResNet输入要求
        transforms.ToTensor(),          # 转换为Tensor,格式[C, H, W],值归一化到0~1
        # 标准化:使用ImageNet的均值和标准差,与预训练模型保持一致
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    "valid": transforms.Compose([
        transforms.Resize([256, 256]),  # 缩放为256*256
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

3.2 自定义 Dataset:解析标签文件并加载图像

PyTorch 的Dataset是抽象类,需要自定义子类实现__init____len____getitem__三个核心方法,实现标签文件解析、图像读取、格式转换的功能,将图像和标签封装为可索引的数据集

python 复制代码
from torch.utils.data import Dataset
from PIL import Image
import numpy as np

class food_dataset(Dataset):
    def __init__(self, file_path, transform=None):
        """
        初始化:解析标签文件,保存图像路径和标签
        :param file_path: 标签文件路径(trainda.txt/testda.txt)
        :param transform: 图像变换策略
        """
        self.file_path = file_path
        self.imgs = []  # 保存所有图像路径
        self.labels = []  # 保存所有图像对应的标签
        self.transform = transform
        # 解析标签文件
        with open(self.file_path) as f:
            # 每行按空格分割,得到[图像路径, 标签]
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path, label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)

    def __len__(self):
        """返回数据集的样本总数"""
        return len(self.imgs)

    def __getitem__(self, idx):
        """根据索引获取单个样本(图像+标签),并完成格式转换"""
        # 读取PIL图像
        image = Image.open(self.imgs[idx])
        # 应用图像变换(转换为Tensor并标准化/增强)
        if self.transform:
            image = self.transform(image)
        # 标签转换为Tensor(int64类型,适配交叉熵损失)
        label = torch.from_numpy(np.array(self.labels[idx], dtype=np.int64))
        return image, label

3.3 DataLoader 封装:批量加载与数据打乱

通过torch.utils.data.DataLoader将自定义的 Dataset 封装为可迭代的批量数据加载器,实现批量读取、数据打乱、多进程加载(默认)等功能,是 PyTorch 训练的标准数据输入方式

python 复制代码
from torch.utils.data import DataLoader

# 实例化训练集和测试集
training_data = food_dataset(file_path=r'.\trainda.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path=r'.\testda.txt', transform=data_transforms['valid'])

# 封装为DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)  # 训练集打乱
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)      # 测试集可选择是否打乱

参数说明:

  • batch_size=64:每次读取 64 个样本为一个批次,批次大小可根据显存调整(显存不足则减小);

  • shuffle=True:每个 epoch 训练前打乱数据顺序,避免模型学习到数据的顺序规律,提升泛化能力。

四、设备配置与模型部署

python 复制代码
# 自动检测可用设备
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

# 将模型部署到指定设备
model = resnet_model.to(device)

五、训练与测试函数定义

5.1 训练函数:实现模型的一次 epoch 训练

训练函数的核心是将模型切换为训练模式(model.train()),遍历训练集的每个批次,完成前向传播预测、交叉熵损失计算、梯度清零、反向传播求梯度、优化器更新参数的流程,并打印批次损失值.

python 复制代码
def train(dataloader, model, loss_fn, optimizer):
    model.train()  # 切换为训练模式:启用Dropout、BatchNorm的训练模式
    batch_size_num = 1  # 统计训练的批次数量
    for x, y in dataloader:
        # 将数据部署到指定设备
        x, y = x.to(device), y.to(device)
        # 前向传播:模型预测(.forward可省略,直接model(x)即可)
        pred = model(x)
        # 计算损失:交叉熵损失(适配分类任务)
        loss = loss_fn(pred, y)

        # 反向传播与参数更新
        optimizer.zero_grad()  # 梯度清零:避免上一批次的梯度累积
        loss.backward()        # 反向传播:计算所有可训练参数的梯度
        optimizer.step()       # 优化器更新:根据梯度调整参数

        # 提取损失值并按批次打印
        loss_value = loss.item()
        if batch_size_num % 2 == 0:
            print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")
        batch_size_num += 1
  • model.train():必须在训练前调用,启用模型中训练相关的层(如 BatchNorm、Dropout),确保训练过程的正确性;

  • optimizer.zero_grad():每次批次训练前必须清零梯度,否则梯度会在批次间累积,导致参数更新错误;

  • 前向传播时model(x)等价于model.forward(x),PyTorch 已对__call__方法做了封装,推荐直接使用model(x)

5.2 测试函数:实现模型的精度与损失评估

测试函数的核心是将模型切换为评估模式(model.eval()),关闭梯度计(torch.no_grad()),遍历测试集的所有批次,统计总损失和正确预测数,最终计算平均损失和准确率。

python 复制代码
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)  # 测试集总样本数
    num_batches = len(dataloader)   # 测试集总批次数
    model.eval()                    # 切换为评估模式:关闭Dropout、固定BatchNorm参数
    test_loss, correct = 0, 0       # 初始化总损失和正确预测数

    # 关闭梯度计算:测试过程无需求梯度,节省显存和计算资源
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            pred = model(x)  # 前向传播预测
            # 累加批次损失
            test_loss += loss_fn(pred, y).item()
            # 累加正确预测数:pred.argmax(1)获取每行最大值的索引(预测类别),与真实标签比较
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()

    # 计算平均损失和准确率
    test_loss /= num_batches
    correct /= size
    print(f"Test result: Accuracy:{(100*correct):>0.1f}%,Avg loss: {test_loss:>8f}")
  • model.eval():必须在测试前调用,关闭 Dropout 层、固定 BatchNorm 的均值和方差,避免评估过程中模型参数变化,确保评估结果的稳定性;

  • torch.no_grad():上下文管理器,关闭梯度计算,大幅节省显存和计算资源,测试过程无需反向传播,因此无需计算梯度;

  • pred.argmax(1):对模型输出的预测值(形状 [batch_size, 20])按维度 1 取最大值索引,即模型预测的类别标签(20 类中概率最大的类)。

六、优化器、损失函数与学习率

6.1 损失函数:交叉熵损失

python 复制代码
loss_fn = nn.CrossEntropyLoss()

6.2 优化器:Adam 优化器

选择 Adam 优化器更新模型参数,Adam 结合了动量法和自适应学习率的优点,收敛速度快、调参简单,本次仅更新筛选出的全连接层参数(params_to_update),学习率设置为 0.001

python 复制代码
optimizer = torch.optim.Adam(params_to_update, lr=0.001)
# 若全模型微调,改为:optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

6.3 学习率调度器:StepLR

为了避免模型训练后期陷入局部最优,使用StepLR学习率调度器,每经过指定的 epoch 数,将学习率乘以衰减系数,实现动态学习率调整

python 复制代码
# 每5个epoch,学习率乘以0.5(衰减为原来的一半)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

七、训练与测试

python 复制代码
# 设置训练的epoch数
epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n----------")
    train(train_dataloader, model, loss_fn, optimizer)  # 训练一个epoch
    scheduler.step()  # 更新学习率
    test(test_dataloader, model, loss_fn)              # 测试模型性能
print("Training Done!")

scheduler.step():每个 epoch 训练完成后调用,更新学习率

结果显示

相关推荐
Piar1231sdafa5 小时前
蓝莓目标检测——改进YOLO11-C2TSSA-DYT-Mona模型实现
人工智能·目标检测·计算机视觉
愚公搬代码5 小时前
【愚公系列】《AI短视频创作一本通》002-AI引爆短视频创作革命(短视频创作者必备的能力)
人工智能
数据猿视觉5 小时前
新品上市|奢音S5耳夹耳机:3.5g无感佩戴,178.8元全场景适配
人工智能
蚁巡信息巡查系统5 小时前
网站信息发布再巡查机制怎么建立?
大数据·人工智能·数据挖掘·内容运营
AI浩5 小时前
C-RADIOv4(技术报告)
人工智能·目标检测
Purple Coder5 小时前
AI赋予超导材料预测论文初稿
人工智能
Data_Journal5 小时前
Scrapy vs. Crawlee —— 哪个更好?!
运维·人工智能·爬虫·媒体·社媒营销
云边云科技_云网融合6 小时前
AIoT智能物联网平台:架构解析与边缘应用新图景
大数据·网络·人工智能·安全
康康的AI博客6 小时前
什么是API中转服务商?如何低成本高稳定调用海量AI大模型?
人工智能·ai
技术与健康6 小时前
AI Coding协作开发工作台 实战案例:为电商系统添加用户评论功能
人工智能