深度学习——基于ResNet18迁移学习的图像分类模型

基于ResNet18迁移学习的20类图像分类模型实现

一、项目背景与设计目标

在深度学习视觉任务中,卷积神经网络(CNN) 已经成为图像分类、检测与识别的核心工具。然而,从零开始训练一个CNN模型往往需要数十万甚至上百万的标注样本,训练成本高昂。因此,迁移学习(Transfer Learning) 成为一种极为实用的策略。

本文采用 PyTorch 框架 ,基于 ResNet18 预训练模型,对20类食物图像进行分类训练。通过冻结卷积层、仅训练全连接层的方式,我们能够充分利用ResNet在ImageNet上的学习能力,在小数据集上快速实现高准确率。


二、环境配置与模块导入

首先导入所需的库模块:

复制代码
import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
import torch.optim as optim

这些模块的功能如下:

  • torch:核心深度学习框架,提供张量计算和GPU加速。

  • torchvision.models:包含大量预训练模型,如ResNet、VGG、DenseNet等。

  • nn:神经网络构建模块。

  • Dataset/DataLoader:用于自定义数据集与批量加载。

  • transforms:图像数据增强工具。

  • PIL.Image:图像读取与处理。

  • optim:优化器模块(如Adam、SGD等)。


三、加载与修改预训练模型

迁移学习的第一步是加载一个在大型数据集(如ImageNet)上训练好的模型。代码如下:

复制代码
resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)

此时模型包含完整的ResNet18结构(卷积层、残差块、全连接层等),并自带预训练权重。

为了实现迁移学习,我们通常 冻结前面的卷积层权重,只微调最后的分类层:

复制代码
for param in resnet_model.parameters():
    param.requires_grad = False

这一步可以避免破坏原有的特征提取能力,从而提高小样本任务的训练稳定性与效率。

接着替换最后的全连接层:

复制代码
in_features = resnet_model.fc.in_features
resnet_model.fc = nn.Sequential(
    nn.Linear(in_features, 512),
    nn.ReLU(inplace=True),
    nn.Dropout(0.5),
    nn.Linear(512, 20)
)

修改的含义如下:

  • 输入特征数(in_features):保留ResNet最后一层输出的特征维度。

  • 中间层512神经元:增加网络非线性表达能力。

  • Dropout(0.5):防止过拟合。

  • 输出层20:对应目标数据集的20个类别。


四、训练参数配置

只训练新加入的全连接层参数:

复制代码
params_to_update = [p for p in resnet_model.parameters() if p.requires_grad]

并将模型放置于可用的计算设备(GPU、MPS或CPU)上:

复制代码
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device")

这种自动选择机制可在不同平台上无缝运行。


五、数据预处理与增强

数据增强(Data Augmentation)能显著提高模型的泛化性能。此处定义了训练集与验证集的不同预处理:

复制代码
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize([300, 300]),
        transforms.RandomRotation(45),
        transforms.CenterCrop(224),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.RandomGrayscale(p=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'valid': transforms.Compose([
        transforms.Resize([224, 224]),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

增强手段包括旋转、翻转、灰度化等,能让模型在不同光照、角度下都具备鲁棒性。


六、自定义数据集类

通过继承 torch.utils.data.Dataset 实现自定义数据加载:

复制代码
class food_dataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.imgs, self.labels = [], []
        with open(file_path) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path, label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)
        self.transform = transform

    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        image = Image.open(self.imgs[idx])
        if self.transform:
            image = self.transform(image)
        label = torch.from_numpy(np.array(self.labels[idx], dtype=np.int64))
        return image, label

该类通过 .txt 文件读取样本路径与类别标签,实现灵活的数据管理。


七、加载数据与构建迭代器

复制代码
training_data = food_dataset(file_path='./train2.txt', transform=data_transforms['train'])
test_data = food_dataset(file_path='./test2.txt', transform=data_transforms['valid'])

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

DataLoader 支持批量加载与随机打乱(shuffle),是PyTorch训练循环的核心组件。


八、训练与验证流程设计

(1)训练函数
复制代码
def train(dataloader, model, loss_fn, optimizer):
    model.train()
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        pred = model(X)
        loss = loss_fn(pred, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

说明:

  • 模型设置为训练模式 model.train()

  • 前向传播得到预测结果。

  • 计算损失后反向传播梯度并更新参数。

(2)验证函数
复制代码
best_acc = 0
def test(dataloader, model, loss_fn):
    global best_acc
    model.eval()
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f}\n")
    if correct > best_acc:
        best_acc = correct
        torch.save(model.state_dict(), "best_model.pth")
    return test_loss

该函数实现模型评估与 最佳模型保存 功能(当验证准确率提升时保存参数)。


九、优化器与学习率调度

使用 Adam 优化器自适应学习率调度器

复制代码
optimizer = torch.optim.Adam(params_to_update, lr=0.001)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='max', factor=0.5, patience=5, threshold=0.00001)

说明:

  • Adam 结合动量与自适应学习率机制,训练稳定。

  • ReduceLROnPlateau:当验证集准确率长时间不提升时,自动减小学习率以细化优化。


十、完整训练流程

主循环如下:

复制代码
epochs = 10
for t in range(epochs):
    print(f"Epoch {t + 1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    loss = test(test_dataloader, model, loss_fn)
    scheduler.step(loss)
print("Done!")
print(f"Best accuracy: {(100 * best_acc):>0.1f}%")

每一轮训练后会打印验证集性能,并根据结果动态调整学习率。最终输出最高准确率。


十一、性能与改进建议

(1) 性能特点
  • 迁移学习使模型快速收敛;

  • 数据增强显著提升泛化能力;

  • 仅微调全连接层降低训练难度;

  • 自动保存最佳模型保证结果稳定。

(2) 可优化方向
  • 使用 混合精度训练(AMP) 提升GPU效率;

  • 调整 Batch Size学习率衰减策略

  • 应用 K-Fold交叉验证 提高鲁棒性;

  • 在更大数据集上解冻部分残差层进行微调。


十二、结语

本文完整展示了一个基于 ResNet18 迁移学习 的20类图像分类任务,从模型加载、参数冻结、数据增强、训练与验证流程,到优化器与学习率调度的全流程实现。

通过冻结特征提取层、仅微调分类层的设计,我们能够以极低的训练成本获得高准确率模型,体现了迁移学习在现实任务中的高效性与实用价值。

相关推荐
TDengine (老段)3 小时前
TDengine IDMP 工业数据建模 —— 数据情景化
大数据·数据库·人工智能·时序数据库·iot·tdengine·涛思数据
Omics Pro3 小时前
端到端单细胞空间组学数据分析
大数据·数据库·人工智能·算法·数据挖掘·数据分析·aigc
zzb15803 小时前
Agent记忆与检索
java·人工智能·python·学习·ai
这张生成的图像能检测吗3 小时前
(论文速读)MoECLIP:零射异常检测补丁专家
人工智能·深度学习·计算机视觉·异常检测·clip·zero-shot方法
TOSUN同星3 小时前
研发周期缩短、成本压力大?同星云平台用“数字孪生+AI”重构研发模式
人工智能·重构
wzl202612133 小时前
从0到1搭建私域数据中台——公域引流的数据采集与分析
python·自动化·企业微信
Deepoch3 小时前
Deepoc具身模型:让智能轮椅从“避障”转向“预判”
人工智能·科技·开发板·具身模型·deepoc
源码之家3 小时前
大数据毕业设计汽车推荐系统 Django框架 可视化 协同过滤算法 数据分析 大数据 机器学习(建议收藏)✅
大数据·python·算法·django·汽车·课程设计·美食
TopDawn3 小时前
自然语言处理
人工智能·自然语言处理
HealthScience3 小时前
COM Surrogate的dllhost.exe高占用(磁盘)解决
python