一、知识回顾要点
- 数据增强:通过对训练图像做翻转、裁剪、旋转等变换,扩充数据集,提升模型泛化能力,避免过拟合。
- 卷积神经网络定义写法:按层堆叠的方式搭建模型,核心包含卷积层、池化层、全连接层等组件,需明确输入输出维度与层间连接逻辑。
- Batch 归一化:对一个批次数据的分布进行标准化调整,加速训练收敛、稳定梯度,在图像分类任务中尤为常用。
- 特征图:特指卷积操作输出的二维 / 三维数据,承载了输入图像的局部特征信息,是卷积层的核心输出形式。
- 学习率调度器:动态修改基础学习率,在训练后期降低学习率,让模型更稳定地收敛到最优解。
二、卷积操作标准流程
- 特征提取阶段 :输入 → 卷积层 → Batch 归一化层(可选) → 池化层 → 激活函数 → 下一层
- 卷积层:提取局部特征
- Batch 归一化:优化数据分布,加速训练
- 池化层:降维并保留关键特征
- 激活函数:引入非线性,增强模型表达能力
- 分类输出阶段:Flatten(展平特征图) → Dense(全连接层,可搭配 Dropout 防过拟合) → Dense(输出层,对应分类类别数)
💡 补充说明
- Batch 归一化通常放在卷积层之后、激活函数之前,也有部分架构将其放在激活函数后,需根据具体任务选择。
- 特征图仅由卷积层产生,池化层输出是特征图的降采样版本,全连接层输出为一维向量,不再称为特征图。
- 学习率调度器需配合优化器使用,常见策略有 StepLR、CosineAnnealing 等。
🧩 简单 CNN PyTorch 代码模板(含数据增强、BatchNorm、学习率调度器)
python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR
# ---------------------- 1. 数据增强与加载 ----------------------
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4), # 随机裁剪
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), # 归一化
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
# 以CIFAR10为例
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=2)
# ---------------------- 2. CNN模型定义 ----------------------
class SimpleCNN(nn.Module):
def __init__(self, num_classes=10):
super(SimpleCNN, self).__init__()
# 特征提取部分:卷积 -> BatchNorm -> 池化 -> 激活
self.features = nn.Sequential(
# 第一层
nn.Conv2d(3, 32, kernel_size=3, padding=1), # 输入通道3,输出32
nn.BatchNorm2d(32), # Batch归一化
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2), # 池化降维
# 第二层
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
# 第三层
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=2, stride=2),
)
# 分类部分:Flatten -> Dense -> Dropout -> Output
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 4 * 4, 512), # 32x32经3次池化后为4x4
nn.ReLU(inplace=True),
nn.Dropout(0.5), # Dropout防过拟合
nn.Linear(512, num_classes),
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
# ---------------------- 3. 训练配置 ----------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)
scheduler = StepLR(optimizer, step_size=30, gamma=0.1) # 学习率调度器:每30轮学习率×0.1
# ---------------------- 4. 训练循环 ----------------------
def train(epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
optimizer.step()
if batch_idx % 100 == 0:
print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} ({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
def test():
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += criterion(output, target).item()
pred = output.argmax(dim=1, keepdim=True)
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.2f}%)\n')
# 运行训练
for epoch in range(1, 61):
train(epoch)
test()
scheduler.step() # 更新学习率