深度学习——基于ResNet18迁移学习的图像分类模型

基于ResNet18迁移学习的20类图像分类模型实现

一、项目背景与设计目标

在深度学习视觉任务中,卷积神经网络(CNN) 已经成为图像分类、检测与识别的核心工具。然而,从零开始训练一个CNN模型往往需要数十万甚至上百万的标注样本,训练成本高昂。因此,迁移学习(Transfer Learning) 成为一种极为实用的策略。

本文采用 PyTorch 框架 ,基于 ResNet18 预训练模型,对20类食物图像进行分类训练。通过冻结卷积层、仅训练全连接层的方式,我们能够充分利用ResNet在ImageNet上的学习能力,在小数据集上快速实现高准确率。


二、环境配置与模块导入

首先导入所需的库模块:

复制代码
import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import torch.optim as optim

这些模块的功能如下:

  • torch:核心深度学习框架,提供张量计算和GPU加速。

  • torchvision.models:包含大量预训练模型,如ResNet、VGG、DenseNet等。

  • nn:神经网络构建模块。

  • Dataset/DataLoader:用于自定义数据集与批量加载。

  • transforms:图像数据增强工具。

  • PIL.Image:图像读取与处理。

  • optim:优化器模块(如Adam、SGD等)。


三、加载与修改预训练模型

迁移学习的第一步是加载一个在大型数据集(如ImageNet)上训练好的模型。代码如下:

复制代码
resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

此时模型包含完整的ResNet18结构(卷积层、残差块、全连接层等),并自带预训练权重。

为了实现迁移学习,我们通常 冻结前面的卷积层权重,只微调最后的分类层:

复制代码
for param in resnet_model.parameters():
    param.requires_grad = False

这一步可以避免破坏原有的特征提取能力,从而提高小样本任务的训练稳定性与效率。

接着替换最后的全连接层:

复制代码
in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Sequential(
    nn.Linear(in_features, 512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.5),
    nn.Linear(512, 20)
)

修改的含义如下:

  • 输入特征数(in_features):保留ResNet最后一层输出的特征维度。

  • 中间层512神经元:增加网络非线性表达能力。

  • Dropout(0.5):防止过拟合。

  • 输出层20:对应目标数据集的20个类别。


四、训练参数配置

只训练新加入的全连接层参数:

复制代码
params_to_update = [p for p in resnet_model.parameters() if p.requires_grad]

并将模型放置于可用的计算设备(GPU、MPS或CPU)上:

复制代码
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

这种自动选择机制可在不同平台上无缝运行。


五、数据预处理与增强

数据增强(Data Augmentation)能显著提高模型的泛化性能。此处定义了训练集与验证集的不同预处理:

复制代码
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize([300, 300]),
        transforms.RandomRotation(45),
        transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomGrayscale(p=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize([224, 224]),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

增强手段包括旋转、翻转、灰度化等,能让模型在不同光照、角度下都具备鲁棒性。


六、自定义数据集类

通过继承 torch.utils.data.Dataset 实现自定义数据加载:

复制代码
class food_dataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.imgs, self.labels = [], []
        with open(file_path) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path, label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)
        self.transform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        image = Image.open(self.imgs[idx])
        if self.transform:
            image = self.transform(image)
        label = torch.from_numpy(np.array(self.labels[idx], dtype=np.int64))
        return image, label

该类通过 .txt 文件读取样本路径与类别标签,实现灵活的数据管理。


七、加载数据与构建迭代器

复制代码
training_data = food_dataset(file_path='./train2.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path='./test2.txt', transform=data_transforms['valid'])

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

DataLoader 支持批量加载与随机打乱(shuffle),是PyTorch训练循环的核心组件。


八、训练与验证流程设计

(1)训练函数
复制代码
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

说明:

  • 模型设置为训练模式 model.train()

  • 前向传播得到预测结果。

  • 计算损失后反向传播梯度并更新参数。

(2)验证函数
复制代码
best_acc = 0
def test(dataloader, model, loss_fn):
    global best_acc
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f}\n")
    if correct > best_acc:
        best_acc = correct
        torch.save(model.state_dict(), "best_model.pth")
    return test_loss

该函数实现模型评估与 最佳模型保存 功能(当验证准确率提升时保存参数)。


九、优化器与学习率调度

使用 Adam 优化器自适应学习率调度器

复制代码
optimizer = torch.optim.Adam(params_to_update, lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=5, threshold=0.00001)

说明:

  • Adam 结合动量与自适应学习率机制,训练稳定。

  • ReduceLROnPlateau:当验证集准确率长时间不提升时,自动减小学习率以细化优化。


十、完整训练流程

主循环如下:

复制代码
epochs = 10
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    loss = test(test_dataloader, model, loss_fn)
    scheduler.step(loss)
print("Done!")
print(f"Best accuracy: {(100 * best_acc):>0.1f}%")

每一轮训练后会打印验证集性能,并根据结果动态调整学习率。最终输出最高准确率。


十一、性能与改进建议

(1) 性能特点
  • 迁移学习使模型快速收敛;

  • 数据增强显著提升泛化能力;

  • 仅微调全连接层降低训练难度;

  • 自动保存最佳模型保证结果稳定。

(2) 可优化方向
  • 使用 混合精度训练(AMP) 提升GPU效率;

  • 调整 Batch Size学习率衰减策略

  • 应用 K-Fold交叉验证 提高鲁棒性;

  • 在更大数据集上解冻部分残差层进行微调。


十二、结语

本文完整展示了一个基于 ResNet18 迁移学习 的20类图像分类任务,从模型加载、参数冻结、数据增强、训练与验证流程,到优化器与学习率调度的全流程实现。

通过冻结特征提取层、仅微调分类层的设计,我们能够以极低的训练成本获得高准确率模型,体现了迁移学习在现实任务中的高效性与实用价值。

相关推荐
feasibility.1 分钟前
反爬十层妖塔:现代爬虫攻防的立体战争
爬虫·python·科技·scrapy·rust·go·硬件
还在忙碌的吴小二3 分钟前
今日AI行业热点新闻
人工智能
十八旬10 分钟前
快速安装ClaudeCode完整指南
开发语言·windows·python·claude
Bode_200211 分钟前
AIoT 技术难点
人工智能·制造
deming_su44 分钟前
AI产品架构师核心理论知识点文档
人工智能
XD7429716361 小时前
科技晚报|2026年5月13日:AI 开始补全库审查、移动入口和弹性调度
人工智能·科技·开发者工具·科技晚报
dFObBIMmai1 小时前
如何在 CSS 中实现元素的绝对定位,使其不受窗口尺寸变化影响
jvm·数据库·python
卷Java1 小时前
2026年4月AI军备竞赛全景:DeepSeek V4 vs GPT-5.5 vs Gemini vs Claude
人工智能·gpt·大模型
人月神话-Lee1 小时前
【图像处理】亮度与对比度——图像的线性变换
图像处理·人工智能·ios·ai编程·swift
WL_Aurora1 小时前
Python 算法基础篇之集合
python·算法