基于PyTorch的MNIST手写数字识别系统 - 从零到实战

基于PyTorch的MNIST手写数字识别系统 - 从零到实战

前言

手写数字识别是深度学习入门最经典的案例之一。今天,我将带大家从零开始,使用PyTorch构建一个完整的MNIST手写数字识别系统。这个项目不仅包含基础的模型训练,还实现了对真实图片中多个数字的识别功能,非常适合初学者学习。

目录

  1. 项目概述
  2. 环境准备
  3. 模型架构设计
  4. 数据集准备
  5. 模型训练
  6. 图像预处理技巧
  7. 多数字识别实现
  8. 实战测试
  9. 常见问题与优化
  10. 总结与展望

项目概述

什么是MNIST?

MNIST(Modified National Institute of Standards and Technology)是一个包含60,000个训练样本和10,000个测试样本的手写数字数据集。每个样本都是28x28像素的灰度图像,包含0-9十个数字类别。

项目目标

本项目旨在实现:

  • ✅ 使用卷积神经网络(CNN)识别手写数字
  • ✅ 支持对真实拍摄的图片进行数字识别
  • ✅ 能够识别一张图片中的多个数字
  • ✅ 处理数字变形问题(如细长数字'1')

环境准备

必需库安装

bash 复制代码
pip install torch torchvision opencv-python numpy matplotlib

验证安装

python 复制代码
import torch
print(f"PyTorch版本: {torch.__version__}")
print(f"CUDA可用: {torch.cuda.is_available()}")

模型架构设计

CNN模型定义

我们使用经典的卷积神经网络结构,定义在 model.py 中:

python 复制代码
import torch
import torch.nn as nn

class HandWriteCNN(nn.Module):
    def __init__(self):
        super(HandWriteCNN, self).__init__()
        # 第一层卷积:1通道→32通道
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2) 
        
        # 第二层卷积:32通道→64通道
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        
        # 全连接层
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.fc2 = nn.Linear(128, 10)  # 输出10个类别

    def forward(self, x):
        # 第一层卷积+池化:28×28 → 14×14
        x = self.pool(self.relu(self.conv1(x)))
        # 第二层卷积+池化:14×14 → 7×7
        x = self.pool(self.relu(self.conv2(x)))
        # 展平:64×7×7 = 3136
        x = x.view(-1, 64 * 7 * 7)
        # 全连接层
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

架构说明

  1. 卷积层1: 提取底层特征(边缘、线条)
  2. 卷积层2: 提取更高层特征(形状、组合)
  3. 全连接层1: 特征融合
  4. 全连接层2: 输出10个类别的概率分布

为什么使用CNN?

  • 卷积操作具有平移不变性
  • 参数共享,减少参数量
  • 能够有效提取图像的局部特征

数据集准备

PyTorch的torchvision库提供了便捷的MNIST数据加载:

python 复制代码
import torchvision
import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.ToTensor(),  # 转为Tensor并归一化到[0,1]
    transforms.Normalize((0.5,), (0.5,))  # 归一化到[-1,1]
])

trainset = torchvision.datasets.MNIST(
    root='./data', 
    train=True, 
    download=True,  # 自动下载
    transform=transform
)

trainloader = torch.utils.data.DataLoader(
    trainset, 
    batch_size=64, 
    shuffle=True
)

模型训练

完整的训练代码在 train.py 中:

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from model import HandWriteCNN

def train_model():
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"使用设备: {device}")

    # 数据预处理
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    # 加载数据集
    trainset = torchvision.datasets.MNIST(
        root='./data', train=True, download=True, transform=transform
    )
    trainloader = torch.utils.data.DataLoader(
        trainset, batch_size=64, shuffle=True
    )

    # 创建模型
    model = HandWriteCNN().to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    print("开始训练模型...")
    epochs = 3
    for epoch in range(epochs):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)

            # 前向传播
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            # 反向传播
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
        
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {running_loss / len(trainloader):.4f}")

    # 保存模型
    torch.save(model.state_dict(), 'mnist_cnn_model.pth')
    print("模型已保存到 'mnist_cnn_model.pth'")

if __name__ == "__main__":
    train_model()

训练要点

  • 损失函数: CrossEntropyLoss(多分类问题常用)
  • 优化器: Adam(自适应学习率)
  • 学习率: 0.001(可调整)
  • 批次大小: 64(根据GPU内存调整)
  • 训练轮数: 3(可增加以提高准确率)

运行训练:

bash 复制代码
python train.py

图像预处理技巧

问题:数字变形

直接使用cv2.resize()缩放图像会导致细长数字(如'1')被拉伸变形,影响识别准确率。

解决方案:保持长宽比缩放

utils.py 中实现的函数:

python 复制代码
import cv2
import numpy as np

def resize_pad_maintain_aspect_ratio(image, target_size=28):
    """
    保持长宽比将图像缩放到 target_size,并填充黑边。
    解决数字 '1' 被拉伸变形的问题。
    """
    h, w = image.shape
    
    # 1. 计算缩放比例,让最长边缩放到 20 (留出边距)
    scale = 20.0 / max(h, w)
    new_h, new_w = int(h * scale), int(w * scale)
    
    # 2. 缩放图像
    resized_image = cv2.resize(
        image, (new_w, new_h), 
        interpolation=cv2.INTER_AREA
    )
    
    # 3. 创建28x28的黑色画布
    canvas = np.zeros((target_size, target_size), dtype=np.uint8)
    
    # 4. 将缩放后的图像居中放置
    top = (target_size - new_h) // 2
    left = (target_size - new_w) // 2
    canvas[top:top+new_h, left:left+new_w] = resized_image
    
    return canvas

关键点

  1. 保持长宽比: 只缩放最长边到20像素
  2. 居中填充: 将缩放后的图像放在28×28画布中心
  3. 留白边距: 预留边距有助于模型识别

多数字识别实现

数字分割流程

完整代码在 main.py 中:

python 复制代码
import cv2
import torch
import torchvision.transforms as transforms
from model import HandWriteCNN
from utils import resize_pad_maintain_aspect_ratio

def predict_two_digit_image(image_path):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 1. 加载模型
    model = HandWriteCNN().to(device)
    model.load_state_dict(
        torch.load('mnist_cnn_model.pth', map_location=device)
    )
    model.eval()
    
    # 2. 读取图片
    img_original = cv2.imread(image_path)
    gray = cv2.cvtColor(img_original, cv2.COLOR_BGR2GRAY)
    
    # 3. 二值化
    _, thresh = cv2.threshold(
        gray, 0, 255, 
        cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU
    )
    
    # 4. 轮廓检测(找到所有数字)
    contours, _ = cv2.findContours(
        thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
    )
    
    # 5. 筛选并排序轮廓
    digit_rects = []
    for cnt in contours:
        x, y, w, h = cv2.boundingRect(cnt)
        if w > 5 and h > 10:  # 过滤噪声
            digit_rects.append((x, y, w, h))
    
    digit_rects.sort(key=lambda x: x[0])  # 按x坐标排序
    
    # 6. 逐个识别
    results = []
    for x, y, w, h in digit_rects:
        # 提取数字区域
        pad = 10
        roi = thresh[
            max(0, y-pad):min(thresh.shape[0], y+h+pad), 
            max(0, x-pad):min(thresh.shape[1], x+w+pad)
        ]
        
        # 预处理:保持长宽比缩放
        roi_processed = resize_pad_maintain_aspect_ratio(roi)
        
        # 转为Tensor
        roi_tensor = transforms.ToTensor()(roi_processed)
        roi_tensor = transforms.Normalize((0.5,), (0.5,))(roi_tensor)
        roi_tensor = roi_tensor.unsqueeze(0).to(device)
        
        # 预测
        with torch.no_grad():
            output = model(roi_tensor)
            _, predicted = torch.max(output, 1)
            digit = predicted.item()
            results.append(str(digit))
    
    final_result = "".join(results)
    print(f"识别结果: {final_result}")
    return final_result

识别步骤详解

  1. 灰度转换 : COLOR_BGR2GRAY
  2. 二值化: OTSU自适应阈值(自动确定最佳阈值)
  3. 轮廓检测 : findContours找到所有数字边界
  4. 区域提取: 提取每个数字的矩形区域
  5. 预处理: 保持长宽比缩放
  6. 模型推理: 使用训练好的模型预测
  7. 结果拼接: 按从左到右顺序组合结果

实战测试

测试代码

python 复制代码
if __name__ == "__main__":
    test_images = ['Test1.jpg', 'Test2.png']
    
    print("开始进行手写测试...")
    for img_file in test_images:
        predict_two_digit_image(img_file)

运行测试

bash 复制代码
python main.py

预期输出

复制代码
开始进行手写测试...
图片 Test1.jpg 识别结果: 23
图片 Test2.png 识别结果: 45

常见问题与优化

Q1: 识别准确率不高

解决方案:

  1. 增加训练轮数(epochs = 5 或 10)
  2. 调整学习率(尝试0.0001或0.01)
  3. 数据增强(旋转、缩放、平移)
  4. 调整模型架构(增加卷积层或全连接层)

Q2: 无法检测到数字

可能原因:

  • 图片对比度太低
  • 数字太小或太大
  • 背景干扰

解决方案:

python 复制代码
# 调整二值化阈值
_, thresh = cv2.threshold(gray, 127, 255, cv2.THRESH_BINARY_INV)

# 调整轮廓筛选条件
if w > 10 and h > 15:  # 根据实际情况调整

Q3: 数字顺序识别错误

解决方案:

python 复制代码
# 如果数字是垂直排列,按y坐标排序
digit_rects.sort(key=lambda x: x[1])  # 按y坐标排序

# 如果数字是两行,需要更复杂的排序逻辑
digit_rects.sort(key=lambda x: (x[1]//30, x[0]))  # 先按行,再按列

Q4: 处理速度慢

优化方案:

  1. 使用GPU加速(自动检测)
  2. 减少图片分辨率
  3. 批量处理多个数字

项目扩展建议

1. 添加数据增强

python 复制代码
transform = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

2. 实现验证集评估

python 复制代码
testset = torchvision.datasets.MNIST(
    root='./data', train=False, download=True, transform=transform
)
testloader = torch.utils.data.DataLoader(testset, batch_size=64)

# 计算准确率
correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = model(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'准确率: {100 * correct / total}%')

3. 支持手写中文数字

扩展数据集,训练识别"零一二三四五六七八九"

4. Web应用开发

使用Flask或FastAPI创建Web接口,支持上传图片识别


总结与展望

本项目实现了:

完整的深度学习流程 : 从数据加载到模型训练再到实际应用

实用的图像处理技巧 : 保持长宽比的预处理方法

多数字识别功能 : 自动分割和识别多个数字

可扩展的代码结构: 模块化设计,易于改进

学习收获:

  1. PyTorch基础: 模型定义、训练循环、保存加载
  2. CNN原理: 卷积、池化、全连接层的作用
  3. 图像处理: OpenCV的轮廓检测、二值化等操作
  4. 实际问题解决: 处理数字变形、多目标识别等

下一步学习方向:

  • 更复杂的网络架构(ResNet、DenseNet)
  • 迁移学习(使用预训练模型)
  • 模型部署(ONNX、TensorRT)
  • 其他计算机视觉任务(目标检测、图像分割)

结语

通过这个项目,我们不仅学会了如何使用PyTorch构建CNN模型,更重要的是理解了如何将模型应用到实际问题中。希望这篇文章对大家有帮助!

完整的项目代码已上传到GitHub,欢迎大家Star和Fork!

如果文章对你有帮助,别忘了点赞👍、收藏⭐、关注❤️!


参考资料


作者简介: 深度学习爱好者,专注于计算机视觉和机器学习应用。

联系方式: [你的联系方式]


本文由CSDN博主原创,转载请注明出处。

相关推荐
ar01232 小时前
AR远程协助工具有哪些
人工智能·ar
冰西瓜6002 小时前
国科大高级人工智能期末复习(五)行为主义
人工智能
zch不会敲代码2 小时前
机器学习之线性回归简单案例(代码逐句解释)
人工智能·机器学习·线性回归
想搞艺术的程序员2 小时前
AI 编程 - 提示词技巧
人工智能·ai编程
Das12 小时前
【机器学习】06_集成学习
人工智能·机器学习·集成学习
one____dream2 小时前
【算法】大整数数组连续进位
python·算法
one____dream2 小时前
【算法】合并两个有序链表
数据结构·python·算法·链表
大江东去浪淘尽千古风流人物2 小时前
【Project Aria】Meta新一代的AR眼镜及其数据集
人工智能·嵌入式硬件·算法·性能优化·ar·dsp开发
Java后端的Ai之路2 小时前
【AI应用开发工程师】-分享Java 转 AI正确思路
java·开发语言·人工智能·java转行·程序员转型