PyTorch 基于 ResNet18 的图像分类训练与验证全流程解析
一、项目概述
本文实现了一个基于 PyTorch 框架的图像分类模型,使用 ResNet18 作为预训练骨干网络(Backbone),并在其基础上进行迁移学习(Transfer Learning)。整个流程涵盖了:
-
数据预处理与增强
-
自定义 Dataset 与 DataLoader
-
模型微调与参数冻结
-
训练与验证循环
-
学习率调度策略(ReduceLROnPlateau)
该项目的核心目标是利用已有的强大视觉特征提取网络(ResNet18)对新的小规模数据集进行分类任务,从而提升训练效率与模型性能。
二、模型部分解析:ResNet18 微调机制
import torch
import torchvision.models as models
from torch import nn, optim
resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
这里加载了 torchvision.models 中的预训练 ResNet18 模型,其权重参数来自 ImageNet 大规模数据集的训练结果。
接着,冻结网络的所有参数,防止在训练过程中被更新:
for param in resnet_model.parameters():
print(param)
param.requires_grad = False
✅ 原理说明:
迁移学习的关键思路在于"保留特征提取层"。
早期卷积层学习的是通用特征(如边缘、纹理),可直接用于新任务。
仅需微调后几层或分类头层(fc层),可显著减少训练量。
然后替换最后一层全连接层(fc层):
in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Linear(in_features, 20)
🔹 原始 ResNet18 的输出为 1000 类(ImageNet)。
🔹 这里改为 20 类,适配自定义数据集。
最后仅选择需要更新的参数:
params_to_update = []
for param in resnet_model.parameters():
if param.requires_grad == True:
params_to_update.append(param)
这意味着优化器只会更新新加入的全连接层参数。
三、数据预处理与增强(Data Augmentation)
数据增强可提升模型泛化能力,代码中定义了两种处理策略:
from torchvision import transforms
1. 训练集增强 data_transforms['train']
包含大量随机性增强操作:
transforms.Compose([
transforms.Resize([300,300]),
transforms.RandomRotation(45),
transforms.CenterCrop(224),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),
transforms.RandomGrayscale(p=0.1),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
📘 数据增强效果:
旋转 、翻转 、灰度转换可提升模型在多视角条件下的鲁棒性。
归一化操作确保输入分布与预训练模型保持一致。
2. 验证集预处理 data_transforms['valid']
transforms.Compose([
transforms.Resize([256,256]),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
验证集通常不使用随机增强,以保持结果的可重复性和客观性。
四、自定义 Dataset 与 DataLoader
1. Dataset 类定义
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
class food_dataset(Dataset):
def __init__(self, file_path, transform=None):
self.file_path = file_path
self.transform = transform
self.imgs = []
self.labels = []
with open(self.file_path) as f:
samples = [x.strip().split(' ') for x in f.readlines()]
for img_path,label in samples:
self.imgs.append(img_path)
self.labels.append(label)
该类通过文本文件(如 train2.txt
)加载图片路径和标签。
每一行格式为:
image_path label
2. 数据访问接口
def __len__(self):
return len(self.imgs)
def __getitem__(self, index):
image = Image.open(self.imgs[index])
if self.transform:
image = self.transform(image)
label = torch.from_numpy(np.array(self.labels[index], dtype=np.int64))
return image, label
⚙️ Dataset 必备方法:
__len__()
:返回数据集大小。
__getitem__()
:返回一条样本及其标签。
3. 数据加载器 DataLoader
training_data = food_dataset('./train2.txt', transform=data_transforms['train'])
test_data = food_dataset('./test2.txt', transform=data_transforms['valid'])
train_dataloader = DataLoader(training_data, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=16, shuffle=True)
✅ DataLoader 的作用:
自动打包 batch
支持多线程加载(
num_workers
)支持数据打乱(shuffle)
五、训练环境与优化器设置
1. 自动选择设备
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")
同时兼容:
-
NVIDIA GPU (CUDA)
-
Apple M1/M2 GPU (MPS)
-
CPU
2. 定义损失函数与优化器
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(params_to_update, lr=0.001)
使用 交叉熵损失函数 处理多分类任务,优化器为 Adam。
3. 学习率调度器
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer,
mode='min',
factor=0.5,
patience=3,
verbose=True
)
🔹 当验证集 Loss 连续 3 轮未改善时,学习率减半。
🔹 可有效防止过拟合与梯度振荡。
六、训练与验证循环
1. 训练函数
def train(dataloader, model, loss_fn, optimizer):
model.train()
batch_size_num = 1
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch_size_num % 100 == 0:
print(f"loss: {loss.item():>7f} [number:{batch_size_num}]")
batch_size_num += 1
🔁 每 100 个 batch 打印一次 loss。
2. 测试函数
def Test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test result:\n Accuracy:{100*correct}%, Avg loss:{test_loss}")
return test_loss
🧮 评估指标:
Accuracy(准确率)
Avg loss(平均验证损失)
七、主训练循环
epochs = 50
for t in range(epochs):
print(f"---------------\nepoch {t+1}")
train(train_dataloader, model, loss_fn, optimizer)
val_loss = Test(test_dataloader, model, loss_fn)
scheduler.step(val_loss)
print("Done!")
共进行 50 轮训练,每轮包括:
-
模型训练
-
验证集测试
-
根据验证集 loss 调整学习率
💡 随着 epoch 增加,loss 应逐渐下降,accuracy 提升。
八、总结
模块 | 作用 | 特点 |
---|---|---|
ResNet18 | 特征提取主干 | 使用 ImageNet 预训练权重 |
Dataset | 读取图片与标签 | 支持 transform 自动增强 |
DataLoader | 批量化输入 | shuffle 提升训练效果 |
train() | 前向传播与反向传播 | 更新梯度 |
Test() | 模型评估 | 计算平均损失与准确率 |
ReduceLROnPlateau | 学习率调整 | 自动降低学习率防止过拟合 |