深度学习——基于 ResNet18 的图像分类训练

PyTorch 基于 ResNet18 的图像分类训练与验证全流程解析


一、项目概述

本文实现了一个基于 PyTorch 框架的图像分类模型,使用 ResNet18 作为预训练骨干网络(Backbone),并在其基础上进行迁移学习(Transfer Learning)。整个流程涵盖了:

  • 数据预处理与增强

  • 自定义 Dataset 与 DataLoader

  • 模型微调与参数冻结

  • 训练与验证循环

  • 学习率调度策略(ReduceLROnPlateau)

该项目的核心目标是利用已有的强大视觉特征提取网络(ResNet18)对新的小规模数据集进行分类任务,从而提升训练效率与模型性能。


二、模型部分解析:ResNet18 微调机制

复制代码
import torch
import torchvision.models as models
from torch import nn, optim

resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

这里加载了 torchvision.models 中的预训练 ResNet18 模型,其权重参数来自 ImageNet 大规模数据集的训练结果。

接着,冻结网络的所有参数,防止在训练过程中被更新:

复制代码
for param in resnet_model.parameters():
    print(param)
    param.requires_grad = False

原理说明:

  • 迁移学习的关键思路在于"保留特征提取层"。

  • 早期卷积层学习的是通用特征(如边缘、纹理),可直接用于新任务。

  • 仅需微调后几层或分类头层(fc层),可显著减少训练量。

然后替换最后一层全连接层(fc层):

复制代码
in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(in_features, 20)

🔹 原始 ResNet18 的输出为 1000 类(ImageNet)。

🔹 这里改为 20 类,适配自定义数据集。

最后仅选择需要更新的参数:

复制代码
params_to_update = []
for param in resnet_model.parameters():
    if param.requires_grad == True:
        params_to_update.append(param)

这意味着优化器只会更新新加入的全连接层参数。


三、数据预处理与增强(Data Augmentation)

数据增强可提升模型泛化能力,代码中定义了两种处理策略:

复制代码
from torchvision import transforms

1. 训练集增强 data_transforms['train']

包含大量随机性增强操作:

复制代码
transforms.Compose([
    transforms.Resize([300,300]),
    transforms.RandomRotation(45),
    transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

📘 数据增强效果:

  • 旋转翻转灰度转换可提升模型在多视角条件下的鲁棒性。

  • 归一化操作确保输入分布与预训练模型保持一致。

2. 验证集预处理 data_transforms['valid']

复制代码
transforms.Compose([
    transforms.Resize([256,256]),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

验证集通常不使用随机增强,以保持结果的可重复性和客观性。


四、自定义 Dataset 与 DataLoader

1. Dataset 类定义

复制代码
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np

class food_dataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.file_path = file_path
        self.transform = transform
        self.imgs = []
        self.labels = []
        with open(self.file_path) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path,label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)

该类通过文本文件(如 train2.txt)加载图片路径和标签。

每一行格式为:

复制代码
image_path label

2. 数据访问接口

复制代码
def __len__(self):
    return len(self.imgs)

def __getitem__(self, index):
    image = Image.open(self.imgs[index])
    if self.transform:
        image = self.transform(image)
    label = torch.from_numpy(np.array(self.labels[index], dtype=np.int64))
    return image, label

⚙️ Dataset 必备方法:

  • __len__():返回数据集大小。

  • __getitem__():返回一条样本及其标签。

3. 数据加载器 DataLoader

复制代码
training_data = food_dataset('./train2.txt', transform=data_transforms['train'])
test_data = food_dataset('./test2.txt', transform=data_transforms['valid'])

train_dataloader = DataLoader(training_data, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=16, shuffle=True)

✅ DataLoader 的作用:

  • 自动打包 batch

  • 支持多线程加载(num_workers

  • 支持数据打乱(shuffle)


五、训练环境与优化器设置

1. 自动选择设备

复制代码
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

同时兼容:

  • NVIDIA GPU (CUDA)

  • Apple M1/M2 GPU (MPS)

  • CPU

2. 定义损失函数与优化器

复制代码
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params_to_update, lr=0.001)

使用 交叉熵损失函数 处理多分类任务,优化器为 Adam

3. 学习率调度器

复制代码
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=3,
    verbose=True
)

🔹 当验证集 Loss 连续 3 轮未改善时,学习率减半。

🔹 可有效防止过拟合与梯度振荡。


六、训练与验证循环

1. 训练函数

复制代码
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    batch_size_num = 1
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_size_num % 100 == 0:
            print(f"loss: {loss.item():>7f} [number:{batch_size_num}]")
        batch_size_num += 1

🔁 每 100 个 batch 打印一次 loss。

2. 测试函数

复制代码
def Test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test result:\n Accuracy:{100*correct}%, Avg loss:{test_loss}")
    return test_loss

🧮 评估指标:

  • Accuracy(准确率)

  • Avg loss(平均验证损失)


七、主训练循环

复制代码
epochs = 50
for t in range(epochs):
    print(f"---------------\nepoch {t+1}")
    train(train_dataloader, model, loss_fn, optimizer)
    val_loss = Test(test_dataloader, model, loss_fn)
    scheduler.step(val_loss)
print("Done!")

共进行 50 轮训练,每轮包括:

  1. 模型训练

  2. 验证集测试

  3. 根据验证集 loss 调整学习率

💡 随着 epoch 增加,loss 应逐渐下降,accuracy 提升。


八、总结

模块 作用 特点
ResNet18 特征提取主干 使用 ImageNet 预训练权重
Dataset 读取图片与标签 支持 transform 自动增强
DataLoader 批量化输入 shuffle 提升训练效果
train() 前向传播与反向传播 更新梯度
Test() 模型评估 计算平均损失与准确率
ReduceLROnPlateau 学习率调整 自动降低学习率防止过拟合
相关推荐
林炳然4 小时前
Python-Basic Day-1 基本元素(数字、字符串)
python
weixin_307779134 小时前
在Linux服务器上使用Jenkins和Poetry实现Python项目自动化
linux·开发语言·python·自动化·jenkins
今天没有盐4 小时前
内置基础类型之布尔值类型(bool)与时间与日期类型
python·编程语言
koo3644 小时前
李宏毅机器学习笔记25
人工智能·笔记·机器学习
Empty_7774 小时前
Python编程之常用模块
开发语言·网络·python
余俊晖4 小时前
如何让多模态大模型学会“自动思考”-R-4B训练框架核心设计与训练方法
人工智能·算法·机器学习
hzp6664 小时前
Magnus:面向大规模机器学习工作负载的综合数据管理方法
人工智能·深度学习·机器学习·大模型·llm·数据湖·大数据存储
Q_Q5110082854 小时前
python+uniapp基于微信小程序的学院设备报修系统
spring boot·python·微信小程序·django·flask·uni-app
GitNohup4 小时前
安装Anaconda和Pytorch
pytorch·anaconda