基于PyTorch的MNIST手写数字识别系统 - 从零到实战
前言
手写数字识别是深度学习入门最经典的案例之一。今天,我将带大家从零开始,使用PyTorch构建一个完整的MNIST手写数字识别系统。这个项目不仅包含基础的模型训练,还实现了对真实图片中多个数字的识别功能,非常适合初学者学习。
目录
项目概述
什么是MNIST?
MNIST(Modified National Institute of Standards and Technology)是一个包含60,000个训练样本和10,000个测试样本的手写数字数据集。每个样本都是28x28像素的灰度图像,包含0-9十个数字类别。
项目目标
本项目旨在实现:
- ✅ 使用卷积神经网络(CNN)识别手写数字
- ✅ 支持对真实拍摄的图片进行数字识别
- ✅ 能够识别一张图片中的多个数字
- ✅ 处理数字变形问题(如细长数字'1')
环境准备
必需库安装
bash
pip install torch torchvision opencv-python numpy matplotlib
验证安装
python
import torch
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")
模型架构设计
CNN模型定义
我们使用经典的卷积神经网络结构,定义在 model.py 中:
python
import torch
import torch.nn as nn
class HandWriteCNN(nn.Module):
def __init__(self):
super(HandWriteCNN, self).__init__()
# 第一层卷积:1通道→32通道
self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(2, 2)
# 第二层卷积:32通道→64通道
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
# 全连接层
self.fc1 = nn.Linear(64 * 7 * 7, 128)
self.fc2 = nn.Linear(128, 10) # 输出10个类别
def forward(self, x):
# 第一层卷积+池化:28×28 → 14×14
x = self.pool(self.relu(self.conv1(x)))
# 第二层卷积+池化:14×14 → 7×7
x = self.pool(self.relu(self.conv2(x)))
# 展平:64×7×7 = 3136
x = x.view(-1, 64 * 7 * 7)
# 全连接层
x = self.relu(self.fc1(x))
x = self.fc2(x)
return x
架构说明
- 卷积层1: 提取底层特征(边缘、线条)
- 卷积层2: 提取更高层特征(形状、组合)
- 全连接层1: 特征融合
- 全连接层2: 输出10个类别的概率分布
为什么使用CNN?
- 卷积操作具有平移不变性
- 参数共享,减少参数量
- 能够有效提取图像的局部特征
数据集准备
PyTorch的torchvision库提供了便捷的MNIST数据加载:
python
import torchvision
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.ToTensor(), # 转为Tensor并归一化到[0,1]
transforms.Normalize((0.5,), (0.5,)) # 归一化到[-1,1]
])
trainset = torchvision.datasets.MNIST(
root='./data',
train=True,
download=True, # 自动下载
transform=transform
)
trainloader = torch.utils.data.DataLoader(
trainset,
batch_size=64,
shuffle=True
)
模型训练
完整的训练代码在 train.py 中:
python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from model import HandWriteCNN
def train_model():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用设备: {device}")
# 数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 加载数据集
trainset = torchvision.datasets.MNIST(
root='./data', train=True, download=True, transform=transform
)
trainloader = torch.utils.data.DataLoader(
trainset, batch_size=64, shuffle=True
)
# 创建模型
model = HandWriteCNN().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
print("开始训练模型...")
epochs = 3
for epoch in range(epochs):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = inputs.to(device), labels.to(device)
# 前向传播
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
# 反向传播
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss / len(trainloader):.4f}")
# 保存模型
torch.save(model.state_dict(), 'mnist_cnn_model.pth')
print("模型已保存到 'mnist_cnn_model.pth'")
if __name__ == "__main__":
train_model()
训练要点
- 损失函数: CrossEntropyLoss(多分类问题常用)
- 优化器: Adam(自适应学习率)
- 学习率: 0.001(可调整)
- 批次大小: 64(根据GPU内存调整)
- 训练轮数: 3(可增加以提高准确率)
运行训练:
bash
python train.py
图像预处理技巧
问题:数字变形
直接使用cv2.resize()缩放图像会导致细长数字(如'1')被拉伸变形,影响识别准确率。
解决方案:保持长宽比缩放
在 utils.py 中实现的函数:
python
import cv2
import numpy as np
def resize_pad_maintain_aspect_ratio(image, target_size=28):
"""
保持长宽比将图像缩放到 target_size,并填充黑边。
解决数字 '1' 被拉伸变形的问题。
"""
h, w = image.shape
# 1. 计算缩放比例,让最长边缩放到 20 (留出边距)
scale = 20.0 / max(h, w)
new_h, new_w = int(h * scale), int(w * scale)
# 2. 缩放图像
resized_image = cv2.resize(
image, (new_w, new_h),
interpolation=cv2.INTER_AREA
)
# 3. 创建28x28的黑色画布
canvas = np.zeros((target_size, target_size), dtype=np.uint8)
# 4. 将缩放后的图像居中放置
top = (target_size - new_h) // 2
left = (target_size - new_w) // 2
canvas[top:top+new_h, left:left+new_w] = resized_image
return canvas
关键点
- 保持长宽比: 只缩放最长边到20像素
- 居中填充: 将缩放后的图像放在28×28画布中心
- 留白边距: 预留边距有助于模型识别
多数字识别实现
数字分割流程
完整代码在 main.py 中:
python
import cv2
import torch
import torchvision.transforms as transforms
from model import HandWriteCNN
from utils import resize_pad_maintain_aspect_ratio
def predict_two_digit_image(image_path):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 1. 加载模型
model = HandWriteCNN().to(device)
model.load_state_dict(
torch.load('mnist_cnn_model.pth', map_location=device)
)
model.eval()
# 2. 读取图片
img_original = cv2.imread(image_path)
gray = cv2.cvtColor(img_original, cv2.COLOR_BGR2GRAY)
# 3. 二值化
_, thresh = cv2.threshold(
gray, 0, 255,
cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU
)
# 4. 轮廓检测(找到所有数字)
contours, _ = cv2.findContours(
thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
)
# 5. 筛选并排序轮廓
digit_rects = []
for cnt in contours:
x, y, w, h = cv2.boundingRect(cnt)
if w > 5 and h > 10: # 过滤噪声
digit_rects.append((x, y, w, h))
digit_rects.sort(key=lambda x: x[0]) # 按x坐标排序
# 6. 逐个识别
results = []
for x, y, w, h in digit_rects:
# 提取数字区域
pad = 10
roi = thresh[
max(0, y-pad):min(thresh.shape[0], y+h+pad),
max(0, x-pad):min(thresh.shape[1], x+w+pad)
]
# 预处理:保持长宽比缩放
roi_processed = resize_pad_maintain_aspect_ratio(roi)
# 转为Tensor
roi_tensor = transforms.ToTensor()(roi_processed)
roi_tensor = transforms.Normalize((0.5,), (0.5,))(roi_tensor)
roi_tensor = roi_tensor.unsqueeze(0).to(device)
# 预测
with torch.no_grad():
output = model(roi_tensor)
_, predicted = torch.max(output, 1)
digit = predicted.item()
results.append(str(digit))
final_result = "".join(results)
print(f"识别结果: {final_result}")
return final_result
识别步骤详解
- 灰度转换 :
COLOR_BGR2GRAY - 二值化: OTSU自适应阈值(自动确定最佳阈值)
- 轮廓检测 :
findContours找到所有数字边界 - 区域提取: 提取每个数字的矩形区域
- 预处理: 保持长宽比缩放
- 模型推理: 使用训练好的模型预测
- 结果拼接: 按从左到右顺序组合结果
实战测试
测试代码
python
if __name__ == "__main__":
test_images = ['Test1.jpg', 'Test2.png']
print("开始进行手写测试...")
for img_file in test_images:
predict_two_digit_image(img_file)
运行测试
bash
python main.py
预期输出
开始进行手写测试...
图片 Test1.jpg 识别结果: 23
图片 Test2.png 识别结果: 45
常见问题与优化
Q1: 识别准确率不高
解决方案:
- 增加训练轮数(epochs = 5 或 10)
- 调整学习率(尝试0.0001或0.01)
- 数据增强(旋转、缩放、平移)
- 调整模型架构(增加卷积层或全连接层)
Q2: 无法检测到数字
可能原因:
- 图片对比度太低
- 数字太小或太大
- 背景干扰
解决方案:
python
# 调整二值化阈值
_, thresh = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY_INV)
# 调整轮廓筛选条件
if w > 10 and h > 15: # 根据实际情况调整
Q3: 数字顺序识别错误
解决方案:
python
# 如果数字是垂直排列,按y坐标排序
digit_rects.sort(key=lambda x: x[1]) # 按y坐标排序
# 如果数字是两行,需要更复杂的排序逻辑
digit_rects.sort(key=lambda x: (x[1]//30, x[0])) # 先按行,再按列
Q4: 处理速度慢
优化方案:
- 使用GPU加速(自动检测)
- 减少图片分辨率
- 批量处理多个数字
项目扩展建议
1. 添加数据增强
python
transform = transforms.Compose([
transforms.RandomRotation(10),
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
2. 实现验证集评估
python
testset = torchvision.datasets.MNIST(
root='./data', train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(testset, batch_size=64)
# 计算准确率
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'准确率: {100 * correct / total}%')
3. 支持手写中文数字
扩展数据集,训练识别"零一二三四五六七八九"
4. Web应用开发
使用Flask或FastAPI创建Web接口,支持上传图片识别
总结与展望
本项目实现了:
✅ 完整的深度学习流程 : 从数据加载到模型训练再到实际应用
✅ 实用的图像处理技巧 : 保持长宽比的预处理方法
✅ 多数字识别功能 : 自动分割和识别多个数字
✅ 可扩展的代码结构: 模块化设计,易于改进
学习收获:
- PyTorch基础: 模型定义、训练循环、保存加载
- CNN原理: 卷积、池化、全连接层的作用
- 图像处理: OpenCV的轮廓检测、二值化等操作
- 实际问题解决: 处理数字变形、多目标识别等
下一步学习方向:
- 更复杂的网络架构(ResNet、DenseNet)
- 迁移学习(使用预训练模型)
- 模型部署(ONNX、TensorRT)
- 其他计算机视觉任务(目标检测、图像分割)
结语
通过这个项目,我们不仅学会了如何使用PyTorch构建CNN模型,更重要的是理解了如何将模型应用到实际问题中。希望这篇文章对大家有帮助!
完整的项目代码已上传到GitHub,欢迎大家Star和Fork!
如果文章对你有帮助,别忘了点赞👍、收藏⭐、关注❤️!
参考资料
作者简介: 深度学习爱好者,专注于计算机视觉和机器学习应用。
联系方式: [你的联系方式]
本文由CSDN博主原创,转载请注明出处。