深度学习——基于 ResNet18 的图像分类训练

PyTorch 基于 ResNet18 的图像分类训练与验证全流程解析


一、项目概述

本文实现了一个基于 PyTorch 框架的图像分类模型,使用 ResNet18 作为预训练骨干网络(Backbone),并在其基础上进行迁移学习(Transfer Learning)。整个流程涵盖了:

  • 数据预处理与增强

  • 自定义 Dataset 与 DataLoader

  • 模型微调与参数冻结

  • 训练与验证循环

  • 学习率调度策略(ReduceLROnPlateau)

该项目的核心目标是利用已有的强大视觉特征提取网络(ResNet18)对新的小规模数据集进行分类任务,从而提升训练效率与模型性能。


二、模型部分解析:ResNet18 微调机制

复制代码
import torch
import torchvision.models as models
from torch import nn, optim

resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

这里加载了 torchvision.models 中的预训练 ResNet18 模型,其权重参数来自 ImageNet 大规模数据集的训练结果。

接着,冻结网络的所有参数,防止在训练过程中被更新:

复制代码
for param in resnet_model.parameters():
    print(param)
    param.requires_grad = False

原理说明:

  • 迁移学习的关键思路在于"保留特征提取层"。

  • 早期卷积层学习的是通用特征(如边缘、纹理),可直接用于新任务。

  • 仅需微调后几层或分类头层(fc层),可显著减少训练量。

然后替换最后一层全连接层(fc层):

复制代码
in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(in_features, 20)

🔹 原始 ResNet18 的输出为 1000 类(ImageNet)。

🔹 这里改为 20 类,适配自定义数据集。

最后仅选择需要更新的参数:

复制代码
params_to_update = []
for param in resnet_model.parameters():
    if param.requires_grad == True:
        params_to_update.append(param)

这意味着优化器只会更新新加入的全连接层参数。


三、数据预处理与增强(Data Augmentation)

数据增强可提升模型泛化能力,代码中定义了两种处理策略:

复制代码
from torchvision import transforms

1. 训练集增强 data_transforms['train']

包含大量随机性增强操作:

复制代码
transforms.Compose([
    transforms.Resize([300,300]),
    transforms.RandomRotation(45),
    transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

📘 数据增强效果:

  • 旋转翻转灰度转换可提升模型在多视角条件下的鲁棒性。

  • 归一化操作确保输入分布与预训练模型保持一致。

2. 验证集预处理 data_transforms['valid']

复制代码
transforms.Compose([
    transforms.Resize([256,256]),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

验证集通常不使用随机增强,以保持结果的可重复性和客观性。


四、自定义 Dataset 与 DataLoader

1. Dataset 类定义

复制代码
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np

class food_dataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.file_path = file_path
        self.transform = transform
        self.imgs = []
        self.labels = []
        with open(self.file_path) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path,label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)

该类通过文本文件(如 train2.txt)加载图片路径和标签。

每一行格式为:

复制代码
image_path label

2. 数据访问接口

复制代码
def __len__(self):
    return len(self.imgs)

def __getitem__(self, index):
    image = Image.open(self.imgs[index])
    if self.transform:
        image = self.transform(image)
    label = torch.from_numpy(np.array(self.labels[index], dtype=np.int64))
    return image, label

⚙️ Dataset 必备方法:

  • __len__():返回数据集大小。

  • __getitem__():返回一条样本及其标签。

3. 数据加载器 DataLoader

复制代码
training_data = food_dataset('./train2.txt', transform=data_transforms['train'])
test_data = food_dataset('./test2.txt', transform=data_transforms['valid'])

train_dataloader = DataLoader(training_data, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=16, shuffle=True)

✅ DataLoader 的作用:

  • 自动打包 batch

  • 支持多线程加载(num_workers

  • 支持数据打乱(shuffle)


五、训练环境与优化器设置

1. 自动选择设备

复制代码
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

同时兼容:

  • NVIDIA GPU (CUDA)

  • Apple M1/M2 GPU (MPS)

  • CPU

2. 定义损失函数与优化器

复制代码
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params_to_update, lr=0.001)

使用 交叉熵损失函数 处理多分类任务,优化器为 Adam

3. 学习率调度器

复制代码
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=3,
    verbose=True
)

🔹 当验证集 Loss 连续 3 轮未改善时,学习率减半。

🔹 可有效防止过拟合与梯度振荡。


六、训练与验证循环

1. 训练函数

复制代码
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    batch_size_num = 1
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_size_num % 100 == 0:
            print(f"loss: {loss.item():>7f} [number:{batch_size_num}]")
        batch_size_num += 1

🔁 每 100 个 batch 打印一次 loss。

2. 测试函数

复制代码
def Test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test result:\n Accuracy:{100*correct}%, Avg loss:{test_loss}")
    return test_loss

🧮 评估指标:

  • Accuracy(准确率)

  • Avg loss(平均验证损失)


七、主训练循环

复制代码
epochs = 50
for t in range(epochs):
    print(f"---------------\nepoch {t+1}")
    train(train_dataloader, model, loss_fn, optimizer)
    val_loss = Test(test_dataloader, model, loss_fn)
    scheduler.step(val_loss)
print("Done!")

共进行 50 轮训练,每轮包括:

  1. 模型训练

  2. 验证集测试

  3. 根据验证集 loss 调整学习率

💡 随着 epoch 增加,loss 应逐渐下降,accuracy 提升。


八、总结

模块 作用 特点
ResNet18 特征提取主干 使用 ImageNet 预训练权重
Dataset 读取图片与标签 支持 transform 自动增强
DataLoader 批量化输入 shuffle 提升训练效果
train() 前向传播与反向传播 更新梯度
Test() 模型评估 计算平均损失与准确率
ReduceLROnPlateau 学习率调整 自动降低学习率防止过拟合
相关推荐
孟健21 小时前
Karpathy 用 200 行纯 Python 从零实现 GPT:代码逐行解析
python
HXhlx1 天前
CART决策树基本原理
算法·机器学习
码路飞1 天前
写了个 AI 聊天页面,被 5 种流式格式折腾了一整天 😭
javascript·python
曲幽1 天前
FastAPI压力测试实战:Locust模拟真实用户并发及优化建议
python·fastapi·web·locust·asyncio·test·uvicorn·workers
敏编程1 天前
一天一个Python库:jsonschema - JSON 数据验证利器
python
前端付豪1 天前
LangChain记忆:通过Memory记住上次的对话细节
人工智能·python·langchain
databook1 天前
ManimCE v0.20.1 发布:LaTeX 渲染修复与动画稳定性提升
python·动效
花酒锄作田2 天前
使用 pkgutil 实现动态插件系统
python
前端付豪2 天前
LangChain链 写一篇完美推文?用SequencialChain链接不同的组件
人工智能·python·langchain
曲幽2 天前
FastAPI实战:打造本地文生图接口,ollama+diffusers让AI绘画更听话
python·fastapi·web·cors·diffusers·lcm·ollama·dreamshaper8·txt2img