深度学习——基于 ResNet18 的图像分类训练

PyTorch 基于 ResNet18 的图像分类训练与验证全流程解析


一、项目概述

本文实现了一个基于 PyTorch 框架的图像分类模型,使用 ResNet18 作为预训练骨干网络(Backbone),并在其基础上进行迁移学习(Transfer Learning)。整个流程涵盖了:

  • 数据预处理与增强

  • 自定义 Dataset 与 DataLoader

  • 模型微调与参数冻结

  • 训练与验证循环

  • 学习率调度策略(ReduceLROnPlateau)

该项目的核心目标是利用已有的强大视觉特征提取网络(ResNet18)对新的小规模数据集进行分类任务,从而提升训练效率与模型性能。


二、模型部分解析:ResNet18 微调机制

复制代码
import torch
import torchvision.models as models
from torch import nn, optim

resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

这里加载了 torchvision.models 中的预训练 ResNet18 模型,其权重参数来自 ImageNet 大规模数据集的训练结果。

接着,冻结网络的所有参数,防止在训练过程中被更新:

复制代码
for param in resnet_model.parameters():
    print(param)
    param.requires_grad = False

原理说明:

  • 迁移学习的关键思路在于"保留特征提取层"。

  • 早期卷积层学习的是通用特征(如边缘、纹理),可直接用于新任务。

  • 仅需微调后几层或分类头层(fc层),可显著减少训练量。

然后替换最后一层全连接层(fc层):

复制代码
in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(in_features, 20)

🔹 原始 ResNet18 的输出为 1000 类(ImageNet)。

🔹 这里改为 20 类,适配自定义数据集。

最后仅选择需要更新的参数:

复制代码
params_to_update = []
for param in resnet_model.parameters():
    if param.requires_grad == True:
        params_to_update.append(param)

这意味着优化器只会更新新加入的全连接层参数。


三、数据预处理与增强(Data Augmentation)

数据增强可提升模型泛化能力,代码中定义了两种处理策略:

复制代码
from torchvision import transforms

1. 训练集增强 data_transforms['train']

包含大量随机性增强操作:

复制代码
transforms.Compose([
    transforms.Resize([300,300]),
    transforms.RandomRotation(45),
    transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

📘 数据增强效果:

  • 旋转翻转灰度转换可提升模型在多视角条件下的鲁棒性。

  • 归一化操作确保输入分布与预训练模型保持一致。

2. 验证集预处理 data_transforms['valid']

复制代码
transforms.Compose([
    transforms.Resize([256,256]),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

验证集通常不使用随机增强,以保持结果的可重复性和客观性。


四、自定义 Dataset 与 DataLoader

1. Dataset 类定义

复制代码
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np

class food_dataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.file_path = file_path
        self.transform = transform
        self.imgs = []
        self.labels = []
        with open(self.file_path) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path,label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)

该类通过文本文件(如 train2.txt)加载图片路径和标签。

每一行格式为:

复制代码
image_path label

2. 数据访问接口

复制代码
def __len__(self):
    return len(self.imgs)

def __getitem__(self, index):
    image = Image.open(self.imgs[index])
    if self.transform:
        image = self.transform(image)
    label = torch.from_numpy(np.array(self.labels[index], dtype=np.int64))
    return image, label

⚙️ Dataset 必备方法:

  • __len__():返回数据集大小。

  • __getitem__():返回一条样本及其标签。

3. 数据加载器 DataLoader

复制代码
training_data = food_dataset('./train2.txt', transform=data_transforms['train'])
test_data = food_dataset('./test2.txt', transform=data_transforms['valid'])

train_dataloader = DataLoader(training_data, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=16, shuffle=True)

✅ DataLoader 的作用:

  • 自动打包 batch

  • 支持多线程加载(num_workers

  • 支持数据打乱(shuffle)


五、训练环境与优化器设置

1. 自动选择设备

复制代码
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

同时兼容:

  • NVIDIA GPU (CUDA)

  • Apple M1/M2 GPU (MPS)

  • CPU

2. 定义损失函数与优化器

复制代码
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params_to_update, lr=0.001)

使用 交叉熵损失函数 处理多分类任务,优化器为 Adam

3. 学习率调度器

复制代码
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=3,
    verbose=True
)

🔹 当验证集 Loss 连续 3 轮未改善时,学习率减半。

🔹 可有效防止过拟合与梯度振荡。


六、训练与验证循环

1. 训练函数

复制代码
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    batch_size_num = 1
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_size_num % 100 == 0:
            print(f"loss: {loss.item():>7f} [number:{batch_size_num}]")
        batch_size_num += 1

🔁 每 100 个 batch 打印一次 loss。

2. 测试函数

复制代码
def Test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test result:\n Accuracy:{100*correct}%, Avg loss:{test_loss}")
    return test_loss

🧮 评估指标:

  • Accuracy(准确率)

  • Avg loss(平均验证损失)


七、主训练循环

复制代码
epochs = 50
for t in range(epochs):
    print(f"---------------\nepoch {t+1}")
    train(train_dataloader, model, loss_fn, optimizer)
    val_loss = Test(test_dataloader, model, loss_fn)
    scheduler.step(val_loss)
print("Done!")

共进行 50 轮训练,每轮包括:

  1. 模型训练

  2. 验证集测试

  3. 根据验证集 loss 调整学习率

💡 随着 epoch 增加,loss 应逐渐下降,accuracy 提升。


八、总结

模块 作用 特点
ResNet18 特征提取主干 使用 ImageNet 预训练权重
Dataset 读取图片与标签 支持 transform 自动增强
DataLoader 批量化输入 shuffle 提升训练效果
train() 前向传播与反向传播 更新梯度
Test() 模型评估 计算平均损失与准确率
ReduceLROnPlateau 学习率调整 自动降低学习率防止过拟合
相关推荐
Sunhen_Qiletian5 分钟前
Python 类继承详解:深度学习神经网络架构的构建艺术
python·深度学习·神经网络
程序员大雄学编程28 分钟前
用Python来学微积分34-定积分的基本性质及其应用
开发语言·python·数学·微积分
LHZSMASH!36 分钟前
神经流形:大脑功能几何基础的革命性视角
人工智能·深度学习·神经网络·机器学习
Q_Q51100828542 分钟前
python+django/flask的莱元元电商数据分析系统_电商销量预测
spring boot·python·django·flask·node.js·php
青云交1 小时前
Java 大视界 --Java 大数据在智慧农业农产品市场价格预测与种植决策支持中的应用实战
机器学习·智慧农业·数据安全·农业物联网·价格预测·java 大数据·种植决策
林一百二十八1 小时前
Python实现手写数字识别
开发语言·python
大明者省1 小时前
图像卷积操值超过了255怎么处理
深度学习·神经网络·机器学习
Q26433650231 小时前
【有源码】基于Hadoop+Spark的起点小说网大数据可视化分析系统-基于Python大数据生态的网络文学数据挖掘与可视化系统
大数据·hadoop·python·信息可视化·数据分析·spark·毕业设计
大叔_爱编程2 小时前
基于Python的历届奥运会数据可视化分析系统-django+spider
python·django·毕业设计·源码·课程设计·spider·奥运会数据可视化
小白狮ww2 小时前
模型不再是一整块!Hunyuan3D-Part 实现可控组件式 3D 生成
人工智能·深度学习·机器学习·教程·3d模型·hunyuan3d·3d创作