基于 ResNet18 的迁移学习:食物图像分类实现

基于ResNet18的迁移学习实现食物图像分类

在计算机视觉领域,图像分类是经典任务之一,而面对特定领域的分类需求(如食物分类),从头训练深度神经网络不仅耗时耗力,还需要大量的标注数据。迁移学习作为一种高效的建模方法,能够将预训练模型在大规模数据集上学到的特征提取能力迁移到新任务中,大幅降低训练成本并提升模型效果。本文将以ResNet18为预训练模型,手把手教大家实现食物图像的20分类任务,全程使用PyTorch框架完成代码编写与模型训练。

一、迁移学习核心思路

迁移学习的核心是复用预训练模型的特征提取层,仅训练适配新任务的分类层

  1. 选用在ImageNet数据集上预训练的ResNet18模型,其卷积层等底层结构已能提取通用的图像特征(如边缘、纹理、形状等),这些特征对食物图像同样适用。

  2. 冻结预训练模型的所有特征提取层参数,避免训练时破坏已学到的通用特征。

  3. 替换ResNet18的最后全连接层(fc层),将原有的1000分类输出改为20分类(适配食物分类任务)。

  4. 仅训练新替换的全连接层参数,同时使用数据增强提升模型泛化能力,最终完成食物图像分类模型的训练。

二、环境准备

本次实验基于Python+PyTorch框架,需要安装以下核心依赖库:

复制代码

pip install torch torchvision pillow numpy

  • torch/torchvision:PyTorch核心框架,提供预训练模型、数据处理工具和神经网络模块。

  • pillow:Python图像处理库,用于读取和处理图像。

  • numpy:数值计算库,用于数据类型转换等操作。

同时确保电脑具备GPU(NVIDIA CUDA或Apple MPS)加速能力,大幅提升训练速度。

三、完整代码实现与详解

接下来将分模块讲解代码,从预训练模型加载、数据处理、数据集构建到模型训练与评估,实现端到端的食物分类模型开发。

3.1 导入核心库

首先导入实验所需的所有Python库,涵盖模型、数据处理、神经网络层等模块:

复制代码

import torch import torchvision.models as models # 包含各类预训练视觉模型 from torch import nn from torch.utils.data import Dataset, DataLoader # 自定义数据集和数据加载器 from torchvision import transforms # 图像变换与数据增强 from PIL import Image import numpy as np

3.2 加载预训练模型并改造

加载ResNet18预训练模型,冻结特征层参数,替换分类层以适配食物20分类任务:

复制代码

# 加载预训练ResNet18模型,使用默认预训练权重 resnet_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) # 冻结所有特征层参数,禁止梯度更新 for param in resnet_model.parameters(): param.requires_grad = False # 获取原fc层的输入特征数,替换为20分类的全连接层 in_features = resnet_model.fc.in_features resnet_model.fc = nn.Linear(in_features, 20) # 20为食物分类的类别数

关键说明

  • weights=models.ResNet18_Weights.DEFAULT:加载官方在ImageNet上的预训练权重,替代旧版的pretrained=True(PyTorch新版本推荐用法)。

  • param.requires_grad = False:冻结参数,让这些层在训练时不更新梯度,仅保留特征提取能力。

  • 替换fc层:ResNet18的最后一层是全连接层,原输出为1000类(ImageNet),此处改为20类,仅该层参数参与后续训练。

3.3 图像变换与数据增强

针对训练集和验证集设计不同的图像变换策略,训练集使用数据增强提升泛化能力,验证集仅做基础变换保证数据一致性

复制代码

data_transforms = { 'train': transforms.Compose([ transforms.Resize([300, 300]), # 缩放图像至300*300 transforms.RandomRotation(45), # 随机旋转(-45,45)度 transforms.CenterCrop(224), # 中心裁剪至224*224(ResNet输入尺寸) transforms.RandomHorizontalFlip(p=0.5), # 50%概率水平翻转 transforms.RandomVerticalFlip(p=0.5), # 50%概率垂直翻转 transforms.RandomGrayscale(p=0.1), # 10%概率转为灰度图 transforms.ToTensor(), # 转为Tensor,像素值归一化至[0,1] # 按ImageNet均值和标准差归一化,与预训练模型保持一致 transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), 'valid': transforms.Compose([ transforms.Resize([224, 224]), # 直接缩放至224*224 transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]), }

数据增强的意义:通过随机旋转、翻转、裁剪等操作,生成更多"虚拟训练样本",避免模型过拟合,提升对不同角度、尺度食物图像的识别能力。

归一化说明:必须使用ImageNet的均值和标准差,因为预训练模型是在该归一化规则下训练的,保证特征提取的一致性。

3.4 自定义食物数据集

PyTorch的Dataset类是自定义数据集的基础,此处实现读取食物图像路径和标签的自定义数据集,适配txt格式的样本清单(每行:图像路径 标签):

复制代码

class food_dataset(Dataset): def __init__(self, file_path, transform=None): self.file_path = file_path # 样本清单txt文件路径 self.imgs = [] # 存储所有图像路径 self.labels = [] # 存储所有图像标签 self.transform = transform # 图像变换策略 # 读取txt文件,解析图像路径和标签 with open(self.file_path) as f: samples = [x.strip().split(' ') for x in f.readlines()] for img_path, label in samples: self.imgs.append(img_path) self.labels.append(label) # 必须实现:返回数据集样本总数 def __len__(self): return len(self.imgs) # 必须实现:根据索引返回单个样本(图像+标签) def __getitem__(self, idx): # 读取图像,Pillow默认读取为RGB格式 image = Image.open(self.imgs[idx]) # 应用图像变换 if self.transform: image = self.transform(image) # 标签转换为64位整数Tensor,适配PyTorch损失函数 label = torch.from_numpy(np.array(self.labels[idx], dtype=np.int64)) return image, label

数据集格式要求 :需准备trainda.txt(训练集)和testda.txt(验证集),每行格式为xxx/xxx/food.jpg 0,其中0为类别标签(0-19)。

3.5 创建数据加载器

通过DataLoader将自定义数据集封装为批量迭代器,实现批量读取、随机打乱、多进程加载(PyTorch自动实现):

复制代码

# 创建训练集和验证集 training_data = food_dataset(file_path='./trainda.txt', transform=data_transforms['train']) test_data = food_dataset(file_path='./testda.txt', transform=data_transforms['valid']) # 创建数据加载器,batch_size=64表示每次读取64个样本 train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True) test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

参数说明

  • batch_size=64:根据GPU显存调整,显存小则调小(如32、16)。

  • shuffle=True:训练集每次迭代前打乱样本顺序,避免模型学习样本顺序规律;验证集打乱仅为方便,不影响结果。

3.6 设备配置与模型初始化

自动检测并使用GPU(CUDA/MPS),将模型移至指定设备,同时定义损失函数、优化器和学习率调度器:

复制代码

# 自动选择训练设备:CUDA > MPS > CPU device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" print(f"Using {device} device") # 将模型移至指定设备 model = resnet_model.to(device) # 定义损失函数:交叉熵损失(适用于多分类任务) loss_fn = nn.CrossEntropyLoss() # 定义优化器:Adam优化器,仅优化可训练参数(此处为新fc层) optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # 学习率调度器:每训练10个epoch,学习率乘以0.5 scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)

关键说明

  • 交叉熵损失:PyTorch的nn.CrossEntropyLoss已包含Softmax层,无需在模型最后手动添加。

  • Adam优化器:自适应学习率优化器,收敛速度快于SGD,适合迁移学习的小批量训练。

  • 学习率调度器:训练后期降低学习率,让模型在最优解附近收敛,提升精度。

3.7 定义训练和验证函数

训练函数

实现模型的单次epoch训练,包括前向传播、损失计算、反向传播、参数更新

复制代码

def train(dataloader, model, loss_fn, optimizer): model.train() # 将模型设为训练模式(启用Dropout/BatchNorm等层的训练特性) for X, y in dataloader: # 将数据移至指定设备 X, y = X.to(device), y.to(device) # 前向传播:获取模型预测结果 pred = model.forward(X) # 计算损失 loss = loss_fn(pred, y) # 梯度清零:避免上一批次梯度累积 optimizer.zero_grad() # 反向传播:计算梯度 loss.backward() # 优化器更新参数:仅更新可训练的fc层参数 optimizer.step()

验证函数

实现模型的单次epoch验证,关闭梯度计算以提升速度,计算验证集的准确率和平均损失,并保存最优模型的准确率:

复制代码

best_acc = 0 # 保存最优验证准确率 acc_s = [] # 记录每个epoch的验证准确率 loss_s = [] # 记录每个epoch的验证损失 def test(dataloader, model, loss_fn): global best_acc size = len(dataloader.dataset) # 验证集总样本数 num_batches = len(dataloader) # 验证集总批次数 model.eval() # 将模型设为评估模式(关闭Dropout/BatchNorm等层的训练特性) test_loss, correct = 0, 0 # 关闭梯度计算,节省显存并提升速度 with torch.no_grad(): for X, y in dataloader: X, y = X.to(device), y.to(device) pred = model.forward(X) # 累加损失和正确预测数 test_loss += loss_fn(pred, y).item() # 取预测概率最大的类别作为预测结果,统计正确数 correct += (pred.argmax(1) == y).type(torch.float).sum().item() # 计算平均损失和准确率 test_loss /= num_batches correct /= size print(f"Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss:.4f}") # 记录准确率和损失 acc_s.append(correct) loss_s.append(test_loss) # 更新最优准确率 if correct > best_acc: best_acc = correct

3.8 模型训练与结果输出

设置训练轮数(epochs),循环执行训练和验证,每轮训练后更新学习率,最终输出最优验证准确率:

复制代码

epochs = 10 # 训练轮数,可根据效果调整 for t in range(epochs): print(f"Epoch {t+1}\n-------------------------------") train(train_dataloader, model, loss_fn, optimizer) scheduler.step() # 每轮训练后更新学习率 test(test_dataloader, model, loss_fn) print("Training done!") print(f"最优验证准确率为:{best_acc * 100:.2f}%")

四、关键优化点与注意事项

  1. 参数冻结:必须冻结预训练模型的特征层参数,否则训练时会覆盖已学到的通用特征,不仅训练速度慢,还容易过拟合。

  2. 输入尺寸:ResNet系列模型的标准输入尺寸为224*224,需保证最终输入图像的尺寸符合要求。

  3. 归一化规则:必须使用ImageNet的均值和标准差,与预训练模型的训练环境保持一致,否则模型特征提取能力会大幅下降。

  4. 数据格式 :标签必须转换为64位整数(np.int64),否则会与PyTorch的交叉熵损失函数数据类型不兼容。

  5. 模型模式 :训练时用model.train(),验证时用model.eval(),避免Dropout和BatchNorm层影响验证结果。

  6. 梯度清零 :每次批量训练前必须执行optimizer.zero_grad(),否则梯度会累积,导致参数更新错误。

五、模型改进方向

本文实现的基础版本已能完成食物分类任务,若想进一步提升模型准确率和泛化能力,可尝试以下改进策略:

  1. 微调(Fine-tuning):冻结部分特征层(如仅冻结前几层),让后几层特征层与分类层一起训练,适配食物图像的专属特征。

  2. 增加数据增强 :添加ColorJitter(颜色抖动)、RandomCrop(随机裁剪)等操作,进一步丰富训练样本。

  3. 调整超参数 :优化batch_size、学习率(如初始lr设为0.0001)、训练轮数,或更换优化器(如SGD+动量)。

  4. 使用更大的预训练模型:如ResNet50、ResNet101,提升特征提取能力(注意显存占用)。

  5. 添加早停(Early Stopping):当验证集损失连续多轮不下降时,停止训练,避免过拟合。

  6. 模型保存:在验证函数中添加模型保存代码,保存最优准确率对应的模型权重,方便后续推理使用:

    复制代码

    if correct > best_acc: best_acc = correct torch.save(model.state_dict(), './best_food_model.pth') # 保存最优模型

  7. 可视化结果:使用matplotlib绘制训练过程中的准确率和损失曲线,直观分析模型训练趋势。

六、总结

本文以ResNet18为预训练模型,通过迁移学习快速实现了食物图像20分类任务,核心是复用预训练特征、仅训练分类层,大幅降低了模型训练的成本。整个过程涵盖了PyTorch中自定义数据集、数据增强、模型改造、训练与验证的全流程,是迁移学习在计算机视觉领域的典型应用。

迁移学习不仅适用于食物分类,还可推广到花卉、车辆、医疗图像等各类特定领域的图像分类任务,只需根据任务需求调整分类层的类别数和数据集即可。掌握这一方法,能让我们在面对新的计算机视觉任务时,快速搭建高性能的模型,无需从头开始训练。

后续可基于本文的基础代码,尝试模型改进策略,进一步提升分类准确率,并将训练好的模型部署到实际应用中(如食物识别APP、智能点餐系统等)。

相关推荐
海上_数字船长2 小时前
LTN 学习机制解析:基于知识库满足度的符号学习与泛化
人工智能
阿里云大数据AI技术2 小时前
Qwen3.6-Plus on PAI-DSW:云端 AI 开发,一站搞定
人工智能
格林威2 小时前
SSD 写入速度测试命令(Linux)(基于工业相机高速存储)
linux·运维·开发语言·人工智能·数码相机·计算机视觉·工业相机
Hilaku2 小时前
OpenClaw 跟病毒的区别是什么?
前端·javascript·人工智能
逻辑君2 小时前
认知神经科学研究报告【20260008】
人工智能·深度学习·神经网络·机器学习
GIS数据转换器2 小时前
延凡智慧水务系统:引领行业变革的智能引擎
大数据·人工智能·无人机·智慧城市
行者无疆_ty3 小时前
小龙虾(OpenClaw)安装教程
人工智能·agent·openclaw·小龙虾
2601_949539453 小时前
家用新能源 SUV 核心技术科普:后排娱乐、空间工程与混动可靠性解析
大数据·网络·人工智能·算法·机器学习
北邮刘老师3 小时前
暗数据:智能体探索世界的下一步
人工智能·大模型·prompt·智能体·智能体互联网