深度学习——基于 ResNet18 的图像分类训练

PyTorch 基于 ResNet18 的图像分类训练与验证全流程解析


一、项目概述

本文实现了一个基于 PyTorch 框架的图像分类模型,使用 ResNet18 作为预训练骨干网络(Backbone),并在其基础上进行迁移学习(Transfer Learning)。整个流程涵盖了:

  • 数据预处理与增强

  • 自定义 Dataset 与 DataLoader

  • 模型微调与参数冻结

  • 训练与验证循环

  • 学习率调度策略(ReduceLROnPlateau)

该项目的核心目标是利用已有的强大视觉特征提取网络(ResNet18)对新的小规模数据集进行分类任务,从而提升训练效率与模型性能。


二、模型部分解析:ResNet18 微调机制

复制代码
import torch
import torchvision.models as models
from torch import nn, optim

resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

这里加载了 torchvision.models 中的预训练 ResNet18 模型,其权重参数来自 ImageNet 大规模数据集的训练结果。

接着,冻结网络的所有参数,防止在训练过程中被更新:

复制代码
for param in resnet_model.parameters():
    print(param)
    param.requires_grad = False

原理说明:

  • 迁移学习的关键思路在于"保留特征提取层"。

  • 早期卷积层学习的是通用特征(如边缘、纹理),可直接用于新任务。

  • 仅需微调后几层或分类头层(fc层),可显著减少训练量。

然后替换最后一层全连接层(fc层):

复制代码
in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(in_features, 20)

🔹 原始 ResNet18 的输出为 1000 类(ImageNet)。

🔹 这里改为 20 类,适配自定义数据集。

最后仅选择需要更新的参数:

复制代码
params_to_update = []
for param in resnet_model.parameters():
    if param.requires_grad == True:
        params_to_update.append(param)

这意味着优化器只会更新新加入的全连接层参数。


三、数据预处理与增强(Data Augmentation)

数据增强可提升模型泛化能力,代码中定义了两种处理策略:

复制代码
from torchvision import transforms

1. 训练集增强 data_transforms['train']

包含大量随机性增强操作:

复制代码
transforms.Compose([
    transforms.Resize([300,300]),
    transforms.RandomRotation(45),
    transforms.CenterCrop(224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),
    transforms.RandomGrayscale(p=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

📘 数据增强效果:

  • 旋转翻转灰度转换可提升模型在多视角条件下的鲁棒性。

  • 归一化操作确保输入分布与预训练模型保持一致。

2. 验证集预处理 data_transforms['valid']

复制代码
transforms.Compose([
    transforms.Resize([256,256]),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

验证集通常不使用随机增强,以保持结果的可重复性和客观性。


四、自定义 Dataset 与 DataLoader

1. Dataset 类定义

复制代码
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np

class food_dataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.file_path = file_path
        self.transform = transform
        self.imgs = []
        self.labels = []
        with open(self.file_path) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path,label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)

该类通过文本文件(如 train2.txt)加载图片路径和标签。

每一行格式为:

复制代码
image_path label

2. 数据访问接口

复制代码
def __len__(self):
    return len(self.imgs)

def __getitem__(self, index):
    image = Image.open(self.imgs[index])
    if self.transform:
        image = self.transform(image)
    label = torch.from_numpy(np.array(self.labels[index], dtype=np.int64))
    return image, label

⚙️ Dataset 必备方法:

  • __len__():返回数据集大小。

  • __getitem__():返回一条样本及其标签。

3. 数据加载器 DataLoader

复制代码
training_data = food_dataset('./train2.txt', transform=data_transforms['train'])
test_data = food_dataset('./test2.txt', transform=data_transforms['valid'])

train_dataloader = DataLoader(training_data, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=16, shuffle=True)

✅ DataLoader 的作用:

  • 自动打包 batch

  • 支持多线程加载(num_workers

  • 支持数据打乱(shuffle)


五、训练环境与优化器设置

1. 自动选择设备

复制代码
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

同时兼容:

  • NVIDIA GPU (CUDA)

  • Apple M1/M2 GPU (MPS)

  • CPU

2. 定义损失函数与优化器

复制代码
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params_to_update, lr=0.001)

使用 交叉熵损失函数 处理多分类任务,优化器为 Adam

3. 学习率调度器

复制代码
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,
    patience=3,
    verbose=True
)

🔹 当验证集 Loss 连续 3 轮未改善时,学习率减半。

🔹 可有效防止过拟合与梯度振荡。


六、训练与验证循环

1. 训练函数

复制代码
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    batch_size_num = 1
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch_size_num % 100 == 0:
            print(f"loss: {loss.item():>7f} [number:{batch_size_num}]")
        batch_size_num += 1

🔁 每 100 个 batch 打印一次 loss。

2. 测试函数

复制代码
def Test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test result:\n Accuracy:{100*correct}%, Avg loss:{test_loss}")
    return test_loss

🧮 评估指标:

  • Accuracy(准确率)

  • Avg loss(平均验证损失)


七、主训练循环

复制代码
epochs = 50
for t in range(epochs):
    print(f"---------------\nepoch {t+1}")
    train(train_dataloader, model, loss_fn, optimizer)
    val_loss = Test(test_dataloader, model, loss_fn)
    scheduler.step(val_loss)
print("Done!")

共进行 50 轮训练,每轮包括:

  1. 模型训练

  2. 验证集测试

  3. 根据验证集 loss 调整学习率

💡 随着 epoch 增加,loss 应逐渐下降,accuracy 提升。


八、总结

模块 作用 特点
ResNet18 特征提取主干 使用 ImageNet 预训练权重
Dataset 读取图片与标签 支持 transform 自动增强
DataLoader 批量化输入 shuffle 提升训练效果
train() 前向传播与反向传播 更新梯度
Test() 模型评估 计算平均损失与准确率
ReduceLROnPlateau 学习率调整 自动降低学习率防止过拟合
相关推荐
独好紫罗兰20 分钟前
对python的再认识-基于数据结构进行-a006-元组-拓展
开发语言·数据结构·python
Dfreedom.22 分钟前
图像直方图完全解析:从原理到实战应用
图像处理·python·opencv·直方图·直方图均衡化
KYGALYX29 分钟前
逻辑回归详解
算法·机器学习·逻辑回归
铉铉这波能秀38 分钟前
LeetCode Hot100数据结构背景知识之集合(Set)Python2026新版
数据结构·python·算法·leetcode·哈希算法
啵啵鱼爱吃小猫咪1 小时前
机械臂能量分析
线性代数·机器学习·概率论
怒放吧德德1 小时前
Python3基础:基础实战巩固,从“会用”到“活用”
后端·python
aiguangyuan1 小时前
基于BERT的中文命名实体识别实战解析
人工智能·python·nlp
喵手1 小时前
Python爬虫实战:知识挖掘机 - 知乎问答与专栏文章的深度分页采集系统(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·采集知乎问答与专栏文章·采集知乎数据·采集知乎数据存储sqlite
铉铉这波能秀1 小时前
LeetCode Hot100数据结构背景知识之元组(Tuple)Python2026新版
数据结构·python·算法·leetcode·元组·tuple
kali-Myon1 小时前
2025春秋杯网络安全联赛冬季赛-day2
python·安全·web安全·ai·php·pwn·ctf