深度学习——基于ResNet18迁移学习的图像分类模型

基于ResNet18迁移学习的20类图像分类模型实现

一、项目背景与设计目标

在深度学习视觉任务中,卷积神经网络(CNN) 已经成为图像分类、检测与识别的核心工具。然而,从零开始训练一个CNN模型往往需要数十万甚至上百万的标注样本,训练成本高昂。因此,迁移学习(Transfer Learning) 成为一种极为实用的策略。

本文采用 PyTorch 框架 ,基于 ResNet18 预训练模型,对20类食物图像进行分类训练。通过冻结卷积层、仅训练全连接层的方式,我们能够充分利用ResNet在ImageNet上的学习能力,在小数据集上快速实现高准确率。


二、环境配置与模块导入

首先导入所需的库模块:

复制代码
import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import torch.optim as optim

这些模块的功能如下:

  • torch:核心深度学习框架,提供张量计算和GPU加速。

  • torchvision.models:包含大量预训练模型,如ResNet、VGG、DenseNet等。

  • nn:神经网络构建模块。

  • Dataset/DataLoader:用于自定义数据集与批量加载。

  • transforms:图像数据增强工具。

  • PIL.Image:图像读取与处理。

  • optim:优化器模块(如Adam、SGD等)。


三、加载与修改预训练模型

迁移学习的第一步是加载一个在大型数据集(如ImageNet)上训练好的模型。代码如下:

复制代码
resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

此时模型包含完整的ResNet18结构(卷积层、残差块、全连接层等),并自带预训练权重。

为了实现迁移学习,我们通常 冻结前面的卷积层权重,只微调最后的分类层:

复制代码
for param in resnet_model.parameters():
    param.requires_grad = False

这一步可以避免破坏原有的特征提取能力,从而提高小样本任务的训练稳定性与效率。

接着替换最后的全连接层:

复制代码
in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Sequential(
    nn.Linear(in_features, 512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.5),
    nn.Linear(512, 20)
)

修改的含义如下:

  • 输入特征数(in_features):保留ResNet最后一层输出的特征维度。

  • 中间层512神经元:增加网络非线性表达能力。

  • Dropout(0.5):防止过拟合。

  • 输出层20:对应目标数据集的20个类别。


四、训练参数配置

只训练新加入的全连接层参数:

复制代码
params_to_update = [p for p in resnet_model.parameters() if p.requires_grad]

并将模型放置于可用的计算设备(GPU、MPS或CPU)上:

复制代码
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

这种自动选择机制可在不同平台上无缝运行。


五、数据预处理与增强

数据增强(Data Augmentation)能显著提高模型的泛化性能。此处定义了训练集与验证集的不同预处理:

复制代码
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize([300, 300]),
        transforms.RandomRotation(45),
        transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomGrayscale(p=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize([224, 224]),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

增强手段包括旋转、翻转、灰度化等,能让模型在不同光照、角度下都具备鲁棒性。


六、自定义数据集类

通过继承 torch.utils.data.Dataset 实现自定义数据加载:

复制代码
class food_dataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.imgs, self.labels = [], []
        with open(file_path) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path, label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)
        self.transform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        image = Image.open(self.imgs[idx])
        if self.transform:
            image = self.transform(image)
        label = torch.from_numpy(np.array(self.labels[idx], dtype=np.int64))
        return image, label

该类通过 .txt 文件读取样本路径与类别标签,实现灵活的数据管理。


七、加载数据与构建迭代器

复制代码
training_data = food_dataset(file_path='./train2.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path='./test2.txt', transform=data_transforms['valid'])

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

DataLoader 支持批量加载与随机打乱(shuffle),是PyTorch训练循环的核心组件。


八、训练与验证流程设计

(1)训练函数
复制代码
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

说明:

  • 模型设置为训练模式 model.train()

  • 前向传播得到预测结果。

  • 计算损失后反向传播梯度并更新参数。

(2)验证函数
复制代码
best_acc = 0
def test(dataloader, model, loss_fn):
    global best_acc
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f}\n")
    if correct > best_acc:
        best_acc = correct
        torch.save(model.state_dict(), "best_model.pth")
    return test_loss

该函数实现模型评估与 最佳模型保存 功能(当验证准确率提升时保存参数)。


九、优化器与学习率调度

使用 Adam 优化器自适应学习率调度器

复制代码
optimizer = torch.optim.Adam(params_to_update, lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=5, threshold=0.00001)

说明:

  • Adam 结合动量与自适应学习率机制,训练稳定。

  • ReduceLROnPlateau:当验证集准确率长时间不提升时,自动减小学习率以细化优化。


十、完整训练流程

主循环如下:

复制代码
epochs = 10
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    loss = test(test_dataloader, model, loss_fn)
    scheduler.step(loss)
print("Done!")
print(f"Best accuracy: {(100 * best_acc):>0.1f}%")

每一轮训练后会打印验证集性能,并根据结果动态调整学习率。最终输出最高准确率。


十一、性能与改进建议

(1) 性能特点
  • 迁移学习使模型快速收敛;

  • 数据增强显著提升泛化能力;

  • 仅微调全连接层降低训练难度;

  • 自动保存最佳模型保证结果稳定。

(2) 可优化方向
  • 使用 混合精度训练(AMP) 提升GPU效率;

  • 调整 Batch Size学习率衰减策略

  • 应用 K-Fold交叉验证 提高鲁棒性;

  • 在更大数据集上解冻部分残差层进行微调。


十二、结语

本文完整展示了一个基于 ResNet18 迁移学习 的20类图像分类任务,从模型加载、参数冻结、数据增强、训练与验证流程,到优化器与学习率调度的全流程实现。

通过冻结特征提取层、仅微调分类层的设计,我们能够以极低的训练成本获得高准确率模型,体现了迁移学习在现实任务中的高效性与实用价值。

相关推荐
小钱c73 小时前
Python使用 pandas操作Excel文件并新增列数据
python·excel·pandas
sunkl_3 小时前
JoyAgent问数多表关联Bug修复
人工智能·自然语言处理
AI数据皮皮侠4 小时前
中国博物馆数据
大数据·人工智能·python·深度学习·机器学习
强哥之神4 小时前
从零理解 KV Cache:大语言模型推理加速的核心机制
人工智能·深度学习·机器学习·语言模型·llm·kvcache
中达瑞和-高光谱·多光谱4 小时前
多光谱图像颜色特征用于茶叶分类的研究进展
人工智能·分类·数据挖掘
格林威4 小时前
UV 紫外相机在半导体制造领域的应用
人工智能·数码相机·opencv·计算机视觉·视觉检测·制造·uv
wu_jing_sheng04 小时前
Python中使用HTTP 206状态码实现大文件下载的完整指南
开发语言·前端·python
精英的英4 小时前
【工具开发】适用于交叉编译环境的QT qmake项目转换vscode项目插件
人工智能·vscode·qt·开源软件