基于PyTorch的食品图像分类:数据增强与调优实战

在计算机视觉领域,图像分类是基础且核心的任务,而食品图像分类因食材种类多、外观相似度高,对模型的鲁棒性提出了更高要求。本文将以20类食品分类任务为例,分享如何通过数据增强、学习率调整、最优模型保存与加载等技巧,基于PyTorch搭建CNN模型并持续提升分类准确率。

一、项目背景与基础框架

1. 任务目标

构建一个CNN模型,实现20类食品图像的精准分类,核心挑战在于:

• 食品图像存在角度、光照、缩放等差异;

• 训练过程中易出现过拟合、学习率不合适导致收敛慢/不收敛;

• 需保证模型在测试集上的泛化能力。

2. 基础环境与数据准备

• 框架:PyTorch(兼顾灵活性与易用性);

• 数据格式:train.txt/test.txt存储图像路径与对应标签(每行格式:图像路径 类别标签);

• 设备:优先使用GPU(CUDA/MPS)加速训练,兜底使用CPU。

3. 基础Dataset与DataLoader构建

自定义food_dataset类加载数据,核心逻辑是读取图像路径与标签,对图像应用预处理,并转换为PyTorch张量:

python 复制代码
class food_dataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.file_path = file_path
        self.imgs = []
        self.labels = []
        self.transform = transform
        # 解析txt文件,收集图像路径与标签
        with open(self.file_path) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path, label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)

    def __len__(self):
        return len(self.imgs)
    
    def __getitem__(self, idx):
        # 读取图像并应用预处理
        image = Image.open(self.imgs[idx])
        if self.transform:
            image = self.transform(image)
        # 标签转换为int64张量(适配CrossEntropyLoss)
        label = torch.from_numpy(np.array(self.labels[idx], dtype=np.int64))
        return image, label

二、数据增强:解决过拟合与数据多样性

1. 为什么需要数据增强?

训练集数据量有限时,模型易记住训练样本特征(过拟合),数据增强通过对训练图像施加随机变换,生成"新样本",提升模型泛化能力。

2. 训练集与验证集差异化增强策略

• 训练集:施加随机旋转、翻转、色彩抖动等增强,模拟真实场景的图像变化;

• 验证/测试集:仅做缩放、标准化,保证数据一致性。

python 复制代码
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize([300,300]),       # 缩放至300×300
        transforms.RandomRotation(45),      # 随机旋转±45°
        transforms.CenterCrop(256),         # 中心裁剪至256×256
        transforms.RandomHorizontalFlip(p=0.5),  # 随机水平翻转(50%概率)
        transforms.RandomVerticalFlip(p=0.5),    # 随机垂直翻转(50%概率)
        transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1), # 色彩抖动
        transforms.RandomGrayscale(p=0.1),  # 随机灰度化(10%概率)
        transforms.ToTensor(),              # 转换为张量(C×H×W,0-1归一化)
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) # 基于ImageNet的标准化
    ]),
    'valid': transforms.Compose([
        transforms.Resize([256,256]),       # 仅缩放至256×256
        transforms.ToTensor(),
        transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225])
    ]),
}

无数据增强与有数据增强对比

关键细节

• 先缩放再裁剪:避免直接裁剪导致图像信息丢失;

• 色彩抖动参数适度:避免过度变换导致标签失真;

• 标准化使用ImageNet均值/方差:利用预训练的统计特征,加速收敛。

三、CNN模型搭建:兼顾特征提取与计算效率

设计轻量级CNN模型,分为3个卷积块+1个全连接层,核心是逐步提升通道数、缩小特征图尺寸:

python 复制代码
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # 卷积块1:提取低级特征(边缘、纹理)
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),  # 3→32通道,5×5卷积,padding=2保证尺寸不变
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)  # 池化后尺寸减半:256→128
        )
        # 卷积块2:提取中级特征(局部形状)
        self.conv2 = nn.Sequential(
            nn.Conv2d(32, 64, 5, 1, 2), # 32→64通道
            nn.ReLU(),
            nn.Conv2d(64, 64, 5, 1, 2), # 加深卷积,强化特征
            nn.ReLU(),
            nn.MaxPool2d(2)              # 尺寸再减半:128→64
        )
        # 卷积块3:提取高级特征(物体部件)
        self.conv3 = nn.Sequential(
            nn.Conv2d(64, 128, 5, 1, 2),# 64→128通道
            nn.ReLU()
        )
        # 全连接层:映射到20类分类结果
        self.out = nn.Linear(128*64*64, 20)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.view(x.size(0), -1)  # 展平:(batch_size, 128×64×64)
        output = self.out(x)
        return output

四、核心调优技巧:学习率调整

固定学习率易出现"前期收敛慢、后期震荡不收敛",本文使用StepLR调度器,每10个epoch将学习率减半:

1. 优化器与调度器初始化
python 复制代码
# 基础优化器:Adam(自适应学习率,收敛更快)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# 学习率调度器:step_size=10(每10轮调整),gamma=0.5(学习率×0.5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
2. 训练循环中更新学习率
python 复制代码
epochs = 100
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
    scheduler.step()  # 每个epoch结束后调整学习率

调优逻辑:

• 初始学习率0.001:保证前期快速收敛;

• 每10轮减半:后期减小学习率,精细调整参数,避免越过最优解。

五、最优模型保存与加载:避免训练白费

训练过程中模型准确率会波动,仅保存测试集准确率最高的模型参数,而非最终模型:

1. 保存最优模型
python 复制代码
best_acc = 0  # 初始化最优准确率
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    # 计算平均损失与准确率
    test_loss /= num_batches
    correct /= size
    print(f"Test result: \n Accuracy: {(100 * correct):.2f}%, Avg loss: {test_loss:.4f}")
    
    # 保存最优模型参数
    global best_acc
    if correct > best_acc:
        best_acc = correct
        # 仅保存参数(state_dict),节省空间且兼容性更好
        torch.save(model.state_dict(), 'best.pth')
2. 加载最优模型进行推理
python 复制代码
# 1. 初始化模型并加载参数
model = CNN().to(device)
# 加载参数(兼容不同设备,weights_only=True保证安全)
model.load_state_dict(torch.load('best.pth', map_location=device, weights_only=True))
model.eval()  # 切换至评估模式(禁用Dropout/BatchNorm等)

# 2. 推理验证
def test_ture(dataloader, model):
    result = []  # 预测标签
    labels = []  # 真实标签
    with torch.no_grad():  # 禁用梯度计算,加速推理
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            result.extend(pred.argmax(1).tolist())
            labels.extend(y.tolist())
    # 计算准确率
    correct = sum(p == l for p, l in zip(result, labels))
    accuracy = correct / len(labels) * 100
    return accuracy

accuracy = test_ture(test_dataloader, model)
print(f'最终测试准确率:{accuracy:.2f}%')

六、效果提升总结

优化手段 核心作用 准确率提升效果(示例)
数据增强 降低过拟合,提升泛化能力 +8%~12%
StepLR 学习率调整 加速收敛,后期精细调参 +3%~5%
保存最优模型 避免训练后期震荡导致的准确率下降 +2%~3%

七、进阶优化思路

  1. 数据增强升级:引入Albumentations库,添加随机裁剪、高斯噪声、透视变换等更丰富的增强;

  2. 模型改进:替换为ResNet、MobileNet等预训练模型(迁移学习),利用预训练特征进一步提升准确率;

  3. 学习率调度升级:使用CosineAnnealingLR(余弦退火)、ReduceLROnPlateau(按需调整);

  4. 正则化:添加Dropout层、L2正则化(weight_decay),进一步抑制过拟合。

八、总结

本文以食品图像分类为例,完整展示了从数据加载、数据增强、模型搭建,到学习率调优、最优模型保存与加载的全流程。核心思路是:通过数据增强扩大有效训练集,通过学习率调度优化收敛过程,通过保存最优模型锁定最佳效果。这些技巧不仅适用于食品分类,也可迁移至花卉、车辆、场景等各类图像分类任务,是提升CNN模型性能的通用方法论。

相关推荐
岁岁种桃花儿2 小时前
AI超级智能开发系列从入门到上天第十篇:SpringAI+云知识库服务
linux·运维·数据库·人工智能·oracle·llm
小马_xiaoen2 小时前
2026 AI 开发新风向:Skills 安装量 Top 10 深度解析
人工智能·skill
Surmon2 小时前
AI 代替不了这样的你
人工智能·ai编程
ZGi.ai2 小时前
一个 LLM 网关需要做哪些事? 多模型统一接入的工程设计
人工智能
FL16238631292 小时前
C#版winform实现FaceFusion人脸替换
人工智能
爱敲点代码的小哥2 小时前
Halcon图像处理:筛选、降噪与增强全解析
人工智能
小陈工2 小时前
2026年3月24日技术资讯洞察:边缘AI商业化,Java26正式发布与开源大模型成本革命
java·运维·开发语言·人工智能·python·容器·开源
云蝠呼叫大模型联络中心2 小时前
医疗智能客服系统架构设计与云蝠VoiceAgent API集成实践
人工智能·系统架构·api·医疗·voiceagent·ai 客服选型·智能客服 2026
工具箱大集合2 小时前
英语课件PPT免费模板2026实测优选清单
人工智能·ppt