在计算机视觉领域,图像分类是经典且基础的任务,而食物图像分类因应用场景广泛(如饮食推荐、营养分析等)成为热门方向。本文将基于PyTorch框架,完整实现从自定义数据集加载、数据预处理,到CNN模型构建与训练的食物图像分类全流程。
一、项目背景与技术栈
1. 核心需求
实现20类食物图像的分类,输入为食物图片(3通道,256×256尺寸),输出为图片对应的食物类别。
2. 技术栈
• 数据处理:PIL(图像读取)、torch.utils.data(Dataset/DataLoader)
• 模型构建:PyTorch nn模块(卷积层、池化层、全连接层)
• 优化器与损失函数:Adam优化器、交叉熵损失函数
二、数据预处理与数据集构建
1. 数据集文件准备
首先需要将食物图像的路径和对应标签整理为文本文件(train.txt/test.txt),每行格式为图片路径 类别标签,例如:
./food_dataset/train/薯条/img_薯条_35.jpeg 0
./food_dataset/train/汉堡/img_汉堡_12.jpeg 1
可通过如下代码批量生成该文件(遍历数据集目录,自动标注类别):
python
import os
def generate_label_file(root, dir, output_file):
file_txt = open(output_file, 'w', encoding='utf-8')
path = os.path.join(root, dir)
# 获取所有子目录(对应不同食物类别)
dirs = [d for d in os.listdir(path) if os.path.isdir(os.path.join(path, d))]
for class_idx, class_name in enumerate(dirs):
class_path = os.path.join(path, class_name)
for file in os.listdir(class_path):
if file.endswith(('.jpeg', '.png', '.jpg')):
img_path = os.path.join(class_path, file)
file_txt.write(f"{img_path} {class_idx}\n")
file_txt.close()
# 生成训练/测试集标签文件
root = r'food_dataset'
generate_label_file(root, 'train', 'train.txt')
generate_label_file(root, 'test', 'test.txt')
2. 自定义Dataset类
继承PyTorch的Dataset类,实现自定义数据集加载逻辑,核心是重写__len__和__getitem__方法:
python
import torch
from torch.utils.data import Dataset
import numpy as np
from PIL import Image
from torchvision import transforms
# 数据预处理:统一尺寸+转Tensor
data_transforms = {
'train': transforms.Compose([
transforms.Resize([256, 256]),
transforms.ToTensor(),
]),
'test': transforms.Compose([
transforms.Resize([256, 256]),
transforms.ToTensor(),
]),
}
class FoodDataset(Dataset):
def __init__(self, file_path, transform=None):
self.file_path = file_path
self.imgs = []
self.labels = []
self.transform = transform
# 读取标签文件,分离图片路径和标签
with open(self.file_path, 'r', encoding='utf-8') as f:
samples = [x.strip().split(' ') for x in f.readlines()]
for img_path, label in samples:
self.imgs.append(img_path)
self.labels.append(label)
def __len__(self):
# 返回数据集总长度
return len(self.imgs)
def __getitem__(self, idx):
# 按索引读取单张图片和标签
image = Image.open(self.imgs[idx]).convert('RGB') # 确保3通道
if self.transform:
image = self.transform(image)
# 标签转为tensor(int64类型适配CrossEntropyLoss)
label = torch.from_numpy(np.array(self.labels[idx], dtype=np.int64))
return image, label
# 实例化数据集
train_dataset = FoodDataset(file_path='./train.txt', transform=data_transforms['train'])
test_dataset = FoodDataset(file_path='./test.txt', transform=data_transforms['test'])
3. 构建DataLoader
DataLoader负责批量加载数据、打乱顺序,提升训练效率:
python
from torch.utils.data import DataLoader
# 批量大小64,训练集打乱,测试集可选打乱
train_dataloader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=64, shuffle=True)
三、CNN模型构建
针对256×256的3通道图片,设计三层卷积+池化的CNN模型,最终通过全连接层输出20类分类结果:
python
import torch.nn as nn
class FoodCNN(nn.Module):
def __init__(self):
super(FoodCNN, self).__init__()
# 输入尺寸:(3, 256, 256)
self.conv1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1, padding=2),
nn.ReLU(), # 激活函数
nn.MaxPool2d(kernel_size=2) # 池化后尺寸:(16, 128, 128)
)
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 5, 1, 2),
nn.ReLU(),
nn.Conv2d(32, 32, 5, 1, 2),
nn.ReLU(),
nn.MaxPool2d(2) # 池化后尺寸:(32, 64, 64)
)
self.conv3 = nn.Sequential(
nn.Conv2d(32, 128, 5, 1, 2),
nn.ReLU() # 输出尺寸:(128, 64, 64)
)
# 全连接层:展平后维度=128*64*64,输出20类
self.fc = nn.Linear(128 * 64 * 64, 20)
def forward(self, x):
# 前向传播
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
# 展平:(batch_size, 128*64*64)
x = x.view(x.size(0), -1)
output = self.fc(x)
return output
# 设备选择:优先使用GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = FoodCNN().to(device)
print(model) # 打印模型结构
四、模型训练与验证
1. 定义损失函数和优化器
python
loss_fn = nn.CrossEntropyLoss() # 交叉熵损失适配分类任务
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Adam优化器
2. 训练函数
python
def train(dataloader, model, loss_fn, optimizer):
model.train() # 训练模式(启用Dropout/BatchNorm)
total_loss = 0
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
# 前向传播
pred = model(X)
loss = loss_fn(pred, y)
# 反向传播+优化
optimizer.zero_grad() # 清空梯度
loss.backward() # 梯度回传
optimizer.step() # 更新参数
total_loss += loss.item()
# 每1个批次打印损失
if batch % 1 == 0:
print(f"Batch {batch+1} | Loss: {loss.item():>7f}")
avg_loss = total_loss / len(dataloader)
print(f"Train Avg Loss: {avg_loss:>7f}")
3. 验证函数
python
def test(dataloader, model, loss_fn):
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval() # 验证模式(禁用Dropout/BatchNorm)
test_loss, correct = 0, 0
with torch.no_grad(): # 禁用梯度计算,提升速度
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
# 计算准确率(取预测概率最大的类别)
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
# 计算平均损失和准确率
test_loss /= num_batches
correct /= size
print(f"Test Result: \n Accuracy: {(100*correct):>0.1f}%, Avg Loss: {test_loss:>7f}\n")
4. 执行训练
python
epochs = 10
for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
train(train_dataloader, model, loss_fn, optimizer)
test(test_dataloader, model, loss_fn)
print("Training Done!")
五、关键知识点解析
1. __getitem__的作用
自定义Dataset必须实现__getitem__,它支持通过索引(如dataset[0])获取单条数据,是DataLoader批量加载的基础,本文中该方法负责读取图片、应用预处理、转换标签格式。
2. PIL库的应用
PIL(Pillow)是Python图像处理核心库,本文中用于:
• Image.open():读取本地图片文件;
• convert('RGB'):确保图片为3通道(避免灰度图导致通道数不匹配);
• 配合torchvision.transforms完成尺寸调整、格式转换。
3. CNN维度计算
卷积层输出尺寸公式:(输入尺寸 - 卷积核尺寸 + 2*padding) / stride + 1
池化层输出尺寸公式:输入尺寸 / 池化核尺寸
本文中256×256图片经两次2×2池化后变为64×64,最终展平维度为128*64*64,与全连接层输入匹配。
六、优化方向
-
数据增强:添加随机裁剪、翻转、亮度调整等(transforms.RandomCrop/RandomHorizontalFlip),提升模型泛化能力;
-
正则化:加入Dropout层(nn.Dropout(0.5))、L2正则化,防止过拟合;
-
学习率调度:使用torch.optim.lr_scheduler动态调整学习率;
-
模型轻量化:采用MobileNet、ResNet等预训练模型迁移学习,减少训练成本。
总结
本文完整实现了基于PyTorch的食物图像分类流程,从数据集构建、数据加载,到CNN模型设计、训练验证,覆盖了计算机视觉分类任务的核心环节。通过自定义Dataset适配自有数据集,结合CNN的特征提取能力,最终实现20类食物的分类,也为其他图像分类任务提供了可复用的模板。


