手写数字识别零基础实战:基于PyTorch的CNN完整拆解

最近在Kaggle上跑了一个经典的MNIST手写数字识别项目,用PyTorch搭了一个朴素的CNN,效果还不错,准确率能到99.3%左右。

我现在把整个jupyter notebook 代码贴出来,以供参考:github.com/anjuxi/-CNN...

项目概览

  • 数据集:MNIST(6万训练,1万测试,28×28灰度图)
  • 框架:PyTorch
  • 模型:3层卷积 + 2层全连接,带BatchNorm、Dropout
  • 加速:多GPU并行、混合精度训练
  • 指标:测试集准确率99.3%+

整个项目就一个jupyter文件,方便调试。


1. 环境与基础配置

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import time

首先导入必备的库。PyTorch那一套不用多说,torchvision帮我们处理MNIST,matplotlib用来可视化,time用来统计训练时间。

中间有两行重复设置中文字体的代码,是Kaggle环境的一个小坑。我原本想用SimHei显示中文标题,但Kaggle镜像里没装这个字体,导致满屏的findfont警告。实际在本地跑的话,改成'Arial Unicode MS'或者直接注释掉也行。

python 复制代码
batch_size = 1024*2

batch_size设成2048,这是一个相对激进的选择。一般情况下64、128比较常见,但我这里用了两张T4并行训练,显存够大,而且大batch能让每轮训练更快收敛。不过要注意,batch过大可能让模型精度略降,需要配合学习率调整,Adam的默认学习率0.001在这里依然表现不错。


2. 数据处理

2.1 数据预处理

python 复制代码
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
  • ToTensor():把PIL图片转成PyTorch张量,并且像素值从0-255映射到0-1。
  • Normalize:将单通道图像的均值与标准差归一化。MNIST的经验值是均值0.1307,标准差0.3081。这一步很重要,能让模型更容易收敛。

2.2 加载数据集

python 复制代码
train_dataset = torchvision.datasets.MNIST(
    root='./data', train=True, download=True, transform=transform
)
test_dataset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transform
)

torchvision.datasets.MNIST会自动下载数据到./data目录。这里要注意,Kaggle环境有网络,下载很快。

2.3 构建DataLoader

python 复制代码
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True
)
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
    persistent_workers=True
)
  • num_workers=4:开启4个子进程加载数据,在Kaggle这种多核环境里能加速IO。
  • pin_memory:启用锁页内存,能加快CPU到GPU的数据传输。
  • persistent_workers:让worker进程在epoch之间保持存活,避免重复创建的开销。这些小参数在大batch训练时提升明显。

2.4 数据探索

python 复制代码
print(f"训练集样本数量: {len(train_dataset)}")
print(f"测试集样本数量: {len(test_dataset)}")
print(f"图片尺寸: {train_dataset[0][0].shape}")
print(f"类别数量: {len(train_dataset.classes)}")

输出:

makefile 复制代码
训练集样本数量: 60000
测试集样本数量: 10000
图片尺寸: torch.Size([1, 28, 28])
类别数量: 10

然后我画了12张手写数字图,用plt.subplot(2, 6, i+1)排成两行六列。这里有个小教训:如果标题用了中文字体,Kaggle会缺字体导致方框乱码。后来我把标题改成英文或者在本地运行加中文字体就能解决。


3. 模型搭建

CNN结构如下:

python 复制代码
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        # Conv1: 1→32
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        # Conv2: 32→64
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # Conv3: 64→128
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        # Dropout
        self.dropout = nn.Dropout(0.5)
        # FC
        self.fc1 = nn.Linear(128 * 3 * 3, 256)
        self.relu4 = nn.ReLU()
        self.dropout2 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(256, 10)

整个网络由三个卷积块 + 两个全连接层组成。设计思路:

  • Conv + BN + ReLU + MaxPool:标准三连,每次池化后尺寸减半。
  • 输入 28×28 → 经过3次池化 → 特征图大小变为 3×3(因为28不能被2整除,最后一次池化会变成28/2=14→14/2=7→7/2=3,floor下来是3×3)。
  • 计算全连接输入维度:128 * 3 * 3 = 1152,这个得算对,不然会报错。
  • Dropout(0.5):放在全连接层前后,防止过拟合。因为MNIST比较简单,我加了0.5的丢弃率,效果不错。
  • 最后一层输出10个类别,不用softmax是因为nn.CrossEntropyLoss自带 softmax。

前向传播

python 复制代码
def forward(self, x):
    x = self.pool1(self.relu1(self.bn1(self.conv1(x))))
    x = self.pool2(self.relu2(self.bn2(self.conv2(x))))
    x = self.pool3(self.relu3(self.bn3(self.conv3(x))))
    x = x.view(x.size(0), -1)        # 展平
    x = self.dropout(x)
    x = self.relu4(self.fc1(x))
    x = self.dropout2(x)
    x = self.fc2(x)
    return x

层叠写法,简洁直观。

多GPU与设备检测

python 复制代码
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CNN()
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)
model = model.to(device)

Kaggle这次给的环境有两张T4,torch.cuda.device_count()检测到2,于是用nn.DataParallel包裹模型,自动做数据并行。最后把模型搬到GPU。


4. 训练配置

4.1 损失函数、优化器、学习率调度

python 复制代码
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
  • 交叉熵损失,多分类标配。
  • Adam优化器,初始学习率0.001,能快速收敛。
  • 学习率调度:每10个epoch衰减为原来的0.1倍。因为训练到后期loss几乎不降了,这样做可以让模型微调至更优解。我试过不衰减,最终精度会低0.1%左右。

4.2 混合精度加速

python 复制代码
scaler = torch.amp.GradScaler('cuda') if torch.cuda.is_available() else None

混合精度训练可以让计算速度加快,显存占用减少。GradScaler用来缩放loss防止梯度下溢,一般和autocast搭配使用。


5. 训练与测试函数

5.1 训练函数

python 复制代码
def train(model, train_loader, criterion, optimizer, device, scaler=None):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        if scaler is not None:
            with torch.cuda.amp.autocast():   # 前方高能:新版PyTorch这里会warning
                output = model(data)
                loss = criterion(output, target)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
        running_loss += loss.item()
        _, predicted = output.max(1)
        total += target.size(0)
        correct += predicted.eq(target).sum().item()
    avg_loss = running_loss / len(train_loader)
    accuracy = 100. * correct / total
    return avg_loss, accuracy

这里混合精度部分使用了torch.cuda.amp.autocast(),但在PyTorch新版本里会提示FutureWarning,建议改成torch.amp.autocast('cuda')。我当时没改,控制台会刷一排警告,不影响运行但看着碍眼。

5.2 测试函数

python 复制代码
def test(model, test_loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            running_loss += loss.item()
            _, predicted = output.max(1)
            total += target.size(0)
            correct += predicted.eq(target).sum().item()
    avg_loss = running_loss / len(test_loader)
    accuracy = 100. * correct / total
    return avg_loss, accuracy

常规的model.eval()torch.no_grad(),记得关掉Dropout和BN的更新,且不计算梯度。


6. 训练循环与保存最佳模型

python 复制代码
num_epochs = 50
best_accuracy = 0.0
for epoch in range(num_epochs):
    epoch_start = time.time()
    train_loss, train_acc = train(model, train_loader, criterion, optimizer, device, scaler)
    test_loss, test_acc = test(model, test_loader, criterion, device)
    scheduler.step()
    # ... 记录数据、打印日志 ...
    if test_acc > best_accuracy:
        best_accuracy = test_acc
        if isinstance(model, nn.DataParallel):
            torch.save(model.module.state_dict(), 'mnist_cnn_best.pth')
        else:
            torch.save(model.state_dict(), 'mnist_cnn_best.pth')

这里我设了50个epoch,其实10~15个epoch后就接近饱和了。保存最佳模型时要注意:如果模型是被DataParallel包裹的,保存model.state_dict()会带有module.前缀,后面加载时需要同样包裹;而保存model.module.state_dict()则是干净的原始结构。我用后一种方式,方便后续单卡推理。

训练输出大致如下:

yaml 复制代码
Epoch [1/50] Train Loss: 0.8236, Train Acc: 73.27% Test Loss: 0.1384, Test Acc: 95.75%
...
Epoch [15/50] Train Loss: 0.0281, Train Acc: 99.17% Test Loss: 0.0202, Test Acc: 99.36%

可以看到第1个epoch后测试准确率就有95.75%,说明模型学习能力很强。最终收敛在99.3%附近,之后基本不再增长。

7. 训练过程可视化

python 复制代码
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
# ... 设置标签、图例 ...
plt.subplot(1, 2, 2)
plt.plot(train_accuracies, label='Train Accuracy')
plt.plot(test_accuracies, label='Test Accuracy')
# ...

画了loss和accuracy的曲线图。从图里明显看出:训练loss平滑下降,测试loss在10epoch左右基本走平,之后学习率衰减让loss又小降了一点。准确率曲线也很漂亮,训练和测试的差距很小,说明过拟合控制得不错。


8. 模型保存与加载

python 复制代码
if isinstance(model, nn.DataParallel):
    torch.save(model.module.state_dict(), 'mnist_cnn_model.pth')
else:
    torch.save(model.state_dict(), 'mnist_cnn_model.pth')

保存最终模型(不一定最佳)。后面如果要用,可以这样加载:

python 复制代码
model = CNN()
model.load_state_dict(torch.load('mnist_cnn_model.pth'))
model = model.to(device)
model.eval()

9. 预测结果可视化

python 复制代码
def visualize_predictions(model, test_loader, device, num_images=16):
    model.eval()
    images, labels = next(iter(test_loader))
    images = images.to(device)
    with torch.no_grad():
        outputs = model(images)
        _, predicted = outputs.max(1)
    images = images.cpu()
    predicted = predicted.cpu()
    # 画4x4网格
    ...

我随机取了一个batch的前16张图,画成4×4网格,真实标签用黑色,预测标签用绿色(正确)或红色(错误)。这波直接拿到100%正确率,一个都没错,看着很爽。


10. 混淆矩阵与分类报告

python 复制代码
from sklearn.metrics import confusion_matrix, classification_report

遍历整个测试集,收集所有预测结果和真实标签,然后用confusion_matrix生成矩阵,再用plt.imshow画热力图。主对角线很亮,其他位置基本没啥数字,说明模型各个类别都区分得很好。

分类报告:

markdown 复制代码
              precision    recall  f1-score   support
    accuracy                           0.99     10000
   macro avg       0.99      0.99      0.99     10000
weighted avg       0.99      0.99      0.99     10000

所有类别的F1都在0.99左右,模型很均衡。


11. 单个图片预测测试

最后我写了个predict_single_image函数,随机抽5张测试图,打印真实标签、预测标签和置信度,并给出Top-3预测概率。

python 复制代码
def predict_single_image(model, image, device):
    model.eval()
    with torch.no_grad():
        image = image.unsqueeze(0).to(device)
        output = model(image)
        probabilities = torch.softmax(output, dim=1)
        confidence, predicted = torch.max(probabilities, 1)
        return predicted.item(), confidence.item(), probabilities.cpu().numpy()[0]

输出示例:

yaml 复制代码
图片 3656:
  真实标签: 7
  预测标签: 7
  置信度: 1.0000
  Top-3 预测:
    1. 数字 7: 1.0000
    2. 数字 2: 0.0000
    3. 数字 9: 0.0000

模型非常自信,概率直接拉满。注意这里的"1.0000"其实是因为softmax输出的数值太小被四舍五入了,实际概率是0.9999几。这也说明模型对这些样例完全没有疑惑。


本文首发于掘金,作者Ailan Anjuxi,转载请注明出处。

相关推荐
jiucaixiuyang1 小时前
散户如何使用手机T0算法?
算法·量化·t0
阿Y加油吧2 小时前
算法二刷复盘:LeetCode 79 单词搜索 & 131 分割回文串(Java 回溯精讲)
java·算法·leetcode
徐新帅2 小时前
4164:【GESP2512七级】学习⼩组
算法
北顾笙9802 小时前
day30-数据结构力扣
数据结构·算法·leetcode
爱写代码的倒霉蛋2 小时前
天梯赛经验总结(细节篇)
经验分享·算法
Hello!!!!!!2 小时前
C++基础(五)——屏幕和文件输入输出
开发语言·c++·算法
Rnan-prince2 小时前
Count-Min Sketch:海量数据频率统计的“轻量级计数器“
python·算法
王老师青少年编程2 小时前
csp信奥赛C++高频考点专项训练之贪心算法 --【排序贪心】:加工生产调度
c++·算法·贪心·csp·信奥赛·排序贪心·加工生产调度
三毛的二哥2 小时前
BEV:MapTR
人工智能·算法·计算机视觉·3d