一、搭建卷积神经网络

利用搜集的食物图片数据集实现卷积神经网络的图像识别。


对上面的文件内容进行处理,使之生成一个trainda.txt和testda.txt文件
python
import os
def train_test_file(root,dir_name):
file_txt = open(dir_name+'da.txt','w')
path = os.path.join(root,dir_name)
for roots,directories, files in os.walk(path):#os.list_dir()
if len(directories) != 0:
dirs= directories
else:
now_dir = roots.split('\\')
for file in files:
path_1 = os.path.join(roots,file)
print(path_1)
file_txt.write(path_1+' '+str(dirs.index(now_dir[-1]))+'\n')
file_txt.close()
root =r'.\food_dataset2'
train_dir = 'train'
test_dir = 'test'
train_test_file(root,train_dir)
train_test_file(root,test_dir)
生成的文件内容如下,保存着食物图片的路径和标签
二、数据预处理Dataset

三、数据增强
数据增强(Data Augmentation) 是一种在机器学习和深度学习领域常用的技术,主要用于增加训练数据的数量和多样性,通过在不改变数据标签的前提下,对原始数据进行一系列变换或扩展来生成新的训练样本。
核心目的:
-
解决数据不足问题:当训练数据量较少时,模型容易过拟合(即过度适应训练集,泛化能力差),数据增强可以"凭空"创造更多样本。
-
提升模型泛化能力:通过对数据添加扰动(如旋转、缩放、噪声等),让模型学习到更鲁棒的特征,提高对未见数据的适应能力。
-
增强数据多样性:模拟真实场景中的可能变化(如光照变化、物体位置变化等),使模型更贴近实际应用。
常见的数据增强方法(以图像数据为例):
1. 几何变换
-
旋转:将图像按一定角度旋转。
-
翻转:水平或垂直翻转图像。
-
缩放:随机放大或缩小图像。
-
裁剪:随机截取图像的一部分。
-
平移:沿水平或垂直方向移动图像。
2. 颜色变换
-
亮度调整:改变图像亮度。
-
对比度调整:增强或减弱对比度。
-
颜色抖动:随机调整色相、饱和度。
-
添加噪声:加入高斯噪声、椒盐噪声等。
3. 结构变换
-
随机遮挡:随机遮挡部分区域(模拟物体被遮挡的情况)。
-
混合图像:将多张图像混合(如MixUp、CutMix)。
4. 高级增强
-
生成对抗网络(GAN):用生成模型创造新样本。
-
风格迁移:改变图像风格但不改变内容。
其他领域的数据增强:
-
文本数据:同义词替换、随机插入/删除词语、回译(翻译成其他语言再译回)、句子重组等。
-
音频数据:改变语速、添加背景噪声、调整音高、时间拉伸等。
-
时序数据:添加时间维度上的抖动、缩放、窗口切片等。
为什么数据增强有效?
-
引入不变性:通过变换让模型学会"旋转后的猫还是猫",增强对视角、光照等因素的不变性。
-
正则化效果:相当于隐式地给模型增加了约束,防止过拟合。
-
低成本扩展数据:无需额外人工标注,自动生成新数据。
注意事项:
-
合理性:增强后的数据应保持标签不变(如翻转"6"可能变成"9",数字识别中需避免)。
-
适度性:过度的增强可能破坏原始信息,反而降低模型性能。
-
领域适配:不同任务需要选择针对性的增强方法(如医学影像需保持病理特征不变)。
四、保存最优模型
在训练结束后,我们可以将训练最好的一轮保存下来,方便下次直接使用
方法1:仅保存状态字典
python
torch.save(model.state_dict(), "best2025-12-30.pth")
保存内容:
-
只保存模型参数:权重(weights)、偏置(biases)、BN层的均值和方差等
-
不保存:模型架构定义、优化器状态、训练配置等
特点:
-
文件体积小
-
需要在加载时重新实例化模型类
-
更灵活,可以跨模型结构(如果兼容)
方法2:保存完整模型
python
torch.save(model, "best1.pth")
保存内容:
-
模型架构:类的定义(通过序列化方式)
-
模型参数 :与
state_dict()相同的内容 -
相关信息:模型类的位置路径、源代码等
特点:
-
文件体积大
-
可直接加载使用,无需模型定义
-
可能存在序列化/反序列化问题
两种方法对比
| 特性 | model.state_dict() |
model (完整模型) |
|---|---|---|
| 保存内容 | 仅参数 | 参数 + 架构 + 类定义 |
| 文件大小 | 小 | 大(可能大2-10倍) |
| 加载方式 | 需要先创建模型实例 | 直接加载即可 |
| 灵活性 | 高(可加载到不同架构) | 低(绑定了特定类) |
| 安全性 | 高 | 低(可能有安全隐患) |
| 代码变更影响 | 不受影响(如果接口一致) | 加载可能失败 |
| 推荐场景 | 生产部署、迁移学习 | 快速原型、短期实验 |
五、实际运用
1、导入库和模块
python
import torch
from torch.utils.data import Dataset, DataLoader # 用于处理数据集的工具
import numpy as np
from PIL import Image # 图像处理库
from torchvision import transforms # 数据预处理和增强工具
from torch import nn # 神经网络模块
2、数据预处理(包括数据增强)
python
data_transforms = {
'trainda': # 训练集的数据增强
transforms.Compose([
transforms.Resize([300,300]), # 第1步:统一调整到300×300
transforms.RandomRotation(45), # 第2步:随机旋转±45度
transforms.CenterCrop(256), # 第3步:从中心裁剪256×256
transforms.RandomHorizontalFlip(p=0.5), # 第4步:50%概率水平翻转
transforms.RandomVerticalFlip(p=0.5), # 第5步:50%概率垂直翻转
transforms.ColorJitter(brightness=0.2,contrast=0.1,saturation=0.1,hue=0.1), # 第6步:颜色扰动
transforms.RandomGrayscale(p=0.1), # 第7步:10%概率转为灰度图
transforms.ToTensor(), # 第8步:将PIL图像转为Tensor(0-1范围)
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) # 第9步:标准化(使用ImageNet均值和标准差)
]),
'valid': # 验证集,不需要数据增强
transforms.Compose([
transforms.Resize([256,256]), # 第1步:调整到256×256
transforms.ToTensor(), # 第2步:转为Tensor
transforms.Normalize([0.485,0.456,0.406],[0.229,0.224,0.225]) # 第3步:标准化
]),
}
3、自定义数据集类
python
class food_dataset(Dataset):
def __init__(self, file_path, transform=None):
self.file_path = file_path
self.imgs = [] # 存储图片路径
self.labels = [] # 存储标签
self.transform = transform # 数据转换函数
# 第1步:读取txt文件
with open(self.file_path,'r') as f:
# 第2步:解析每行数据(假设格式:图片路径 标签)
samples = [x.strip().split(' ') for x in f.readlines()]
for img_path, label in samples:
# 第3步:保存图片路径和标签
self.imgs.append(img_path)
self.labels.append(label)
def __len__(self):
# 返回数据集大小
return len(self.imgs)
def __getitem__(self, idx):
# 第1步:根据索引读取图片
image = Image.open(self.imgs[idx]) # 打开图片
# 第2步:应用数据转换(如果有)
if self.transform:
image = self.transform(image) # 应用前面定义的数据增强
# 第3步:处理标签
label = self.labels[idx] # 获取标签
label = torch.from_numpy(np.array(label, dtype=np.int64)) # 转为Tensor
# 第4步:返回(图像,标签)对
return image, label
4、创建数据集和数据加载器
python
# 第1步:创建训练集和测试集实例
training_data = food_dataset(file_path='./trainda.txt', transform=data_transforms['trainda'])
test_data = food_dataset(file_path='./testda.txt', transform=data_transforms['valid'])
# 第2步:创建数据加载器
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True) # 训练集:批大小64,打乱顺序
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True) # 测试集:批大小64,打乱顺序
5、定义CNN模型
python
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
# 第1层卷积:3通道输入 → 16通道输出
self.conv1 = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=5, stride=1, padding=2), # 卷积:保持256×256尺寸
nn.ReLU(), # 激活函数
nn.MaxPool2d(kernel_size=2) # 最大池化:256×256 → 128×128
)
# 第2层卷积:16通道 → 32通道
self.conv2 = nn.Sequential(
nn.Conv2d(16, 32, 5, 1, 2), # 卷积:128×128 → 128×128
nn.ReLU(),
nn.Conv2d(32, 32, 5, 1, 2), # 再次卷积
nn.ReLU(),
nn.MaxPool2d(kernel_size=2) # 池化:128×128 → 64×64
)
# 第3层卷积:32通道 → 128通道
self.conv3 = nn.Sequential(
nn.Conv2d(32, 128, 5, 1, 2), # 卷积:64×64 → 64×64
nn.ReLU()
)
# 全连接层:64×64×128 → 20个类别
self.out = nn.Linear(128 * 64 * 64, 20) # 524,288维 → 20维
def forward(self, x):
# 前向传播流程:
# 第1步:卷积层1
x = self.conv1(x) # [batch, 3, 256, 256] → [batch, 16, 128, 128]
# 第2步:卷积层2
x = self.conv2(x) # [batch, 16, 128, 128] → [batch, 32, 64, 64]
# 第3步:卷积层3
x = self.conv3(x) # [batch, 32, 64, 64] → [batch, 128, 64, 64]
# 第4步:展平
x = x.view(x.size(0), -1) # [batch, 128, 64, 64] → [batch, 524288]
# 第5步:全连接层
output = self.out(x) # [batch, 524288] → [batch, 20]
return output
6、设置训练设备
python
# 第1步:检测可用的设备(优先顺序:CUDA → MPS → CPU)
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else 'cpu'
# 第2步:打印使用的设备
print(f"Using {device} device")
# 第3步:创建模型并移动到对应设备
model = CNN().to(device)
# 第4步:打印模型结构
print(model)
7、定义训练函数
python
def train(dataloader, model, loss_fn, optimizer):
# 第1步:设置为训练模式
model.train()
# 第2步:初始化批次计数器
batch_size_num = 1
# 第3步:遍历数据加载器
for X, y in dataloader:
# 第4步:将数据移动到设备
X, y = X.to(device), y.to(device)
# 第5步:前向传播
pred = model.forward(X) # 等价于 model(X)
# 第6步:计算损失
loss = loss_fn(pred, y)
# 第7步:梯度清零
optimizer.zero_grad()
# 第8步:反向传播
loss.backward()
# 第9步:参数更新
optimizer.step()
# 第10步:打印损失
loss_value = loss.item()
print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")
# 第11步:批次计数增加
batch_size_num += 1
8、定义测试函数和保存最优模型
python
def test(dataloader, model, loss_fn):
# 第1步:设置为评估模式
model.eval()
# 第2步:获取数据集和批次数量
size = len(dataloader.dataset)
num_batchs = len(dataloader)
# 第3步:初始化测试统计
test_loss, correct = 0, 0
# 第4步:不计算梯度
with torch.no_grad():
# 第5步:遍历测试数据
for X, y in dataloader:
# 第6步:数据移动
X, y = X.to(device), y.to(device)
# 第7步:前向传播
pred = model.forward(X)
# 第8步:累计损失
test_loss += loss_fn(pred, y).item()
# 第9步:计算正确预测数量
# pred.argmax(1): 取每行最大值索引(预测类别)
# (pred.argmax(1) == y): 比较预测和真实标签
# .type(torch.float): 转为浮点数(True=1.0, False=0.0)
# .sum().item(): 求和并转为Python标量
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
# 第10步:计算平均损失和准确率
test_loss /= num_batchs # 平均损失
correct /= size # 准确率(0-1范围)
# 第11步:打印结果
print(f"结果: \n Accuracy :{(100*correct)}%, Avg loss:{test_loss}")
# 第12步: 保存最优模型
best_acc = 0
if correct>best_acc:
best_acc=correct
torch.save(model.state_dict(),"best2025-12-30.pth")#获取模型中的全部w,b参数,后缀也可以是ph,t7
#torch.save(model,"best1.pth")#保存完整模型(w,b,模型cnn),模型信息也保存下来
9、设置损失函数和优化器
python
# 第1步:定义损失函数(交叉熵损失,适用于分类问题)
loss_fn = nn.CrossEntropyLoss()
# 第2步:定义优化器(Adam优化器,学习率0.001)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
10、训练循环
python
# 第1步:设置训练轮数
epochs = 10
# 第2步:开始训练
for t in range(epochs):
# 第3步:打印当前轮次
print(f"Epoch{t+1}\n开始训练")
# 第4步:训练一个epoch
train(train_dataloader, model, loss_fn, optimizer)
# 第5步:训练完成
print("训练结束!")
# 第6步:最终测试
test(test_dataloader, model, loss_fn)
代码执行流程总结:
数据准备阶段:
-
定义数据增强管道
-
创建自定义数据集类
-
加载数据到数据加载器
模型构建阶段:
-
定义CNN网络结构
-
设置训练设备(GPU/CPU)
-
创建模型实例
训练配置阶段:
-
定义损失函数(交叉熵)
-
定义优化器(Adam)
-
设置训练轮数
训练执行阶段:
-
循环10个epoch
-
每个epoch内:
a. 设置为训练模式
b. 遍历所有训练批次
c. 前向传播 → 计算损失 → 反向传播 → 参数更新
d. 打印损失
测试评估阶段:
-
设置为评估模式
-
遍历测试数据
-
计算准确率和平均损失
-
打印最终结果
代码运行结果

可以观察到,训练得到的准确率不是很高,下一篇我将详细讲解提高准确率的方法
