使用pytorch创建/训练/推理OCR模型

一、任务描述

从手写数字图像中自动识别出对应的数字(0-9)" 的问题,属于单标签图像分类任务(每张图像仅对应一个类别,即 0-9 中的一个数字)

1、任务的核心定义:输入与输出

  • **输入:**28×28 像素的灰度图像(像素值范围 0-255,0 代表黑色背景,255 代表白色前景),图像内容是人类手写的 0-9 中的某一个数字,例如:一张 28×28 的图像,像素分布呈现 "3" 的形状,就是模型的输入。
  • **输出:**一个 "类别标签",即从 10 个可能的类别(0、1、2、...、9)中选择一个,作为输入图像对应的数字,例如:输入 "3" 的图像,模型输出 "类别 3",即完成一次正确识别。
  • **目标:**让模型在 "未见的手写数字图像" 上,尽可能准确地输出正确类别(通常用 "准确率" 衡量,即正确识别的图像数 / 总图像数)

2、任务的核心挑战

  • **不同人书写习惯差异极大:**有人写的 "4" 带弯钩,有人写的 "7" 带横线,有人字体粗大,有人字体纤细;甚至同一个人不同时间写的同一数字,笔画粗细、倾斜角度也会不同。例如:同样是 "5",可能是 "直笔 5""圆笔 5",也可能是倾斜 10° 或 20° 的 "5"------ 模型需要忽略这些 "风格差异",抓住 "数字的本质特征"(如 "5 有一个上半圆 + 一个竖线")。
  • **图像噪声与干扰:**手写数字图像可能存在噪声,比如纸张上的污渍、书写时的断笔、扫描时的光线不均,这些都会影响像素分布。例如:一张 "0" 的图像,边缘有一小块污渍,模型需要判断 "这是噪声" 而不是 "0 的一部分",避免误判为 "6" 或 "8"。

二、模型训练

1、MNIST数据集

MNIST(Modified National Institute of Standards and Technology database)是由美国国家标准与技术研究院(NIST)整理的手写数字数据集,后经修改(调整图像大小、居中对齐)成为机器学习领域的 "基准数据集",MNIST手写数字识别的核心是 "让计算机从标准化的手写数字灰度图中,自动识别出对应的 0-9 数字",它看似基础,却浓缩了图像分类的核心挑战(风格多样性、噪声鲁棒性、特征自动提取),同时是实际 OCR 场景的技术基础和机器学习入门的经典案例。

  • **数据量适中:**包含 70000 张图像,其中 60000 张用于训练(让模型学习特征),10000 张用于测试(验证模型泛化能力);
  • **图像规格统一:**所有图像都是 28×28 灰度图,无需复杂的预处理(如尺寸缩放、颜色通道处理),降低入门门槛;
  • **标注准确:**每张图像都有明确的 "正确数字标签"(人工标注),无需额外标注成本。

2、代码

  • **数据准备:**使用torchvision.datasets加载 MNIST 数据集,对数据进行转换(转为 Tensor 并标准化),使用DataLoader创建可迭代的数据加载器;

  • **模型定义:**定义了一个简单的两层神经网络SimpleNN,第一层将 28x28 的图像展平后映射到 128 维,第二层将 128 维特征映射到 10 个类别(对应数字 0-9);

  • **训练设置:**使用交叉熵损失函数(CrossEntropyLoss),使用 Adam 优化器,设置批量大小为64,训练轮次为5;

  • **训练过程:**循环多个训练轮次(epoch),每个轮次中迭代所有批次数据,执行前向传播、计算损失、反向传播和参数更新。

    -- coding: utf-8 --

    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import DataLoader
    from torchvision import datasets, transforms

    设置随机种子,确保结果可复现

    torch.manual_seed(42)

    1. 数据准备

    定义数据变换

    transform = transforms.Compose([
    transforms.ToTensor(), # 转换为Tensor
    transforms.Normalize((0.1307,), (0.3081,)) # 标准化,MNIST数据集的均值和标准差
    ])

    加载MNIST数据集

    train_dataset = datasets.MNIST(
    root='./data', # 数据保存路径
    train=True, # 训练集
    download=True, # 如果数据不存在则下载
    transform=transform
    )

    test_dataset = datasets.MNIST(
    root='./data',
    train=False, # 测试集
    download=True,
    transform=transform
    )

    创建数据加载器

    batch_size = 64
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    2. 定义模型

    class SimpleNN(nn.Module):
    def init(self):
    super(SimpleNN, self).init()
    # 输入层到隐藏层
    self.fc1 = nn.Linear(28 * 28, 128) # MNIST图像大小为28x28
    # 隐藏层到输出层
    self.fc2 = nn.Linear(128, 10) # 10个类别(0-9)

    复制代码
      def forward(self, x):
          # 将图像展平为一维向量
          x = x.view(-1, 28 * 28)
          # 隐藏层,使用ReLU激活函数
          x = torch.relu(self.fc1(x))
          # 输出层,不使用激活函数(因为后面会用CrossEntropyLoss)
          x = self.fc2(x)
          return x

    3. 初始化模型、损失函数和优化器

    model = SimpleNN()
    criterion = nn.CrossEntropyLoss() # 交叉熵损失,适用于分类问题
    optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam优化器

    4. 训练模型

    def train(model, train_loader, criterion, optimizer, epochs=5):
    model.train() # 设置为训练模式
    train_losses = []

    复制代码
      for epoch in range(epochs):
          running_loss = 0.0
          for batch_idx, (data, target) in enumerate(train_loader):
              # 清零梯度
              optimizer.zero_grad()
    
              # 前向传播
              outputs = model(data)
              loss = criterion(outputs, target)
    
              # 反向传播和优化
              loss.backward()
              optimizer.step()
    
              running_loss += loss.item()
    
              # 每100个批次打印一次信息
              if batch_idx % 100 == 99:
                  print(
                      f'Epoch [{epoch + 1}/{epochs}], Batch [{batch_idx + 1}/{len(train_loader)}], Loss: {running_loss / 100:.4f}')
                  running_loss = 0.0
    
          train_losses.append(running_loss / len(train_loader))
    
      return train_losses

    6. 运行训练和测试

    if name == 'main':
    # 训练模型
    print("开始训练模型...")
    train_losses = train(model, train_loader, criterion, optimizer, epochs=5)
    print("模型训练完成...")
    # 保存模型
    torch.save(model.state_dict(), 'mnist_model.pth')
    print("模型已保存为 mnist_model.pth")

三、模型使用测试

复制代码
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
from torchvision import transforms  # 修正transforms的导入方式

# 定义与训练时相同的模型结构
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 加载模型
def load_model(model_path='mnist_model.pth'):
    model = SimpleNN()
    # 加载模型时添加参数以避免潜在的Python 3兼容性问题
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))
    model.eval()  # 设置为评估模式
    return model

# 图像预处理(与训练时保持一致)
def preprocess_image(image_path):
    # 打开图像并转换为灰度图
    img = Image.open(image_path).convert('L')  # 'L'表示灰度模式
    # 调整大小为28x28
    img = img.resize((28, 28))
    # 转换为numpy数组并归一化
    img_array = np.array(img) / 255.0
    
    # 定义图像转换(使用torchvision的transforms)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    # 注意:这里需要先将numpy数组转换为PIL图像再应用transform
    img_pil = Image.fromarray((img_array * 255).astype(np.uint8))
    img_tensor = transform(img_pil).unsqueeze(0)  # 增加批次维度
    return img_tensor

# 预测函数
def predict_digit(model, image_path):
    # 预处理图像
    img_tensor = preprocess_image(image_path)
    
    # 预测
    with torch.no_grad():  # 不计算梯度
        outputs = model(img_tensor)
        _, predicted = torch.max(outputs.data, 1)
        return predicted.item()  # 返回预测的数字

# 示例使用
if __name__ == '__main__':
    # 加载模型
    model = load_model('mnist_model.pth')
    
    # 预测示例图像
    test_image_path = 'test_digit.png'  # 用户需要提供的测试图像路径
    try:
        predicted_digit = predict_digit(model, test_image_path)
        print(f"预测的数字是: {predicted_digit}")
    except Exception as e:
        print(f"预测出错: {str(e)}")

使用gpu0(第一块gpu)进行训练/推理:

torch.cuda.set_device(0)

model = model.cuda(0)

使用cpu记性训练/推理:

model = model.cpu()


怎么用pytorch训练一个模型-手写数字识别
手把手教你如何跑通一个手写中文汉字识别模型-OCR识别【pytorch】
手把手教你用PyTorch从零训练自己的大模型(非常详细)零基础入门到精通,收藏这一篇就够了
揭秘大模型的训练方法:使用PyTorch进行超大规模深度学习模型训练
全套解决方案:基于pytorch、transformers的中文NLP训练框架,支持大模型训练和文本生成,快速上手,海量训练数据!
用 pytorch 从零开始创建大语言模型(三):编码注意力机制

YOLOv5源码逐行超详细注释与解读(1)------项目目录结构解析

相关推荐
风象南20 分钟前
Claude Code这个隐藏技能,让我告别PPT焦虑
人工智能·后端
曲幽42 分钟前
FastAPI压力测试实战:Locust模拟真实用户并发及优化建议
python·fastapi·web·locust·asyncio·test·uvicorn·workers
Mintopia1 小时前
OpenClaw 对软件行业产生的影响
人工智能
陈广亮2 小时前
构建具有长期记忆的 AI Agent:从设计模式到生产实践
人工智能
会写代码的柯基犬2 小时前
DeepSeek vs Kimi vs Qwen —— AI 生成俄罗斯方块代码效果横评
人工智能·llm
Mintopia2 小时前
OpenClaw 是什么?为什么节后热度如此之高?
人工智能
爱可生开源社区2 小时前
DBA 的未来?八位行业先锋的年度圆桌讨论
人工智能·dba
叁两5 小时前
用opencode打造全自动公众号写作流水线,AI 代笔太香了!
前端·人工智能·agent
敏编程5 小时前
一天一个Python库:jsonschema - JSON 数据验证利器
python