基于PyTorch实现食物图像分类:从数据加载到CNN训练全流程

在计算机视觉领域,图像分类是经典且基础的任务,而食物图像分类因应用场景广泛(如饮食推荐、营养分析等)成为热门方向。本文将基于PyTorch框架,完整实现从自定义数据集加载、数据预处理,到CNN模型构建与训练的食物图像分类全流程。

一、项目背景与技术栈

1. 核心需求

实现20类食物图像的分类,输入为食物图片(3通道,256×256尺寸),输出为图片对应的食物类别。

2. 技术栈

• 数据处理:PIL(图像读取)、torch.utils.data(Dataset/DataLoader)

• 模型构建:PyTorch nn模块(卷积层、池化层、全连接层)

• 优化器与损失函数:Adam优化器、交叉熵损失函数

二、数据预处理与数据集构建

1. 数据集文件准备

首先需要将食物图像的路径和对应标签整理为文本文件(train.txt/test.txt),每行格式为图片路径 类别标签,例如:
./food_dataset/train/薯条/img_薯条_35.jpeg 0
./food_dataset/train/汉堡/img_汉堡_12.jpeg 1
可通过如下代码批量生成该文件(遍历数据集目录,自动标注类别):

python 复制代码
import os

def generate_label_file(root, dir, output_file):
    file_txt = open(output_file, 'w', encoding='utf-8')
    path = os.path.join(root, dir)
    # 获取所有子目录(对应不同食物类别)
    dirs = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
    for class_idx, class_name in enumerate(dirs):
        class_path = os.path.join(path, class_name)
        for file in os.listdir(class_path):
            if file.endswith(('.jpeg', '.png', '.jpg')):
                img_path = os.path.join(class_path, file)
                file_txt.write(f"{img_path} {class_idx}\n")
    file_txt.close()

# 生成训练/测试集标签文件
root = r'food_dataset'
generate_label_file(root, 'train', 'train.txt')
generate_label_file(root, 'test', 'test.txt')
2. 自定义Dataset类

继承PyTorch的Dataset类,实现自定义数据集加载逻辑,核心是重写__len__和__getitem__方法:

python 复制代码
import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
from torchvision import transforms

# 数据预处理:统一尺寸+转Tensor
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize([256, 256]),
        transforms.ToTensor(),
    ]),
    'test': transforms.Compose([
        transforms.Resize([256, 256]),
        transforms.ToTensor(),
    ]),
}

class FoodDataset(Dataset):
    def __init__(self, file_path, transform=None):
        self.file_path = file_path
        self.imgs = []
        self.labels = []
        self.transform = transform
        
        # 读取标签文件,分离图片路径和标签
        with open(self.file_path, 'r', encoding='utf-8') as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for img_path, label in samples:
                self.imgs.append(img_path)
                self.labels.append(label)

    def __len__(self):
        # 返回数据集总长度
        return len(self.imgs)
    
    def __getitem__(self, idx):
        # 按索引读取单张图片和标签
        image = Image.open(self.imgs[idx]).convert('RGB')  # 确保3通道
        if self.transform:
            image = self.transform(image)
        
        # 标签转为tensor(int64类型适配CrossEntropyLoss)
        label = torch.from_numpy(np.array(self.labels[idx], dtype=np.int64))
        return image, label

# 实例化数据集
train_dataset = FoodDataset(file_path='./train.txt', transform=data_transforms['train'])
test_dataset = FoodDataset(file_path='./test.txt', transform=data_transforms['test'])
3. 构建DataLoader

DataLoader负责批量加载数据、打乱顺序,提升训练效率:

python 复制代码
from torch.utils.data import DataLoader

# 批量大小64,训练集打乱,测试集可选打乱
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)

三、CNN模型构建

针对256×256的3通道图片,设计三层卷积+池化的CNN模型,最终通过全连接层输出20类分类结果:

python 复制代码
import torch.nn as nn

class FoodCNN(nn.Module):
    def __init__(self):
        super(FoodCNN, self).__init__()
        # 输入尺寸:(3, 256, 256)
        self.conv1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),  # 激活函数
            nn.MaxPool2d(kernel_size=2)  # 池化后尺寸:(16, 128, 128)
        )
        self.conv2 = nn.Sequential(
            nn.Conv2d(16, 32, 5, 1, 2),
            nn.ReLU(),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2)  # 池化后尺寸:(32, 64, 64)
        )
        self.conv3 = nn.Sequential(
            nn.Conv2d(32, 128, 5, 1, 2),
            nn.ReLU()  # 输出尺寸:(128, 64, 64)
        )
        # 全连接层:展平后维度=128*64*64,输出20类
        self.fc = nn.Linear(128 * 64 * 64, 20)

    def forward(self, x):
        # 前向传播
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        # 展平:(batch_size, 128*64*64)
        x = x.view(x.size(0), -1)
        output = self.fc(x)
        return output

# 设备选择:优先使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FoodCNN().to(device)
print(model)  # 打印模型结构

四、模型训练与验证

1. 定义损失函数和优化器
python 复制代码
loss_fn = nn.CrossEntropyLoss()  # 交叉熵损失适配分类任务
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  # Adam优化器
2. 训练函数
python 复制代码
def train(dataloader, model, loss_fn, optimizer):
    model.train()  # 训练模式(启用Dropout/BatchNorm)
    total_loss = 0
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)
        
        # 前向传播
        pred = model(X)
        loss = loss_fn(pred, y)
        
        # 反向传播+优化
        optimizer.zero_grad()  # 清空梯度
        loss.backward()  # 梯度回传
        optimizer.step()  # 更新参数
        
        total_loss += loss.item()
        # 每1个批次打印损失
        if batch % 1 == 0:
            print(f"Batch {batch+1} | Loss: {loss.item():>7f}")
    avg_loss = total_loss / len(dataloader)
    print(f"Train Avg Loss: {avg_loss:>7f}")
3. 验证函数
python 复制代码
def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()  # 验证模式(禁用Dropout/BatchNorm)
    test_loss, correct = 0, 0
    with torch.no_grad():  # 禁用梯度计算,提升速度
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            # 计算准确率(取预测概率最大的类别)
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    # 计算平均损失和准确率
    test_loss /= num_batches
    correct /= size
    print(f"Test Result: \n Accuracy: {(100*correct):>0.1f}%, Avg Loss: {test_loss:>7f}\n")
4. 执行训练
python 复制代码
epochs = 10
for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    test(test_dataloader, model, loss_fn)
print("Training Done!")

五、关键知识点解析

1. __getitem__的作用

自定义Dataset必须实现__getitem__,它支持通过索引(如dataset[0])获取单条数据,是DataLoader批量加载的基础,本文中该方法负责读取图片、应用预处理、转换标签格式。

2. PIL库的应用

PIL(Pillow)是Python图像处理核心库,本文中用于:

• Image.open():读取本地图片文件;

• convert('RGB'):确保图片为3通道(避免灰度图导致通道数不匹配);

• 配合torchvision.transforms完成尺寸调整、格式转换。

3. CNN维度计算

卷积层输出尺寸公式:(输入尺寸 - 卷积核尺寸 + 2*padding) / stride + 1
池化层输出尺寸公式:输入尺寸 / 池化核尺寸
本文中256×256图片经两次2×2池化后变为64×64,最终展平维度为128*64*64,与全连接层输入匹配。

六、优化方向

  1. 数据增强:添加随机裁剪、翻转、亮度调整等(transforms.RandomCrop/RandomHorizontalFlip),提升模型泛化能力;

  2. 正则化:加入Dropout层(nn.Dropout(0.5))、L2正则化,防止过拟合;

  3. 学习率调度:使用torch.optim.lr_scheduler动态调整学习率;

  4. 模型轻量化:采用MobileNet、ResNet等预训练模型迁移学习,减少训练成本。

总结

本文完整实现了基于PyTorch的食物图像分类流程,从数据集构建、数据加载,到CNN模型设计、训练验证,覆盖了计算机视觉分类任务的核心环节。通过自定义Dataset适配自有数据集,结合CNN的特征提取能力,最终实现20类食物的分类,也为其他图像分类任务提供了可复用的模板。

相关推荐
盼小辉丶3 小时前
PyTorch实战(35)——使用PyTorch Profiler分析模型推理性能
人工智能·pytorch·深度学习
Dxy12393102164 小时前
深度学习的优雅降温:PyTorch中CosineAnnealingLR的终极指南
人工智能·pytorch·深度学习
研究点啥好呢4 小时前
百度 人工智能工程师面试题精选
人工智能·pytorch·神经网络·百度·ai·面试·文心一言
行走__Wz13 小时前
【刘二大人】《PyTorch深度学习实践》——PyTorch实现线性回归代码(自用)
pytorch·深度学习·线性回归
查无此人byebye15 小时前
【保姆级教程】从零实现模块化Transformer对话生成模型(PyTorch完整代码)
pytorch·深度学习·transformer
fawubio_A17 小时前
毕业设计 深度学习卷积神经网络垃圾分类系统
python·cnn·毕业设计·毕设
红茶川17 小时前
[ExecuTorch 系列] 2. 导出官方支持的大语言模型
人工智能·pytorch·ai·端侧ai
shy^-^cky18 小时前
TensorFlow、PyTorch、PaddlePaddle 三大深度学习框架全维度对比表
pytorch·深度学习·tensorflow·paddlepaddle·飞桨
ZTLJQ18 小时前
深入理解CNN:卷积神经网络的原理与实战应用
人工智能·神经网络·cnn