深度学习——基于 ResNet18 的图像分类训练

PyTorch 基于 ResNet18 的图像分类训练与验证全流程解析


一、项目概述

本文实现了一个基于 PyTorch 框架的图像分类模型,使用 ResNet18 作为预训练骨干网络(Backbone),并在其基础上进行迁移学习(Transfer Learning)。整个流程涵盖了:

  • 数据预处理与增强

  • 自定义 Dataset 与 DataLoader

  • 模型微调与参数冻结

  • 训练与验证循环

  • 学习率调度策略(ReduceLROnPlateau)

该项目的核心目标是利用已有的强大视觉特征提取网络(ResNet18)对新的小规模数据集进行分类任务,从而提升训练效率与模型性能。


二、模型部分解析:ResNet18 微调机制

复制代码
import torch
import torchvision.models as models
from torch import nn, optim

resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

这里加载了 torchvision.models 中的预训练 ResNet18 模型,其权重参数来自 ImageNet 大规模数据集的训练结果。

接着,冻结网络的所有参数,防止在训练过程中被更新:

复制代码
for param in resnet_model.parameters():
    print(param)
    param.requires_grad = False

原理说明:

  • 迁移学习的关键思路在于"保留特征提取层"。

  • 早期卷积层学习的是通用特征(如边缘、纹理),可直接用于新任务。

  • 仅需微调后几层或分类头层(fc层),可显著减少训练量。

然后替换最后一层全连接层(fc层):

复制代码
in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(in_features, 20)

🔹 原始 ResNet18 的输出为 1000 类(ImageNet)。

🔹 这里改为 20 类,适配自定义数据集。

最后仅选择需要更新的参数:

复制代码
params_to_update = []
for param in resnet_model.parameters():
    if param.requires_grad == True:
        params_to_update.append(param)

这意味着优化器只会更新新加入的全连接层参数。


三、数据预处理与增强(Data Augmentation)

数据增强可提升模型泛化能力,代码中定义了两种处理策略:

复制代码
from torchvision import transforms

1. 训练集增强 data_transforms['train']

包含大量随机性增强操作:

复制代码
transforms.Compose([
    transforms.Resize([300,300]),
    transforms.RandomRotation(45),
    transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

📘 数据增强效果:

  • 旋转翻转灰度转换可提升模型在多视角条件下的鲁棒性。

  • 归一化操作确保输入分布与预训练模型保持一致。

2. 验证集预处理 data_transforms['valid']

复制代码
transforms.Compose([
    transforms.Resize([256,256]),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

验证集通常不使用随机增强,以保持结果的可重复性和客观性。


四、自定义 Dataset 与 DataLoader

1. Dataset 类定义

复制代码
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np

class food_dataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.file_path = file_path
        self.transform = transform
        self.imgs = []
        self.labels = []
        with open(self.file_path) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path,label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)

该类通过文本文件(如 train2.txt)加载图片路径和标签。

每一行格式为:

复制代码
image_path label

2. 数据访问接口

复制代码
def __len__(self):
    return len(self.imgs)

def __getitem__(self, index):
    image = Image.open(self.imgs[index])
    if self.transform:
        image = self.transform(image)
    label = torch.from_numpy(np.array(self.labels[index], dtype=np.int64))
    return image, label

⚙️ Dataset 必备方法:

  • __len__():返回数据集大小。

  • __getitem__():返回一条样本及其标签。

3. 数据加载器 DataLoader

复制代码
training_data = food_dataset('./train2.txt', transform=data_transforms['train'])
test_data = food_dataset('./test2.txt', transform=data_transforms['valid'])

train_dataloader = DataLoader(training_data, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=16, shuffle=True)

✅ DataLoader 的作用:

  • 自动打包 batch

  • 支持多线程加载(num_workers

  • 支持数据打乱(shuffle)


五、训练环境与优化器设置

1. 自动选择设备

复制代码
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

同时兼容:

  • NVIDIA GPU (CUDA)

  • Apple M1/M2 GPU (MPS)

  • CPU

2. 定义损失函数与优化器

复制代码
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params_to_update, lr=0.001)

使用 交叉熵损失函数 处理多分类任务,优化器为 Adam

3. 学习率调度器

复制代码
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=3,
    verbose=True
)

🔹 当验证集 Loss 连续 3 轮未改善时,学习率减半。

🔹 可有效防止过拟合与梯度振荡。


六、训练与验证循环

1. 训练函数

复制代码
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    batch_size_num = 1
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_size_num % 100 == 0:
            print(f"loss: {loss.item():>7f} [number:{batch_size_num}]")
        batch_size_num += 1

🔁 每 100 个 batch 打印一次 loss。

2. 测试函数

复制代码
def Test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test result:\n Accuracy:{100*correct}%, Avg loss:{test_loss}")
    return test_loss

🧮 评估指标:

  • Accuracy(准确率)

  • Avg loss(平均验证损失)


七、主训练循环

复制代码
epochs = 50
for t in range(epochs):
    print(f"---------------\nepoch {t+1}")
    train(train_dataloader, model, loss_fn, optimizer)
    val_loss = Test(test_dataloader, model, loss_fn)
    scheduler.step(val_loss)
print("Done!")

共进行 50 轮训练,每轮包括:

  1. 模型训练

  2. 验证集测试

  3. 根据验证集 loss 调整学习率

💡 随着 epoch 增加,loss 应逐渐下降,accuracy 提升。


八、总结

模块 作用 特点
ResNet18 特征提取主干 使用 ImageNet 预训练权重
Dataset 读取图片与标签 支持 transform 自动增强
DataLoader 批量化输入 shuffle 提升训练效果
train() 前向传播与反向传播 更新梯度
Test() 模型评估 计算平均损失与准确率
ReduceLROnPlateau 学习率调整 自动降低学习率防止过拟合
相关推荐
习习.y9 小时前
python笔记梳理以及一些题目整理
开发语言·笔记·python
撸码猿9 小时前
《Python AI入门》第10章 拥抱AIGC——OpenAI API调用与Prompt工程实战
人工智能·python·aigc
qq_386218999 小时前
Gemini生成的自动搜索和下载论文的python脚本
开发语言·python
vx_vxbs6610 小时前
【SSM电影网站】(免费领源码+演示录像)|可做计算机毕设Java、Python、PHP、小程序APP、C#、爬虫大数据、单片机、文案
java·spring boot·python·mysql·小程序·php·idea
双翌视觉10 小时前
双翌全自动影像测量仪:以微米精度打造智能化制造
人工智能·机器学习·制造
编程小白_正在努力中11 小时前
神经网络深度解析:从神经元到深度学习的进化之路
人工智能·深度学习·神经网络·机器学习
烤汉堡12 小时前
Python入门到实战:post请求+cookie+代理
爬虫·python
luod12 小时前
Python异常链
python
我不是QI12 小时前
周志华《机器学习---西瓜书》 一
人工智能·python·机器学习·ai
今天没ID12 小时前
Python 编程实战:从基础语法到算法实现 (1)
python