深度学习——基于ResNet18迁移学习的图像分类模型

基于ResNet18迁移学习的20类图像分类模型实现

一、项目背景与设计目标

在深度学习视觉任务中,卷积神经网络(CNN) 已经成为图像分类、检测与识别的核心工具。然而,从零开始训练一个CNN模型往往需要数十万甚至上百万的标注样本,训练成本高昂。因此,迁移学习(Transfer Learning) 成为一种极为实用的策略。

本文采用 PyTorch 框架 ,基于 ResNet18 预训练模型,对20类食物图像进行分类训练。通过冻结卷积层、仅训练全连接层的方式,我们能够充分利用ResNet在ImageNet上的学习能力,在小数据集上快速实现高准确率。


二、环境配置与模块导入

首先导入所需的库模块:

复制代码
import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import torch.optim as optim

这些模块的功能如下:

  • torch:核心深度学习框架,提供张量计算和GPU加速。

  • torchvision.models:包含大量预训练模型,如ResNet、VGG、DenseNet等。

  • nn:神经网络构建模块。

  • Dataset/DataLoader:用于自定义数据集与批量加载。

  • transforms:图像数据增强工具。

  • PIL.Image:图像读取与处理。

  • optim:优化器模块(如Adam、SGD等)。


三、加载与修改预训练模型

迁移学习的第一步是加载一个在大型数据集(如ImageNet)上训练好的模型。代码如下:

复制代码
resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

此时模型包含完整的ResNet18结构(卷积层、残差块、全连接层等),并自带预训练权重。

为了实现迁移学习,我们通常 冻结前面的卷积层权重,只微调最后的分类层:

复制代码
for param in resnet_model.parameters():
    param.requires_grad = False

这一步可以避免破坏原有的特征提取能力,从而提高小样本任务的训练稳定性与效率。

接着替换最后的全连接层:

复制代码
in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Sequential(
    nn.Linear(in_features, 512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.5),
    nn.Linear(512, 20)
)

修改的含义如下:

  • 输入特征数(in_features):保留ResNet最后一层输出的特征维度。

  • 中间层512神经元:增加网络非线性表达能力。

  • Dropout(0.5):防止过拟合。

  • 输出层20:对应目标数据集的20个类别。


四、训练参数配置

只训练新加入的全连接层参数:

复制代码
params_to_update = [p for p in resnet_model.parameters() if p.requires_grad]

并将模型放置于可用的计算设备(GPU、MPS或CPU)上:

复制代码
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

这种自动选择机制可在不同平台上无缝运行。


五、数据预处理与增强

数据增强(Data Augmentation)能显著提高模型的泛化性能。此处定义了训练集与验证集的不同预处理:

复制代码
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize([300, 300]),
        transforms.RandomRotation(45),
        transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomGrayscale(p=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize([224, 224]),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

增强手段包括旋转、翻转、灰度化等,能让模型在不同光照、角度下都具备鲁棒性。


六、自定义数据集类

通过继承 torch.utils.data.Dataset 实现自定义数据加载:

复制代码
class food_dataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.imgs, self.labels = [], []
        with open(file_path) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path, label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)
        self.transform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        image = Image.open(self.imgs[idx])
        if self.transform:
            image = self.transform(image)
        label = torch.from_numpy(np.array(self.labels[idx], dtype=np.int64))
        return image, label

该类通过 .txt 文件读取样本路径与类别标签,实现灵活的数据管理。


七、加载数据与构建迭代器

复制代码
training_data = food_dataset(file_path='./train2.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path='./test2.txt', transform=data_transforms['valid'])

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

DataLoader 支持批量加载与随机打乱(shuffle),是PyTorch训练循环的核心组件。


八、训练与验证流程设计

(1)训练函数
复制代码
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

说明:

  • 模型设置为训练模式 model.train()

  • 前向传播得到预测结果。

  • 计算损失后反向传播梯度并更新参数。

(2)验证函数
复制代码
best_acc = 0
def test(dataloader, model, loss_fn):
    global best_acc
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f}\n")
    if correct > best_acc:
        best_acc = correct
        torch.save(model.state_dict(), "best_model.pth")
    return test_loss

该函数实现模型评估与 最佳模型保存 功能(当验证准确率提升时保存参数)。


九、优化器与学习率调度

使用 Adam 优化器自适应学习率调度器

复制代码
optimizer = torch.optim.Adam(params_to_update, lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=5, threshold=0.00001)

说明:

  • Adam 结合动量与自适应学习率机制,训练稳定。

  • ReduceLROnPlateau:当验证集准确率长时间不提升时,自动减小学习率以细化优化。


十、完整训练流程

主循环如下:

复制代码
epochs = 10
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    loss = test(test_dataloader, model, loss_fn)
    scheduler.step(loss)
print("Done!")
print(f"Best accuracy: {(100 * best_acc):>0.1f}%")

每一轮训练后会打印验证集性能,并根据结果动态调整学习率。最终输出最高准确率。


十一、性能与改进建议

(1) 性能特点
  • 迁移学习使模型快速收敛;

  • 数据增强显著提升泛化能力;

  • 仅微调全连接层降低训练难度;

  • 自动保存最佳模型保证结果稳定。

(2) 可优化方向
  • 使用 混合精度训练(AMP) 提升GPU效率;

  • 调整 Batch Size学习率衰减策略

  • 应用 K-Fold交叉验证 提高鲁棒性;

  • 在更大数据集上解冻部分残差层进行微调。


十二、结语

本文完整展示了一个基于 ResNet18 迁移学习 的20类图像分类任务,从模型加载、参数冻结、数据增强、训练与验证流程,到优化器与学习率调度的全流程实现。

通过冻结特征提取层、仅微调分类层的设计,我们能够以极低的训练成本获得高准确率模型,体现了迁移学习在现实任务中的高效性与实用价值。

相关推荐
F_D_Z1 分钟前
【解决办法】网络训练报错AttributeError: module ‘jax.core‘ has no attribute ‘Shape‘.
开发语言·python·jax
Python私教5 分钟前
别让 API Key 裸奔:基于 TRAE SOLO 的大模型安全配置最佳实践
人工智能
Python私教7 分钟前
Vibe Coding 体验报告:我让 TRAE SOLO 替我重构了 2000 行屎山代码,结果...
人工智能
prog_61038 分钟前
【笔记】和各大AI语言模型写项目——手搓SDN后得到的经验
人工智能·笔记·语言模型
前端伪大叔14 分钟前
第29篇:99% 的量化新手死在挂单上:Freqtrade 隐藏技能揭秘
后端·python·github
zhangfeng113314 分钟前
深入剖析Kimi K2 Thinking与其他大规模语言模型(Large Language Models, LLMs)之间的差异
人工智能·语言模型·自然语言处理
paopao_wu31 分钟前
人脸检测与识别-InsightFace:特征向量提取与识别
人工智能·目标检测
Aevget43 分钟前
MyEclipse全新发布v2025.2——AI + Java 24 +更快的调试
java·ide·人工智能·eclipse·myeclipse
IT_陈寒1 小时前
React 18并发渲染实战:5个核心API让你的应用性能飙升50%
前端·人工智能·后端
韩曙亮1 小时前
【人工智能】AI 人工智能 技术 学习路径分析 ① ( Python语言 -> 微积分 / 概率论 / 线性代数 -> 机器学习 )
人工智能·python·学习·数学·机器学习·ai·微积分