PyTorch 实战:从零搭建手写数字识别系统(CNN 卷积神经网络)

<div align="center">PyTorch 实战:从零搭建手写数字识别系统(CNN 卷积神经网络)

PyTorch 实战:从零搭建手写数字识别系统(CNN 卷积神经网络)


代码链接:https://pan.quark.cn/s/ecc8ad46bfc8

目录

  • 一、项目简介\](#一项目简介)

  • 三、项目结构\](#三项目结构)

  • 五、数据处理与增强\](#五数据处理与增强)

  • 七、模型预测与评估\](#七模型预测与评估)

  • 九、训练结果与分析\](#九训练结果与分析)

  • 十一、完整源码获取\](#十一完整源码获取)

一、项目简介

1.1 什么是手写数字识别?

手写数字识别是计算机视觉领域最经典的入门问题之一。它的目标是让计算机能够自动识别用户手写的 0-9 这 10 个数字。这个看似简单的问题,实际上是 OCR(光学字符识别)技术的基础,在邮政编码识别、银行支票处理、车牌识别等场景中有着广泛应用。

1.2 为什么选择 MNIST?

**MNIST**(Modified National Institute of Standards and Technology)是机器学习领域的 "Hello World" 数据集:

| 属性 | 说明 |

|------|------|

| 数据量 | 70,000 张灰度图片(60,000 训练 + 10,000 测试) |

| 图片尺寸 | 28 × 28 像素 |

| 颜色通道 | 单通道(灰度图) |

| 分类数 | 10 类(数字 0-9) |

| 数据大小 | 约 12 MB |

1.3 本项目亮点

  • **三种网络架构**对比:全连接网络 vs 轻量级 CNN vs 深度 CNN

  • **数据增强**:随机旋转、平移,提升模型泛化能力

  • **训练可视化**:实时 loss/accuracy 曲线、混淆矩阵

  • **Web 在线演示**:Flask + Canvas 画板,即画即识别

  • **完整工程化**:模块化代码,注释详尽,开箱即用


二、环境配置

2.1 基础环境

```

操作系统:Windows 10/11、Ubuntu 20.04+、macOS

Python:3.8 或更高版本

CUDA:11.8+(如需 GPU 加速)

```

2.2 创建虚拟环境(推荐)

```bash

使用 conda 创建虚拟环境

conda create -n digit_recog python=3.10

conda activate digit_recog

或使用 venv

python -m venv venv

Windows

venv\Scripts\activate

Linux/Mac

source venv/bin/activate

```

2.3 安装依赖

```bash

安装 PyTorch(GPU 版本)

conda install pytorch torchvision pytorch-cuda=11.8 -c pytorch -c nvidia

或 CPU 版本

conda install pytorch torchvision cpuonly -c pytorch

安装其他依赖

pip install matplotlib numpy Pillow flask tqdm scikit-learn seaborn

```

**requirements.txt:**

```text

torch>=2.0.0

torchvision>=0.15.0

matplotlib>=3.7.0

numpy>=1.24.0

Pillow>=9.5.0

flask>=2.3.0

tqdm>=4.65.0

```

一键安装:

```bash

pip install -r requirements.txt

```

2.4 验证安装

```python

import torch

print(f"PyTorch 版本: {torch.version}")

print(f"CUDA 可用: {torch.cuda.is_available()}")

if torch.cuda.is_available():

print(f"GPU: {torch.cuda.get_device_name(0)}")

```

输出示例:

```

PyTorch 版本: 2.1.0

CUDA 可用: True

GPU: NVIDIA GeForce RTX 3060

```


三、项目结构

```

digit_recognition/

├── model.py # 模型定义(CNN 网络结构)

├── train.py # 训练脚本(完整训练流程)

├── predict.py # 预测脚本(单张/批量预测)

├── utils.py # 工具函数(数据加载、可视化)

├── web_app.py # Flask Web 应用(在线演示)

├── requirements.txt # 依赖列表

├── data/ # MNIST 数据集(自动下载)

├── checkpoints/ # 模型权重文件

│ └── best_model.pth # 最佳模型

└── results/ # 训练结果图片

├── training_history.png

├── confusion_matrix.png

└── prediction_result.png

```


四、模型架构详解

4.1 核心思路:卷积神经网络(CNN)

CNN 是处理图像任务的首选架构,它通过以下核心操作提取图像特征:

```

输入图像 → [卷积提取特征] → [池化压缩] → [全连接分类] → 输出概率

```

**为什么 CNN 比全连接网络更适合图像?**

| 对比项 | 全连接网络 | 卷积神经网络 |

|--------|-----------|-------------|

| 参数量 | 784×256 + 256×10 = 203,776 | 约 30,000-200,000 |

| 空间信息 | 丢失 | 保留 |

| 平移不变性 | 无 | 有 |

| 图像任务效果 | 一般 | 优秀 |

4.2 模型代码实现

```python

import torch

import torch.nn as nn

import torch.nn.functional as F

class DigitCNN(nn.Module):

"""

三层卷积神经网络

结构: Conv→BN→ReLU→Pool × 2 + Conv→BN→ReLU + FC × 2

"""

def init(self):

super(DigitCNN, self).init()

第一个卷积块

self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)

输入: 1×28×28 → 输出: 32×28×28

self.bn1 = nn.BatchNorm2d(32)

self.pool1 = nn.MaxPool2d(2, 2)

输出: 32×14×14

第二个卷积块

self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)

输入: 32×14×14 → 输出: 64×14×14

self.bn2 = nn.BatchNorm2d(64)

self.pool2 = nn.MaxPool2d(2, 2)

输出: 64×7×7

第三个卷积块

self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)

输入: 64×7×7 → 输出: 128×7×7

self.bn3 = nn.BatchNorm2d(128)

全连接层

self.fc1 = nn.Linear(128 * 7 * 7, 256)

self.dropout = nn.Dropout(0.5)

self.fc2 = nn.Linear(256, 10)

def forward(self, x):

卷积块1: Conv → BN → ReLU → Pool

x = self.pool1(F.relu(self.bn1(self.conv1(x))))

卷积块2: Conv → BN → ReLU → Pool

x = self.pool2(F.relu(self.bn2(self.conv2(x))))

卷积块3: Conv → BN → ReLU(不池化,保留空间信息)

x = F.relu(self.bn3(self.conv3(x)))

展平: [batch, 128, 7, 7] → [batch, 128*7*7]

x = x.view(-1, 128 * 7 * 7)

全连接层 + Dropout

x = F.relu(self.fc1(x))

x = self.dropout(x)

x = self.fc2(x)

return x

```

4.3 模型结构图解

```

┌─────────────────────────────────────────────────────────┐

│ 输入: 1×28×28 │

├─────────────────────────────────────────────────────────┤

│ Conv2d(1→32, 3×3) + BN + ReLU + MaxPool(2×2) │

│ 输出: 32×14×14 │

├─────────────────────────────────────────────────────────┤

│ Conv2d(32→64, 3×3) + BN + ReLU + MaxPool(2×2) │

│ 输出: 64×7×7 │

├─────────────────────────────────────────────────────────┤

│ Conv2d(64→128, 3×3) + BN + ReLU │

│ 输出: 128×7×7 │

├─────────────────────────────────────────────────────────┤

│ Flatten → 6272 │

├─────────────────────────────────────────────────────────┤

│ Linear(6272→256) + ReLU + Dropout(0.5) │

├─────────────────────────────────────────────────────────┤

│ Linear(256→10) │

├─────────────────────────────────────────────────────────┤

│ 输出: 10(数字0-9概率) │

└─────────────────────────────────────────────────────────┘

```

4.4 关键组件解析

**① 卷积层(Conv2d)**

卷积操作通过滑动窗口提取局部特征。浅层提取边缘、角点等低级特征,深层提取数字形状等高级特征。

```python

nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)

padding=1 保持输入输出尺寸一致

```

**② 批归一化(BatchNorm)**

加速训练收敛,稳定梯度传播,有轻微正则化效果。

**③ Dropout**

训练时随机丢弃 50% 的神经元,防止过拟合。测试时自动关闭。

```python

self.dropout = nn.Dropout(0.5) # 训练时随机丢弃50%

```

**④ 池化层(MaxPool)**

将特征图尺寸减半,降低计算量,同时提取最显著特征。


五、数据处理与增强

5.1 MNIST 数据加载

```python

from torchvision import datasets, transforms

from torch.utils.data import DataLoader

def get_data_loaders(batch_size=64, data_dir="./data"):

训练集:加入数据增强

train_transform = transforms.Compose([

transforms.RandomRotation(10), # 随机旋转±10°

transforms.RandomAffine(0, translate=(0.1, 0.1)), # 随机平移

transforms.ToTensor(), # 转为Tensor

transforms.Normalize((0.1307,), (0.3081,)) # 标准化

])

测试集:只做基础预处理

test_transform = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize((0.1307,), (0.3081,))

])

train_dataset = datasets.MNIST(

root=data_dir, train=True, download=True, transform=train_transform

)

test_dataset = datasets.MNIST(

root=data_dir, train=False, download=True, transform=test_transform

)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

return train_loader, test_loader

```

5.2 数据增强的意义

```

原始图片 随机旋转±10° 随机平移10%

┌───┐ ┌───┐ ┌───┐

│ 5 │ → │ 5 │ → │ 5 │

└───┘ └───┘ └───┘

微倾斜 位置偏移

```

数据增强让模型学会"容忍"手写数字的自然变化,显著提升泛化能力。**注意:测试集不做增强**,确保评估的客观性。

5.3 标准化的参数来源

```python

MNIST 数据集的全局统计量

mean = 0.1307 # 像素均值

std = 0.3081 # 像素标准差

标准化公式: x_normalized = (x - mean) / std

```

这两个值是在整个 MNIST 训练集上计算得到的,标准化后数据分布更稳定,有利于模型收敛。


六、模型训练

6.1 训练脚本核心逻辑

```python

import torch

import torch.nn as nn

import torch.optim as optim

from torch.optim.lr_scheduler import CosineAnnealingLR

from tqdm import tqdm

def train(args):

设备配置

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

数据加载

train_loader, test_loader = get_data_loaders(batch_size=64)

模型初始化

model = DigitCNN().to(device)

损失函数:交叉熵(分类任务标配)

criterion = nn.CrossEntropyLoss()

优化器:Adam(自适应学习率)

optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

学习率调度:余弦退火

scheduler = CosineAnnealingLR(optimizer, T_max=15)

训练循环

best_acc = 0.0

for epoch in range(1, 16):

model.train()

running_loss = 0.0

correct = 0

total = 0

for images, labels in tqdm(train_loader, desc=f"Epoch {epoch}"):

images, labels = images.to(device), labels.to(device)

前向传播

outputs = model(images)

loss = criterion(outputs, labels)

反向传播

optimizer.zero_grad()

loss.backward()

optimizer.step()

统计

running_loss += loss.item()

_, predicted = torch.max(outputs, 1)

total += labels.size(0)

correct += (predicted == labels).sum().item()

scheduler.step()

验证

test_acc = evaluate(model, test_loader, device)

print(f"Epoch {epoch}: Train Acc={100*correct/total:.2f}%, Test Acc={test_acc:.2f}%")

保存最佳模型

if test_acc > best_acc:

best_acc = test_acc

torch.save(model.state_dict(), "checkpoints/best_model.pth")

```

6.2 启动训练

```bash

使用默认参数训练(推荐)

python train.py --epochs 15 --visualize

自定义参数

python train.py \

--model digitcnn \

--batch_size 128 \

--epochs 20 \

--lr 0.001 \

--scheduler cosine

使用轻量级模型快速验证

python train.py --model simple --epochs 10

```

6.3 训练过程输出

```

============================================================

手写数字识别 - 模型训练

============================================================

设备: cuda

GPU: NVIDIA GeForce RTX 3060

显存: 12.0 GB

模型: digitcnn

批次大小: 64

训练轮次: 15

学习率: 0.001

============================================================

加载 MNIST 数据集...

训练集: 60000 张图片

测试集: 10000 张图片

开始训练...

──────────────────────────────────────────────────

Epoch [1/15] LR: 0.001000

──────────────────────────────────────────────────

Training: 100%|████████████████████████| 938/938 [00:28]

训练 - Loss: 0.2156, Acc: 93.42%

测试 - Loss: 0.0487, Acc: 98.56%

最佳模型已保存! 准确率: 98.56%

...

──────────────────────────────────────────────────

Epoch [15/15] LR: 0.000044

──────────────────────────────────────────────────

Training: 100%|████████████████████████| 938/938 [00:26]

训练 - Loss: 0.0234, Acc: 99.31%

测试 - Loss: 0.0198, Acc: 99.38%

最佳模型已保存! 准确率: 99.38%

============================================================

训练完成!

============================================================

⏱ 总耗时: 423.5 秒

最佳测试准确率: 99.38%

模型保存路径: checkpoints/best_model.pth

```

6.4 训练技巧详解

**① 优化器选择:Adam vs SGD**

| 优化器 | 特点 | 适用场景 |

|--------|------|---------|

| Adam | 自适应学习率,收敛快 | 大多数场景,快速验证 |

| SGD+Momentum | 需要调参,泛化更好 | 追求极致精度 |

**② 学习率调度:余弦退火**

```

学习率

0.001 ┤●

│ ╲

│ ╲

│ ╲

│ ╲

0.0001┤ ●

└──────────────→ Epoch

1 15

```

余弦退火让学习率从初始值平滑下降,避免训练后期的震荡。

**③ 权重衰减(Weight Decay)**

```python

optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

```

L2 正则化,惩罚过大的权重,防止过拟合。

**④ 早停策略(Early Stopping)**

```python

if patience_counter >= patience: # 连续5个epoch未提升

print("早停触发!")

break

```

当验证集准确率不再提升时提前终止训练,节省时间并防止过拟合。


七、模型预测与评估

7.1 单张图片预测

```python

import torch

from PIL import Image

from model import DigitCNN

def predict_digit(image_path, model_path="checkpoints/best_model.pth"):

加载模型

model = DigitCNN()

checkpoint = torch.load(model_path, map_location="cpu")

model.load_state_dict(checkpoint["model_state_dict"])

model.eval()

预处理图片

img = Image.open(image_path).convert("L") # 转灰度

img = img.resize((28, 28), Image.LANCZOS) # 缩放

img_array = np.array(img, dtype=np.float32)

反转颜色(白底黑字 → 黑底白字)

if img_array.mean() > 127:

img_array = 255 - img_array

归一化 + 标准化

img_array = img_array / 255.0

img_array = (img_array - 0.1307) / 0.3081

转为 Tensor

tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0)

预测

with torch.no_grad():

output = model(tensor)

probs = torch.softmax(output, dim=1)

predicted = torch.argmax(probs).item()

confidence = probs[0][predicted].item()

return predicted, confidence

使用示例

digit, conf = predict_digit("my_digit.png")

print(f"预测结果: {digit}, 置信度: {conf:.2%}")

```

**命令行预测:**

```bash

预测单张图片

python predict.py --image my_digit.png --visualize

从测试集随机预测

python predict.py --test_set --visualize

```

7.2 混淆矩阵分析

混淆矩阵是评估分类模型的重要工具,它展示了每个数字的预测情况:

```

预测 → 0 1 2 3 4 5 6 7 8 9

真实 ↓

0 [ 976 0 1 0 0 1 1 1 0 0 ]

1 [ 0 1132 1 1 0 0 0 1 0 0 ]

2 [ 1 1 1024 1 1 0 0 3 1 0 ]

3 [ 0 0 2 1001 0 3 0 1 2 1 ]

4 [ 0 0 1 0 975 0 2 0 1 3 ]

5 [ 2 0 0 4 1 884 2 1 0 0 ]

6 [ 3 2 0 0 2 2 949 0 0 0 ]

7 [ 0 3 4 1 0 0 0 1018 1 1 ]

8 [ 1 0 2 2 1 1 0 1 966 0 ]

9 [ 1 1 0 1 5 2 0 3 2 994 ]

```

**分析:**

  • 对角线上的数字表示正确预测的数量

  • 99.38% 的总体准确率

  • 数字 "5" 容易被误判为 "3"(手写体相似)

  • 数字 "4" 和 "9" 有时会混淆

7.3 错误样本分析

查看模型预测错误的样本,有助于理解模型的局限性:

```python

找出所有预测错误的样本

wrong_indices = []

for i, (pred, true) in enumerate(zip(all_preds, all_labels)):

if pred != true:

wrong_indices.append(i)

print(f"错误样本数: {len(wrong_indices)} / {len(all_labels)}")

print(f"准确率: {100 * (1 - len(wrong_indices)/len(all_labels)):.2f}%")

```


八、Web 可视化部署

8.1 Flask Web 应用

我们实现了一个在线手写识别演示,用户可以在画板上书写数字并实时获得识别结果。

```bash

启动 Web 服务

python web_app.py

```

访问 `http://localhost:5000` 即可体验。

8.2 前端核心实现

```html

<!-- Canvas 画板 -->

<canvas id="canvas" width="280" height="280"></canvas>

<script>

const canvas = document.getElementById('canvas');

const ctx = canvas.getContext('2d');

// 初始化黑色背景

ctx.fillStyle = '#000';

ctx.fillRect(0, 0, canvas.width, canvas.height);

// 白色画笔,线宽18

ctx.strokeStyle = '#fff';

ctx.lineWidth = 18;

ctx.lineCap = 'round';

// 绑定鼠标事件

canvas.addEventListener('mousedown', startDraw);

canvas.addEventListener('mousemove', draw);

canvas.addEventListener('mouseup', stopDraw);

</script>

```

8.3 后端预测接口

```python

@app.route('/predict', methods=['POST'])

def predict():

接收 base64 编码的图片

data = request.get_json()

image_data = data['image'].split(',')[1]

解码并预处理

image = Image.open(io.BytesIO(base64.b64decode(image_data)))

image = image.convert('L').resize((28, 28))

归一化 + 标准化

img_array = np.array(image, dtype=np.float32) / 255.0

img_array = (img_array - 0.1307) / 0.3081

tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0)

模型推理

with torch.no_grad():

output = model(tensor)

probs = torch.softmax(output, dim=1).numpy()[0]

return jsonify({

'prediction': int(np.argmax(probs)),

'confidence': float(np.max(probs)),

'probabilities': probs.tolist()

})

```

8.4 界面效果

Web 界面包含以下功能:

  • **画板区域**:支持鼠标和触摸输入

  • **识别按钮**:发送图片到后端进行推理

  • **清除按钮**:清空画板重新书写

  • **概率分布**:展示 10 个数字的预测概率条形图

  • **置信度显示**:高亮显示预测结果和置信度


九、训练结果与分析

9.1 最终性能指标

| 指标 | DigitCNN(本项目) | 简单 CNN | 全连接网络 |

|------|-------------------|----------|-----------|

| **测试准确率** | **99.38%** | 98.92% | 97.65% |

| **参数量** | 1,199,874 | 89,690 | 203,776 |

| **训练时间 (15 epochs)** | ~7 分钟 | ~3 分钟 | ~2 分钟 |

| **模型大小** | 4.6 MB | 0.3 MB | 0.8 MB |

9.2 各数字识别准确率

```

数字 0: 99.80% ████████████████████▉

数字 1: 99.74% ████████████████████▉

数字 2: 99.22% ███████████████████▊

数字 3: 99.11% ███████████████████▊

数字 4: 99.29% ███████████████████▊

数字 5: 98.99% ███████████████████▋

数字 6: 99.06% ███████████████████▊

数字 7: 99.03% ███████████████████▊

数字 8: 99.18% ███████████████████▊

数字 9: 99.11% ███████████████████▊

```

9.3 收敛曲线分析

训练过程中的损失和准确率变化:

  • **Epoch 1-5**:快速收敛阶段,准确率从 93% 快速提升到 99%

  • **Epoch 6-10**:精细调整阶段,学习率下降,准确率缓慢提升

  • **Epoch 11-15**:趋于稳定,最终达到 99.38%


十、常见问题与优化建议

10.1 FAQ

**Q1:训练时 GPU 内存不足怎么办?**

```python

减小 batch_size

python train.py --batch_size 32

或使用混合精度训练

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

with autocast():

output = model(images)

loss = criterion(output, labels)

```

**Q2:如何提升准确率?**

```python

1. 增加数据增强强度

transforms.RandomRotation(15)

transforms.RandomAffine(0, translate=(0.15, 0.15))

2. 使用更深的网络

3. 使用集成学习(多个模型投票)

4. 增加训练轮次

python train.py --epochs 30

```

**Q3:如何用自己的手写图片测试?**

```bash

1. 用画图工具写一个数字

2. 保存为 PNG/JPG

3. 运行预测

python predict.py --image your_digit.png --visualize

```

**注意:** 图片要求:

  • 白底黑字或黑底白字均可(代码会自动处理)

  • 尽量居中书写

  • 数字占图片主要区域

**Q4:模型能否部署到移动端?**

```python

导出为 ONNX 格式

dummy_input = torch.randn(1, 1, 28, 28)

torch.onnx.export(model, dummy_input, "digit_model.onnx")

然后使用 ONNX Runtime 在移动端推理

import onnxruntime as ort

session = ort.InferenceSession("digit_model.onnx")

```

10.2 进阶优化方向

| 方向 | 方法 | 预期效果 |

|------|------|---------|

| 更深网络 | ResNet、DenseNet | 99.5%+ |

| 更强增强 | Cutout、Mixup、CutMix | +0.2% |

| 集成学习 | 多模型投票/融合 | +0.3% |

| 注意力机制 | SE Block、CBAM | +0.2% |

| 模型压缩 | 剪枝、量化、蒸馏 | 模型缩小 4x |

10.3 推荐学习路线

```

MNIST 手写数字识别 (本项目)

CIFAR-10 图像分类 (彩色图片、10类)

ImageNet 分类 (大规模、1000类)

目标检测 (YOLO、Faster R-CNN)

语义分割 (U-Net、DeepLab)

生成模型 (GAN、Diffusion)

```


十一、完整源码获取

11.1 项目文件说明

| 文件 | 功能 | 代码量 |

|------|------|--------|

| `model.py` | CNN 模型定义 | ~60 行 |

| `train.py` | 完整训练流程 | ~180 行 |

| `predict.py` | 单张/批量预测 | ~150 行 |

| `utils.py` | 数据加载、可视化 | ~160 行 |

| `web_app.py` | Flask Web 应用 | ~180 行 |

11.2 快速开始

```bash

1. 克隆项目

git clone https://github.com/your-username/digit_recognition.git

cd digit_recognition

2. 安装依赖

pip install -r requirements.txt

3. 训练模型

python train.py --epochs 15 --visualize

4. 预测

python predict.py --test_set --visualize

5. 启动 Web 演示

python web_app.py

<div align="center">PyTorch 实战:从零搭建手写数字识别系统(CNN 卷积神经网络)

PyTorch 实战:从零搭建手写数字识别系统(CNN 卷积神经网络)


代码链接:https://pan.quark.cn/s/ecc8ad46bfc8

目录

  • 一、项目简介\](#一项目简介)

  • 三、项目结构\](#三项目结构)

  • 五、数据处理与增强\](#五数据处理与增强)

  • 七、模型预测与评估\](#七模型预测与评估)

  • 九、训练结果与分析\](#九训练结果与分析)

  • 十一、完整源码获取\](#十一完整源码获取)

一、项目简介

1.1 什么是手写数字识别?

手写数字识别是计算机视觉领域最经典的入门问题之一。它的目标是让计算机能够自动识别用户手写的 0-9 这 10 个数字。这个看似简单的问题,实际上是 OCR(光学字符识别)技术的基础,在邮政编码识别、银行支票处理、车牌识别等场景中有着广泛应用。

1.2 为什么选择 MNIST?

**MNIST**(Modified National Institute of Standards and Technology)是机器学习领域的 "Hello World" 数据集:

| 属性 | 说明 |

|------|------|

| 数据量 | 70,000 张灰度图片(60,000 训练 + 10,000 测试) |

| 图片尺寸 | 28 × 28 像素 |

| 颜色通道 | 单通道(灰度图) |

| 分类数 | 10 类(数字 0-9) |

| 数据大小 | 约 12 MB |

1.3 本项目亮点

  • **三种网络架构**对比:全连接网络 vs 轻量级 CNN vs 深度 CNN

  • **数据增强**:随机旋转、平移,提升模型泛化能力

  • **训练可视化**:实时 loss/accuracy 曲线、混淆矩阵

  • **Web 在线演示**:Flask + Canvas 画板,即画即识别

  • **完整工程化**:模块化代码,注释详尽,开箱即用


二、环境配置

2.1 基础环境

```

操作系统:Windows 10/11、Ubuntu 20.04+、macOS

Python:3.8 或更高版本

CUDA:11.8+(如需 GPU 加速)

```

2.2 创建虚拟环境(推荐)

```bash

使用 conda 创建虚拟环境

conda create -n digit_recog python=3.10

conda activate digit_recog

或使用 venv

python -m venv venv

Windows

venv\Scripts\activate

Linux/Mac

source venv/bin/activate

```

2.3 安装依赖

```bash

安装 PyTorch(GPU 版本)

conda install pytorch torchvision pytorch-cuda=11.8 -c pytorch -c nvidia

或 CPU 版本

conda install pytorch torchvision cpuonly -c pytorch

安装其他依赖

pip install matplotlib numpy Pillow flask tqdm scikit-learn seaborn

```

**requirements.txt:**

```text

torch>=2.0.0

torchvision>=0.15.0

matplotlib>=3.7.0

numpy>=1.24.0

Pillow>=9.5.0

flask>=2.3.0

tqdm>=4.65.0

```

一键安装:

```bash

pip install -r requirements.txt

```

2.4 验证安装

```python

import torch

print(f"PyTorch 版本: {torch.version}")

print(f"CUDA 可用: {torch.cuda.is_available()}")

if torch.cuda.is_available():

print(f"GPU: {torch.cuda.get_device_name(0)}")

```

输出示例:

```

PyTorch 版本: 2.1.0

CUDA 可用: True

GPU: NVIDIA GeForce RTX 3060

```


三、项目结构

```

digit_recognition/

├── model.py # 模型定义(CNN 网络结构)

├── train.py # 训练脚本(完整训练流程)

├── predict.py # 预测脚本(单张/批量预测)

├── utils.py # 工具函数(数据加载、可视化)

├── web_app.py # Flask Web 应用(在线演示)

├── requirements.txt # 依赖列表

├── data/ # MNIST 数据集(自动下载)

├── checkpoints/ # 模型权重文件

│ └── best_model.pth # 最佳模型

└── results/ # 训练结果图片

├── training_history.png

├── confusion_matrix.png

└── prediction_result.png

```


四、模型架构详解

4.1 核心思路:卷积神经网络(CNN)

CNN 是处理图像任务的首选架构,它通过以下核心操作提取图像特征:

```

输入图像 → [卷积提取特征] → [池化压缩] → [全连接分类] → 输出概率

```

**为什么 CNN 比全连接网络更适合图像?**

| 对比项 | 全连接网络 | 卷积神经网络 |

|--------|-----------|-------------|

| 参数量 | 784×256 + 256×10 = 203,776 | 约 30,000-200,000 |

| 空间信息 | 丢失 | 保留 |

| 平移不变性 | 无 | 有 |

| 图像任务效果 | 一般 | 优秀 |

4.2 模型代码实现

```python

import torch

import torch.nn as nn

import torch.nn.functional as F

class DigitCNN(nn.Module):

"""

三层卷积神经网络

结构: Conv→BN→ReLU→Pool × 2 + Conv→BN→ReLU + FC × 2

"""

def init(self):

super(DigitCNN, self).init()

第一个卷积块

self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)

输入: 1×28×28 → 输出: 32×28×28

self.bn1 = nn.BatchNorm2d(32)

self.pool1 = nn.MaxPool2d(2, 2)

输出: 32×14×14

第二个卷积块

self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)

输入: 32×14×14 → 输出: 64×14×14

self.bn2 = nn.BatchNorm2d(64)

self.pool2 = nn.MaxPool2d(2, 2)

输出: 64×7×7

第三个卷积块

self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)

输入: 64×7×7 → 输出: 128×7×7

self.bn3 = nn.BatchNorm2d(128)

全连接层

self.fc1 = nn.Linear(128 * 7 * 7, 256)

self.dropout = nn.Dropout(0.5)

self.fc2 = nn.Linear(256, 10)

def forward(self, x):

卷积块1: Conv → BN → ReLU → Pool

x = self.pool1(F.relu(self.bn1(self.conv1(x))))

卷积块2: Conv → BN → ReLU → Pool

x = self.pool2(F.relu(self.bn2(self.conv2(x))))

卷积块3: Conv → BN → ReLU(不池化,保留空间信息)

x = F.relu(self.bn3(self.conv3(x)))

展平: [batch, 128, 7, 7] → [batch, 128*7*7]

x = x.view(-1, 128 * 7 * 7)

全连接层 + Dropout

x = F.relu(self.fc1(x))

x = self.dropout(x)

x = self.fc2(x)

return x

```

4.3 模型结构图解

```

┌─────────────────────────────────────────────────────────┐

│ 输入: 1×28×28 │

├─────────────────────────────────────────────────────────┤

│ Conv2d(1→32, 3×3) + BN + ReLU + MaxPool(2×2) │

│ 输出: 32×14×14 │

├─────────────────────────────────────────────────────────┤

│ Conv2d(32→64, 3×3) + BN + ReLU + MaxPool(2×2) │

│ 输出: 64×7×7 │

├─────────────────────────────────────────────────────────┤

│ Conv2d(64→128, 3×3) + BN + ReLU │

│ 输出: 128×7×7 │

├─────────────────────────────────────────────────────────┤

│ Flatten → 6272 │

├─────────────────────────────────────────────────────────┤

│ Linear(6272→256) + ReLU + Dropout(0.5) │

├─────────────────────────────────────────────────────────┤

│ Linear(256→10) │

├─────────────────────────────────────────────────────────┤

│ 输出: 10(数字0-9概率) │

└─────────────────────────────────────────────────────────┘

```

4.4 关键组件解析

**① 卷积层(Conv2d)**

卷积操作通过滑动窗口提取局部特征。浅层提取边缘、角点等低级特征,深层提取数字形状等高级特征。

```python

nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)

padding=1 保持输入输出尺寸一致

```

**② 批归一化(BatchNorm)**

加速训练收敛,稳定梯度传播,有轻微正则化效果。

**③ Dropout**

训练时随机丢弃 50% 的神经元,防止过拟合。测试时自动关闭。

```python

self.dropout = nn.Dropout(0.5) # 训练时随机丢弃50%

```

**④ 池化层(MaxPool)**

将特征图尺寸减半,降低计算量,同时提取最显著特征。


五、数据处理与增强

5.1 MNIST 数据加载

```python

from torchvision import datasets, transforms

from torch.utils.data import DataLoader

def get_data_loaders(batch_size=64, data_dir="./data"):

训练集:加入数据增强

train_transform = transforms.Compose([

transforms.RandomRotation(10), # 随机旋转±10°

transforms.RandomAffine(0, translate=(0.1, 0.1)), # 随机平移

transforms.ToTensor(), # 转为Tensor

transforms.Normalize((0.1307,), (0.3081,)) # 标准化

])

测试集:只做基础预处理

test_transform = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize((0.1307,), (0.3081,))

])

train_dataset = datasets.MNIST(

root=data_dir, train=True, download=True, transform=train_transform

)

test_dataset = datasets.MNIST(

root=data_dir, train=False, download=True, transform=test_transform

)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

return train_loader, test_loader

```

5.2 数据增强的意义

```

原始图片 随机旋转±10° 随机平移10%

┌───┐ ┌───┐ ┌───┐

│ 5 │ → │ 5 │ → │ 5 │

└───┘ └───┘ └───┘

微倾斜 位置偏移

```

数据增强让模型学会"容忍"手写数字的自然变化,显著提升泛化能力。**注意:测试集不做增强**,确保评估的客观性。

5.3 标准化的参数来源

```python

MNIST 数据集的全局统计量

mean = 0.1307 # 像素均值

std = 0.3081 # 像素标准差

标准化公式: x_normalized = (x - mean) / std

```

这两个值是在整个 MNIST 训练集上计算得到的,标准化后数据分布更稳定,有利于模型收敛。


六、模型训练

6.1 训练脚本核心逻辑

```python

import torch

import torch.nn as nn

import torch.optim as optim

from torch.optim.lr_scheduler import CosineAnnealingLR

from tqdm import tqdm

def train(args):

设备配置

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

数据加载

train_loader, test_loader = get_data_loaders(batch_size=64)

模型初始化

model = DigitCNN().to(device)

损失函数:交叉熵(分类任务标配)

criterion = nn.CrossEntropyLoss()

优化器:Adam(自适应学习率)

optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

学习率调度:余弦退火

scheduler = CosineAnnealingLR(optimizer, T_max=15)

训练循环

best_acc = 0.0

for epoch in range(1, 16):

model.train()

running_loss = 0.0

correct = 0

total = 0

for images, labels in tqdm(train_loader, desc=f"Epoch {epoch}"):

images, labels = images.to(device), labels.to(device)

前向传播

outputs = model(images)

loss = criterion(outputs, labels)

反向传播

optimizer.zero_grad()

loss.backward()

optimizer.step()

统计

running_loss += loss.item()

_, predicted = torch.max(outputs, 1)

total += labels.size(0)

correct += (predicted == labels).sum().item()

scheduler.step()

验证

test_acc = evaluate(model, test_loader, device)

print(f"Epoch {epoch}: Train Acc={100*correct/total:.2f}%, Test Acc={test_acc:.2f}%")

保存最佳模型

if test_acc > best_acc:

best_acc = test_acc

torch.save(model.state_dict(), "checkpoints/best_model.pth")

```

6.2 启动训练

```bash

使用默认参数训练(推荐)

python train.py --epochs 15 --visualize

自定义参数

python train.py \

--model digitcnn \

--batch_size 128 \

--epochs 20 \

--lr 0.001 \

--scheduler cosine

使用轻量级模型快速验证

python train.py --model simple --epochs 10

```

6.3 训练过程输出

```

============================================================

手写数字识别 - 模型训练

============================================================

设备: cuda

GPU: NVIDIA GeForce RTX 3060

显存: 12.0 GB

模型: digitcnn

批次大小: 64

训练轮次: 15

学习率: 0.001

============================================================

加载 MNIST 数据集...

训练集: 60000 张图片

测试集: 10000 张图片

开始训练...

──────────────────────────────────────────────────

Epoch [1/15] LR: 0.001000

──────────────────────────────────────────────────

Training: 100%|████████████████████████| 938/938 [00:28]

训练 - Loss: 0.2156, Acc: 93.42%

测试 - Loss: 0.0487, Acc: 98.56%

最佳模型已保存! 准确率: 98.56%

...

──────────────────────────────────────────────────

Epoch [15/15] LR: 0.000044

──────────────────────────────────────────────────

Training: 100%|████████████████████████| 938/938 [00:26]

训练 - Loss: 0.0234, Acc: 99.31%

测试 - Loss: 0.0198, Acc: 99.38%

最佳模型已保存! 准确率: 99.38%

============================================================

训练完成!

============================================================

⏱ 总耗时: 423.5 秒

最佳测试准确率: 99.38%

模型保存路径: checkpoints/best_model.pth

```

6.4 训练技巧详解

**① 优化器选择:Adam vs SGD**

| 优化器 | 特点 | 适用场景 |

|--------|------|---------|

| Adam | 自适应学习率,收敛快 | 大多数场景,快速验证 |

| SGD+Momentum | 需要调参,泛化更好 | 追求极致精度 |

**② 学习率调度:余弦退火**

```

学习率

0.001 ┤●

│ ╲

│ ╲

│ ╲

│ ╲

0.0001┤ ●

└──────────────→ Epoch

1 15

```

余弦退火让学习率从初始值平滑下降,避免训练后期的震荡。

**③ 权重衰减(Weight Decay)**

```python

optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)

```

L2 正则化,惩罚过大的权重,防止过拟合。

**④ 早停策略(Early Stopping)**

```python

if patience_counter >= patience: # 连续5个epoch未提升

print("早停触发!")

break

```

当验证集准确率不再提升时提前终止训练,节省时间并防止过拟合。


七、模型预测与评估

7.1 单张图片预测

```python

import torch

from PIL import Image

from model import DigitCNN

def predict_digit(image_path, model_path="checkpoints/best_model.pth"):

加载模型

model = DigitCNN()

checkpoint = torch.load(model_path, map_location="cpu")

model.load_state_dict(checkpoint["model_state_dict"])

model.eval()

预处理图片

img = Image.open(image_path).convert("L") # 转灰度

img = img.resize((28, 28), Image.LANCZOS) # 缩放

img_array = np.array(img, dtype=np.float32)

反转颜色(白底黑字 → 黑底白字)

if img_array.mean() > 127:

img_array = 255 - img_array

归一化 + 标准化

img_array = img_array / 255.0

img_array = (img_array - 0.1307) / 0.3081

转为 Tensor

tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0)

预测

with torch.no_grad():

output = model(tensor)

probs = torch.softmax(output, dim=1)

predicted = torch.argmax(probs).item()

confidence = probs[0][predicted].item()

return predicted, confidence

使用示例

digit, conf = predict_digit("my_digit.png")

print(f"预测结果: {digit}, 置信度: {conf:.2%}")

```

**命令行预测:**

```bash

预测单张图片

python predict.py --image my_digit.png --visualize

从测试集随机预测

python predict.py --test_set --visualize

```

7.2 混淆矩阵分析

混淆矩阵是评估分类模型的重要工具,它展示了每个数字的预测情况:

```

预测 → 0 1 2 3 4 5 6 7 8 9

真实 ↓

0 [ 976 0 1 0 0 1 1 1 0 0 ]

1 [ 0 1132 1 1 0 0 0 1 0 0 ]

2 [ 1 1 1024 1 1 0 0 3 1 0 ]

3 [ 0 0 2 1001 0 3 0 1 2 1 ]

4 [ 0 0 1 0 975 0 2 0 1 3 ]

5 [ 2 0 0 4 1 884 2 1 0 0 ]

6 [ 3 2 0 0 2 2 949 0 0 0 ]

7 [ 0 3 4 1 0 0 0 1018 1 1 ]

8 [ 1 0 2 2 1 1 0 1 966 0 ]

9 [ 1 1 0 1 5 2 0 3 2 994 ]

```

**分析:**

  • 对角线上的数字表示正确预测的数量

  • 99.38% 的总体准确率

  • 数字 "5" 容易被误判为 "3"(手写体相似)

  • 数字 "4" 和 "9" 有时会混淆

7.3 错误样本分析

查看模型预测错误的样本,有助于理解模型的局限性:

```python

找出所有预测错误的样本

wrong_indices = []

for i, (pred, true) in enumerate(zip(all_preds, all_labels)):

if pred != true:

wrong_indices.append(i)

print(f"错误样本数: {len(wrong_indices)} / {len(all_labels)}")

print(f"准确率: {100 * (1 - len(wrong_indices)/len(all_labels)):.2f}%")

```


八、Web 可视化部署

8.1 Flask Web 应用

我们实现了一个在线手写识别演示,用户可以在画板上书写数字并实时获得识别结果。

```bash

启动 Web 服务

python web_app.py

```

访问 `http://localhost:5000` 即可体验。

8.2 前端核心实现

```html

<!-- Canvas 画板 -->

<canvas id="canvas" width="280" height="280"></canvas>

<script>

const canvas = document.getElementById('canvas');

const ctx = canvas.getContext('2d');

// 初始化黑色背景

ctx.fillStyle = '#000';

ctx.fillRect(0, 0, canvas.width, canvas.height);

// 白色画笔,线宽18

ctx.strokeStyle = '#fff';

ctx.lineWidth = 18;

ctx.lineCap = 'round';

// 绑定鼠标事件

canvas.addEventListener('mousedown', startDraw);

canvas.addEventListener('mousemove', draw);

canvas.addEventListener('mouseup', stopDraw);

</script>

```

8.3 后端预测接口

```python

@app.route('/predict', methods=['POST'])

def predict():

接收 base64 编码的图片

data = request.get_json()

image_data = data['image'].split(',')[1]

解码并预处理

image = Image.open(io.BytesIO(base64.b64decode(image_data)))

image = image.convert('L').resize((28, 28))

归一化 + 标准化

img_array = np.array(image, dtype=np.float32) / 255.0

img_array = (img_array - 0.1307) / 0.3081

tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0)

模型推理

with torch.no_grad():

output = model(tensor)

probs = torch.softmax(output, dim=1).numpy()[0]

return jsonify({

'prediction': int(np.argmax(probs)),

'confidence': float(np.max(probs)),

'probabilities': probs.tolist()

})

```

8.4 界面效果

Web 界面包含以下功能:

  • **画板区域**:支持鼠标和触摸输入

  • **识别按钮**:发送图片到后端进行推理

  • **清除按钮**:清空画板重新书写

  • **概率分布**:展示 10 个数字的预测概率条形图

  • **置信度显示**:高亮显示预测结果和置信度


九、训练结果与分析

9.1 最终性能指标

| 指标 | DigitCNN(本项目) | 简单 CNN | 全连接网络 |

|------|-------------------|----------|-----------|

| **测试准确率** | **99.38%** | 98.92% | 97.65% |

| **参数量** | 1,199,874 | 89,690 | 203,776 |

| **训练时间 (15 epochs)** | ~7 分钟 | ~3 分钟 | ~2 分钟 |

| **模型大小** | 4.6 MB | 0.3 MB | 0.8 MB |

9.2 各数字识别准确率

```

数字 0: 99.80% ████████████████████▉

数字 1: 99.74% ████████████████████▉

数字 2: 99.22% ███████████████████▊

数字 3: 99.11% ███████████████████▊

数字 4: 99.29% ███████████████████▊

数字 5: 98.99% ███████████████████▋

数字 6: 99.06% ███████████████████▊

数字 7: 99.03% ███████████████████▊

数字 8: 99.18% ███████████████████▊

数字 9: 99.11% ███████████████████▊

```

9.3 收敛曲线分析

训练过程中的损失和准确率变化:

  • **Epoch 1-5**:快速收敛阶段,准确率从 93% 快速提升到 99%

  • **Epoch 6-10**:精细调整阶段,学习率下降,准确率缓慢提升

  • **Epoch 11-15**:趋于稳定,最终达到 99.38%


十、常见问题与优化建议

10.1 FAQ

**Q1:训练时 GPU 内存不足怎么办?**

```python

减小 batch_size

python train.py --batch_size 32

或使用混合精度训练

from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

with autocast():

output = model(images)

loss = criterion(output, labels)

```

**Q2:如何提升准确率?**

```python

1. 增加数据增强强度

transforms.RandomRotation(15)

transforms.RandomAffine(0, translate=(0.15, 0.15))

2. 使用更深的网络

3. 使用集成学习(多个模型投票)

4. 增加训练轮次

python train.py --epochs 30

```

**Q3:如何用自己的手写图片测试?**

```bash

1. 用画图工具写一个数字

2. 保存为 PNG/JPG

3. 运行预测

python predict.py --image your_digit.png --visualize

```

**注意:** 图片要求:

  • 白底黑字或黑底白字均可(代码会自动处理)

  • 尽量居中书写

  • 数字占图片主要区域

**Q4:模型能否部署到移动端?**

```python

导出为 ONNX 格式

dummy_input = torch.randn(1, 1, 28, 28)

torch.onnx.export(model, dummy_input, "digit_model.onnx")

然后使用 ONNX Runtime 在移动端推理

import onnxruntime as ort

session = ort.InferenceSession("digit_model.onnx")

```

10.2 进阶优化方向

| 方向 | 方法 | 预期效果 |

|------|------|---------|

| 更深网络 | ResNet、DenseNet | 99.5%+ |

| 更强增强 | Cutout、Mixup、CutMix | +0.2% |

| 集成学习 | 多模型投票/融合 | +0.3% |

| 注意力机制 | SE Block、CBAM | +0.2% |

| 模型压缩 | 剪枝、量化、蒸馏 | 模型缩小 4x |

10.3 推荐学习路线

```

MNIST 手写数字识别 (本项目)

CIFAR-10 图像分类 (彩色图片、10类)

ImageNet 分类 (大规模、1000类)

目标检测 (YOLO、Faster R-CNN)

语义分割 (U-Net、DeepLab)

生成模型 (GAN、Diffusion)

```


十一、完整源码获取

11.1 项目文件说明

| 文件 | 功能 | 代码量 |

|------|------|--------|

| `model.py` | CNN 模型定义 | ~60 行 |

| `train.py` | 完整训练流程 | ~180 行 |

| `predict.py` | 单张/批量预测 | ~150 行 |

| `utils.py` | 数据加载、可视化 | ~160 行 |

| `web_app.py` | Flask Web 应用 | ~180 行 |

11.2 快速开始

```bash

1. 克隆项目

git clone https://github.com/your-username/digit_recognition.git

cd digit_recognition

2. 安装依赖

pip install -r requirements.txt

3. 训练模型

python train.py --epochs 15 --visualize

4. 预测

python predict.py --test_set --visualize

5. 启动 Web 演示

python web_app.py

```


<div align="center">

写在最后

这个项目虽然简单,但涵盖了深度学习的核心流程:

**数据处理 → 模型设计 → 训练优化 → 评估分析 → 部署应用**

掌握了这些,你就具备了处理更复杂视觉任务的基础。

**如果这篇文章对你有帮助,欢迎点赞 收藏 ⭐ 评论 支持一下!**

</div>

```


<div align="center">

写在最后

这个项目虽然简单,但涵盖了深度学习的核心流程:

**数据处理 → 模型设计 → 训练优化 → 评估分析 → 部署应用**

掌握了这些,你就具备了处理更复杂视觉任务的基础。

**如果这篇文章对你有帮助,欢迎点赞 收藏 ⭐ 评论 支持一下!**

</div>

相关推荐
Night_Elf2 小时前
AES-256加密+本地存储:国内本地密码管理器如何使用
人工智能·自动化
秋92 小时前
window中部署小龙虾OpenClaw
人工智能
星辰AI2 小时前
Stable Diffusion 实战教程:从安装到图像生成
人工智能·ai·语言模型
用户5191495848452 小时前
WordPress WPMasterToolkit 插件漏洞检测与利用工具
人工智能·aigc
AI医影跨模态组学2 小时前
Radiol Artif Intell 中山大学肿瘤防治中心放疗科:基于连续MRI的深度学习模型预测局部晚期鼻咽癌患者生存期
人工智能·深度学习·论文·医学·医学影像·影像组学
金智维科技官方2 小时前
圆桌对话:从流程自动化到智能流程,AI落地的下一站在哪里?
大数据·人工智能·ai·自动化·智能体
码字小学妹2 小时前
Google I/O 2026:Gemini 3.5 Flash 发布 + 国内 API 接入教程
人工智能
yezannnnnn2 小时前
Claude code 5 小时额度卡住?多账户错峰激活让你一天平滑使用不断额
人工智能·claude·vibecoding
PILIPALAPENG2 小时前
第4周 Day 4:Agent 工作流模式——编排复杂流程
前端·人工智能·python