用 PyTorch 训练图像分类器:完整实战

摘要 :前面五篇学完了理论和 PyTorch 基础,这次做一件完整的事------从零训练一个 CNN 图像分类器。我们会用 CIFAR-10 数据集(飞机、汽车、鸟类等 10 类),经历数据加载、CNN 模型设计、训练调优、结果可视化、模型保存与推理的全流程。每一段代码都可以直接复制运行。


一、项目概览

我们要做什么

复制代码
输入:一张 32×32 的彩色图片
输出:它属于 10 个类别中的哪一个

类别:飞机 ✈️ 汽车 🚗 鸟 🐦 猫 🐱 鹿 🦌
      狗 🐶 青蛙 🐸 马 🐴 船 🚢 卡车 🚛

技术栈

组件 用途
PyTorch 深度学习框架
torchvision 数据集 + 图像预处理
CIFAR-10 6 万张 32×32 彩色图像
CNN 卷积神经网络
matplotlib 结果可视化

完整流程

复制代码
数据加载 ─→ 数据预处理 ─→ 模型设计 ─→ 训练 ─→ 评估 ─→ 保存 ─→ 推理
  ↓            ↓            ↓          ↓       ↓        ↓       ↓
torchvision  transforms   nn.Module   循环    测试集    .pt    加载预测

二、第一步:数据加载与预处理

复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# ===== 1. 设备配置 =====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 输出: Using device: cuda  (如果有 GPU)

数据增强与归一化

数据增强(Data Augmentation)是提升模型泛化能力的关键技巧------对训练图片做随机变换,相当于免费扩大了训练集:

复制代码
# ===== 2. 数据预处理 =====
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),   # 随机裁剪
    transforms.RandomHorizontalFlip(),       # 随机水平翻转
    transforms.ToTensor(),                   # PIL → Tensor [0,1]
    transforms.Normalize(                     # 归一化到 [-1, 1]
        mean=(0.4914, 0.4822, 0.4465),       # CIFAR-10 各通道均值
        std=(0.2470, 0.2435, 0.2616)         # CIFAR-10 各通道标准差
    ),
])

# 测试集不增强(只做归一化)
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=(0.4914, 0.4822, 0.4465),
        std=(0.2470, 0.2435, 0.2616)
    ),
])

为什么数据增强有效?

数据增强模拟了真实世界中的变化------物体可能出现在图像的不同位置、角度、光照条件下。通过让模型看到"各种各样的同类物体",它学会了关注物体的本质特征 而非不相关的细节

加载数据集

复制代码
# ===== 3. 下载和加载 CIFAR-10 =====
train_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train
)
test_dataset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test
)

train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

print(f"训练集: {len(train_dataset)} 张")
print(f"测试集:  {len(test_dataset)} 张")
# 训练集: 50000 张
# 测试集:  10000 张

看一眼数据长什么样

复制代码
# ===== 4. 可视化样本 =====
def imshow(img):
    img = img / 2 + 0.5  # 反归一化 [-1,1] → [0,1]
    plt.imshow(np.transpose(img.numpy(), (1, 2, 0)))
    plt.axis('off')

# 取一个 batch
data_iter = iter(train_loader)
images, labels = next(data_iter)

# 显示 8 张
plt.figure(figsize=(12, 4))
for i in range(8):
    plt.subplot(2, 4, i+1)
    imshow(images[i])
    plt.title(classes[labels[i]])
plt.tight_layout()
plt.show()

三、第二步:设计 CNN 模型

我们设计一个适合 CIFAR-10 的 CNN。它比 LeNet 更深,但比 ResNet 更简单------在效果好和易于理解之间取得平衡。

复制代码
class CIFAR10CNN(nn.Module):
    """适合 CIFAR-10 的 CNN 架构"""
    
    def __init__(self, num_classes=10):
        super().__init__()
        
        # 特征提取部分 (卷积层)
        self.features = nn.Sequential(
            # Block 1: 32×32×3 → 32×32×32
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            
            # Block 2: 32×32×32 → 16×16×32
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 32×32 → 16×16
            
            # Block 3: 16×16×32 → 16×16×64
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            
            # Block 4: 16×16×64 → 8×8×64
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 16×16 → 8×8
            
            # Block 5: 8×8×64 → 8×8×128
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            
            # Block 6: 8×8×128 → 4×4×128
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),  # 8×8 → 4×4
        )
        
        # 分类部分 (全连接层)
        self.classifier = nn.Sequential(
            nn.Dropout(0.3),           # 防止过拟合
            nn.Linear(128 * 4 * 4, 256),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(256, 256),
            nn.ReLU(inplace=True),
            nn.Linear(256, num_classes),
        )
    
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)  # 展平: [batch, 128×4×4] = [batch, 2048]
        x = self.classifier(x)
        return x

# 实例化
model = CIFAR10CNN(num_classes=10).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"总参数量: {total_params:,}")
print(f"可训练参数量: {trainable_params:,}")
# 总参数量: 1,173,802  (~1.2M,轻量级)

架构可视化

复制代码
输入: [batch, 3, 32, 32]
    │
    ▼
Conv2d(3→32) + BN + ReLU                     [batch, 32, 32, 32]
Conv2d(32→32) + BN + ReLU + MaxPool(2×2)      [batch, 32, 16, 16]
    │
    ▼
Conv2d(32→64) + BN + ReLU                     [batch, 64, 16, 16]
Conv2d(64→64) + BN + ReLU + MaxPool(2×2)      [batch, 64, 8, 8]
    │
    ▼
Conv2d(64→128) + BN + ReLU                    [batch, 128, 8, 8]
Conv2d(128→128) + BN + ReLU + MaxPool(2×2)    [batch, 128, 4, 4]
    │
    ▼
展平 → Dropout → Linear(2048→256) → ReLU → Dropout → Linear(256→256) → ReLU → Linear(256→10)
    │
    ▼
输出: [batch, 10]  ← 各类别的得分 (logits)

设计要点

  • 卷积核统一为 3×3:三个 3×3 堆叠 = 一个 7×7 感受野,但参数少 45%
  • 每层后加 BatchNorm:稳定训练,允许更大的学习率
  • 通道数逐渐增加:3→32→64→128,底层检测基础特征需要更多通道
  • 特征图尺寸逐渐缩小:32→16→8→4,空间压缩、语义浓缩
  • Dropout 防止过拟合:在分类器部分随机丢弃 30% 的神经元

四、第三步:训练模型

复制代码
# ===== 5. 定义损失函数和优化器 =====
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=5e-4)

# 学习率调度器:每 30 个 epoch 学习率乘以 0.1
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)

训练一个 epoch

复制代码
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    avg_loss = running_loss / len(loader)
    accuracy = 100.0 * correct / total
    return avg_loss, accuracy

验证

复制代码
@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
    
    return running_loss / len(loader), 100.0 * correct / total

执行训练

复制代码
# ===== 6. 训练 =====
num_epochs = 50
best_acc = 0.0

train_losses, train_accs = [], []
test_losses, test_accs = [], []

for epoch in range(1, num_epochs + 1):
    train_loss, train_acc = train_epoch(
        model, train_loader, criterion, optimizer, device
    )
    test_loss, test_acc = evaluate(
        model, test_loader, criterion, device
    )
    
    # 学习率调度
    scheduler.step()
    
    # 记录
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    test_losses.append(test_loss)
    test_accs.append(test_acc)
    
    # 保存最佳模型
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), 'best_model.pth')
    
    if epoch % 5 == 0 or epoch == 1:
        print(f"Epoch {epoch:2d}/{num_epochs} | "
              f"Train Loss={train_loss:.3f} Acc={train_acc:.2f}% | "
              f"Test  Loss={test_loss:.3f} Acc={test_acc:.2f}%")

print(f"\n训练完成!最佳测试准确率: {best_acc:.2f}%")

输出示例

复制代码
Epoch  1/50 | Train Loss=1.893 Acc=30.15% | Test  Loss=1.546 Acc=42.38%
Epoch  5/50 | Train Loss=0.995 Acc=64.34% | Test  Loss=0.902 Acc=68.41%
Epoch 10/50 | Train Loss=0.609 Acc=78.26% | Test  Loss=0.711 Acc=75.37%
Epoch 20/50 | Train Loss=0.361 Acc=87.27% | Test  Loss=0.597 Acc=80.52%
Epoch 30/50 | Train Loss=0.245 Acc=91.32% | Test  Loss=0.523 Acc=83.18%
Epoch 40/50 | Train Loss=0.166 Acc=94.08% | Test  Loss=0.527 Acc=83.87%
Epoch 50/50 | Train Loss=0.110 Acc=96.17% | Test  Loss=0.535 Acc=83.96%

训练完成!最佳测试准确率: 84.21%

五、第四步:结果分析

绘制训练曲线

复制代码
# ===== 7. 可视化训练过程 =====
plt.figure(figsize=(12, 4))

# 损失曲线
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# 准确率曲线
plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Acc')
plt.plot(test_accs, label='Test Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

如何看懂这些曲线?

现象 含义 怎么办
训练损失↓ 测试损失↓ ✅ 正常学习 继续训练
训练损失↓ 测试损失↑ ⚠️ 过拟合 增加 Dropout、减少模型参数、加数据增强
训练损失↑ 测试损失↑ ❌ 没学到 检查学习率、模型结构、数据预处理
两者都停滞 ⏸ 收敛了 降低学习率试试

查看每类准确率

复制代码
# ===== 8. 各类别准确率 =====
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

class_correct = [0] * 10
class_total = [0] * 10

with torch.no_grad():
    for inputs, labels in test_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = outputs.max(1)
        
        for label, pred in zip(labels, predicted):
            if label == pred:
                class_correct[label] += 1
            class_total[label] += 1

print("\n各类别准确率:")
for i in range(10):
    acc = 100.0 * class_correct[i] / class_total[i]
    print(f"  {classes[i]:8s}: {acc:.1f}%")

典型输出

复制代码
各类别准确率:
  plane  : 88.2%
  car    : 91.4%
  bird   : 76.8%    ← 鸟最难分类(形态多变)
  cat    : 75.3%    ← 猫也难(和狗容易混淆)
  deer   : 82.6%
  dog    : 78.1%
  frog   : 86.9%
  horse  : 84.5%
  ship   : 90.2%
  truck  : 89.6%

六、第五步:保存与推理

保存完整的模型

复制代码
# 保存方式一:仅参数(推荐)
torch.save(model.state_dict(), 'cifar10_cnn_weights.pth')

# 保存方式二:完整模型(含架构)
torch.save(model, 'cifar10_cnn_full.pth')

用保存的模型做推理

复制代码
# ===== 9. 推理新图片 =====
def predict_image(model, image_tensor, device, classes):
    """对单张图片做预测"""
    model.eval()
    
    # 增加 batch 维度: [3,32,32] → [1,3,32,32]
    if image_tensor.dim() == 3:
        image_tensor = image_tensor.unsqueeze(0)
    
    image_tensor = image_tensor.to(device)
    
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = torch.softmax(outputs, dim=1)
        confidence, predicted = probabilities.max(1)
    
    return classes[predicted.item()], confidence.item()

# 示例:从测试集中取一张
data_iter = iter(test_loader)
images, labels = next(data_iter)

img = images[5]   # 取第 6 张
true_label = classes[labels[5]]

pred_label, confidence = predict_image(model, img, device, classes)
print(f"真实: {true_label} | 预测: {pred_label} | 置信度: {confidence:.2%}")

批处理推理

复制代码
def predict_batch(model, dataloader, device):
    """对整个数据集做批量推理"""
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    return np.array(all_preds), np.array(all_labels)

# 在测试集上运行
predictions, ground_truth = predict_batch(model, test_loader, device)

# 计算各类别准确率
from sklearn.metrics import classification_report
print(classification_report(ground_truth, predictions, target_names=classes))

七、进阶方向

这个项目达到了约 84% 的测试准确率。如果想更进一步,可以尝试:

改进方向 预期提升 实现方式
更深的结构 +3-5% 改用 ResNet-18 或更深的架构
更强的数据增强 +2-4% 添加 CutOut、MixUp、AutoAugment
学习率 Warmup +1-2% 前 5 个 epoch 学习率从 0 线性增加到目标值
标签平滑 +0.5-1% nn.CrossEntropyLoss(label_smoothing=0.1)
集成学习 +2-3% 训练 5 个不同初始化的模型,投票决策
迁移学习 +5-10% 用 ImageNet 预训练模型做微调

当前 SOTA(2026 年 CIFAR-10):准确率超过 99%,使用 Vision Transformer + 大规模预训练 + 强数据增强。但这个项目的价值不在于追求极致准确率,而在于理解整个流程------这段代码可以轻松迁移到任何一个图像分类任务。


八、总结

这篇文章的完整代码可以在你的知识库中直接运行。核心流程:

复制代码
torchvision  加载 CIFAR-10
    ↓
transforms  数据增强 + 归一化
    ↓
DataLoader  批量加载 + 打乱
    ↓
nn.Module   自定义 CNN 架构
    ↓
训练循环    前向 → 损失 → 反向 → 更新
    ↓
评估       准确率、各类别分析、可视化
    ↓
保存/推理   .pth 文件加载预测

核心三句话

  1. 数据准备决定上限------好的数据增强比改模型结构更有效
  2. 先让模型过拟合------先在小批量数据上把训练损失降到接近零,确保模型正确,再用正则化手段提升泛化
  3. PyTorch 训练模板------这个项目的代码结构可以迁移到 90% 的图像分类任务

下一篇文章可以沿着这条线继续深入------用预训练模型做迁移学习,或者从分类走向目标检测。

相关推荐
不爱土豆唯爱马铃薯39 分钟前
MONKEYCODE 教程系列MC-025 | 实战AI客服机器人
人工智能·数据挖掘
雪度娃娃40 分钟前
转向现代C++——保证const成员函数的线程安全性
开发语言·c++
刘婉晴41 分钟前
【火山「AI安全攻防」】恶意Skill检测引擎设计思路分享
人工智能·安全
小王毕业啦1 小时前
2009-2024年 各国清廉指数CPI(xlsx)
大数据·人工智能·数据挖掘·数据分析·社科数据·实证分析·经管数据
原来是猿1 小时前
深入理解 C++ unordered_map 与 unordered_set
开发语言·c++
满天星83035771 小时前
【Qt】信号和槽 (一)(概述和基本使用)
开发语言·c++·qt
syounger1 小时前
从遗留系统到AI运营:富士通转型折射日本企业的数字化再考
人工智能
l1t1 小时前
DeepSeek总结的 waddler,一个 Go 语言编写的从 YAML 文件运行的 ETL 管道
开发语言·golang·etl
DogDaoDao1 小时前
【GitHub】CodeGraph 深度解析:为 AI 编程代理构建预索引代码知识图谱
人工智能·程序员·github·知识图谱·ai编程·ai agent·codegraph