深度学习——基于ResNet18迁移学习的图像分类模型

基于ResNet18迁移学习的20类图像分类模型实现

一、项目背景与设计目标

在深度学习视觉任务中,卷积神经网络(CNN) 已经成为图像分类、检测与识别的核心工具。然而,从零开始训练一个CNN模型往往需要数十万甚至上百万的标注样本,训练成本高昂。因此,迁移学习(Transfer Learning) 成为一种极为实用的策略。

本文采用 PyTorch 框架 ,基于 ResNet18 预训练模型,对20类食物图像进行分类训练。通过冻结卷积层、仅训练全连接层的方式,我们能够充分利用ResNet在ImageNet上的学习能力,在小数据集上快速实现高准确率。


二、环境配置与模块导入

首先导入所需的库模块:

复制代码
import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import torch.optim as optim

这些模块的功能如下:

  • torch:核心深度学习框架,提供张量计算和GPU加速。

  • torchvision.models:包含大量预训练模型,如ResNet、VGG、DenseNet等。

  • nn:神经网络构建模块。

  • Dataset/DataLoader:用于自定义数据集与批量加载。

  • transforms:图像数据增强工具。

  • PIL.Image:图像读取与处理。

  • optim:优化器模块(如Adam、SGD等)。


三、加载与修改预训练模型

迁移学习的第一步是加载一个在大型数据集(如ImageNet)上训练好的模型。代码如下:

复制代码
resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

此时模型包含完整的ResNet18结构(卷积层、残差块、全连接层等),并自带预训练权重。

为了实现迁移学习,我们通常 冻结前面的卷积层权重,只微调最后的分类层:

复制代码
for param in resnet_model.parameters():
    param.requires_grad = False

这一步可以避免破坏原有的特征提取能力,从而提高小样本任务的训练稳定性与效率。

接着替换最后的全连接层:

复制代码
in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Sequential(
    nn.Linear(in_features, 512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.5),
    nn.Linear(512, 20)
)

修改的含义如下:

  • 输入特征数(in_features):保留ResNet最后一层输出的特征维度。

  • 中间层512神经元:增加网络非线性表达能力。

  • Dropout(0.5):防止过拟合。

  • 输出层20:对应目标数据集的20个类别。


四、训练参数配置

只训练新加入的全连接层参数:

复制代码
params_to_update = [p for p in resnet_model.parameters() if p.requires_grad]

并将模型放置于可用的计算设备(GPU、MPS或CPU)上:

复制代码
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

这种自动选择机制可在不同平台上无缝运行。


五、数据预处理与增强

数据增强(Data Augmentation)能显著提高模型的泛化性能。此处定义了训练集与验证集的不同预处理:

复制代码
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize([300, 300]),
        transforms.RandomRotation(45),
        transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomGrayscale(p=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize([224, 224]),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

增强手段包括旋转、翻转、灰度化等,能让模型在不同光照、角度下都具备鲁棒性。


六、自定义数据集类

通过继承 torch.utils.data.Dataset 实现自定义数据加载:

复制代码
class food_dataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.imgs, self.labels = [], []
        with open(file_path) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path, label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)
        self.transform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        image = Image.open(self.imgs[idx])
        if self.transform:
            image = self.transform(image)
        label = torch.from_numpy(np.array(self.labels[idx], dtype=np.int64))
        return image, label

该类通过 .txt 文件读取样本路径与类别标签,实现灵活的数据管理。


七、加载数据与构建迭代器

复制代码
training_data = food_dataset(file_path='./train2.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path='./test2.txt', transform=data_transforms['valid'])

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

DataLoader 支持批量加载与随机打乱(shuffle),是PyTorch训练循环的核心组件。


八、训练与验证流程设计

(1)训练函数
复制代码
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

说明:

  • 模型设置为训练模式 model.train()

  • 前向传播得到预测结果。

  • 计算损失后反向传播梯度并更新参数。

(2)验证函数
复制代码
best_acc = 0
def test(dataloader, model, loss_fn):
    global best_acc
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f}\n")
    if correct > best_acc:
        best_acc = correct
        torch.save(model.state_dict(), "best_model.pth")
    return test_loss

该函数实现模型评估与 最佳模型保存 功能(当验证准确率提升时保存参数)。


九、优化器与学习率调度

使用 Adam 优化器自适应学习率调度器

复制代码
optimizer = torch.optim.Adam(params_to_update, lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=5, threshold=0.00001)

说明:

  • Adam 结合动量与自适应学习率机制,训练稳定。

  • ReduceLROnPlateau:当验证集准确率长时间不提升时,自动减小学习率以细化优化。


十、完整训练流程

主循环如下:

复制代码
epochs = 10
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    loss = test(test_dataloader, model, loss_fn)
    scheduler.step(loss)
print("Done!")
print(f"Best accuracy: {(100 * best_acc):>0.1f}%")

每一轮训练后会打印验证集性能,并根据结果动态调整学习率。最终输出最高准确率。


十一、性能与改进建议

(1) 性能特点
  • 迁移学习使模型快速收敛;

  • 数据增强显著提升泛化能力;

  • 仅微调全连接层降低训练难度;

  • 自动保存最佳模型保证结果稳定。

(2) 可优化方向
  • 使用 混合精度训练(AMP) 提升GPU效率;

  • 调整 Batch Size学习率衰减策略

  • 应用 K-Fold交叉验证 提高鲁棒性;

  • 在更大数据集上解冻部分残差层进行微调。


十二、结语

本文完整展示了一个基于 ResNet18 迁移学习 的20类图像分类任务,从模型加载、参数冻结、数据增强、训练与验证流程,到优化器与学习率调度的全流程实现。

通过冻结特征提取层、仅微调分类层的设计,我们能够以极低的训练成本获得高准确率模型,体现了迁移学习在现实任务中的高效性与实用价值。

相关推荐
十三画者7 小时前
【文献分享】利用 GeneTEA 对基因描述进行自然语言处理以进行过表达分析
人工智能·自然语言处理
洞见新研社7 小时前
家庭机器人,从科幻到日常的二十年突围战
大数据·人工智能·机器人
qzhqbb7 小时前
神经网络 - 循环神经网络
人工智能·rnn·神经网络
newxtc7 小时前
【湖北政务服务网-注册_登录安全分析报告】
人工智能·selenium·测试工具·安全·政务
Oxo Security7 小时前
【AI安全】提示词注入
人工智能·安全·网络安全·ai
跳跳糖炒酸奶7 小时前
第十章、GPT1:Improving Language Understanding by Generative Pre-Training(代码部分)
人工智能·自然语言处理·大模型·transformer·gpt1
Chubxu7 小时前
从零本地跑通 Suna:一套可复刻的调试实践
人工智能
大叔_爱编程7 小时前
基于Python的历届奥运会数据可视化分析系统-django+spider
python·django·毕业设计·源码·课程设计·spider·奥运会数据可视化
小白狮ww7 小时前
模型不再是一整块!Hunyuan3D-Part 实现可控组件式 3D 生成
人工智能·深度学习·机器学习·教程·3d模型·hunyuan3d·3d创作
York·Zhang7 小时前
AI 下的 Agent 技术全览
人工智能·大模型·agent