使用pytorch创建/训练/推理OCR模型

一、任务描述

从手写数字图像中自动识别出对应的数字(0-9)" 的问题,属于单标签图像分类任务(每张图像仅对应一个类别,即 0-9 中的一个数字)

1、任务的核心定义:输入与输出

  • **输入:**28×28 像素的灰度图像(像素值范围 0-255,0 代表黑色背景,255 代表白色前景),图像内容是人类手写的 0-9 中的某一个数字,例如:一张 28×28 的图像,像素分布呈现 "3" 的形状,就是模型的输入。
  • **输出:**一个 "类别标签",即从 10 个可能的类别(0、1、2、...、9)中选择一个,作为输入图像对应的数字,例如:输入 "3" 的图像,模型输出 "类别 3",即完成一次正确识别。
  • **目标:**让模型在 "未见的手写数字图像" 上,尽可能准确地输出正确类别(通常用 "准确率" 衡量,即正确识别的图像数 / 总图像数)

2、任务的核心挑战

  • **不同人书写习惯差异极大:**有人写的 "4" 带弯钩,有人写的 "7" 带横线,有人字体粗大,有人字体纤细;甚至同一个人不同时间写的同一数字,笔画粗细、倾斜角度也会不同。例如:同样是 "5",可能是 "直笔 5""圆笔 5",也可能是倾斜 10° 或 20° 的 "5"------ 模型需要忽略这些 "风格差异",抓住 "数字的本质特征"(如 "5 有一个上半圆 + 一个竖线")。
  • **图像噪声与干扰:**手写数字图像可能存在噪声,比如纸张上的污渍、书写时的断笔、扫描时的光线不均,这些都会影响像素分布。例如:一张 "0" 的图像,边缘有一小块污渍,模型需要判断 "这是噪声" 而不是 "0 的一部分",避免误判为 "6" 或 "8"。

二、模型训练

1、MNIST数据集

MNIST(Modified National Institute of Standards and Technology database)是由美国国家标准与技术研究院(NIST)整理的手写数字数据集,后经修改(调整图像大小、居中对齐)成为机器学习领域的 "基准数据集",MNIST手写数字识别的核心是 "让计算机从标准化的手写数字灰度图中,自动识别出对应的 0-9 数字",它看似基础,却浓缩了图像分类的核心挑战(风格多样性、噪声鲁棒性、特征自动提取),同时是实际 OCR 场景的技术基础和机器学习入门的经典案例。

  • **数据量适中:**包含 70000 张图像,其中 60000 张用于训练(让模型学习特征),10000 张用于测试(验证模型泛化能力);
  • **图像规格统一:**所有图像都是 28×28 灰度图,无需复杂的预处理(如尺寸缩放、颜色通道处理),降低入门门槛;
  • **标注准确:**每张图像都有明确的 "正确数字标签"(人工标注),无需额外标注成本。

2、代码

  • **数据准备:**使用torchvision.datasets加载 MNIST 数据集,对数据进行转换(转为 Tensor 并标准化),使用DataLoader创建可迭代的数据加载器;

  • **模型定义:**定义了一个简单的两层神经网络SimpleNN,第一层将 28x28 的图像展平后映射到 128 维,第二层将 128 维特征映射到 10 个类别(对应数字 0-9);

  • **训练设置:**使用交叉熵损失函数(CrossEntropyLoss),使用 Adam 优化器,设置批量大小为64,训练轮次为5;

  • **训练过程:**循环多个训练轮次(epoch),每个轮次中迭代所有批次数据,执行前向传播、计算损失、反向传播和参数更新。

    -- coding: utf-8 --

    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import DataLoader
    from torchvision import datasets, transforms

    设置随机种子,确保结果可复现

    torch.manual_seed(42)

    1. 数据准备

    定义数据变换

    transform = transforms.Compose([
    transforms.ToTensor(), # 转换为Tensor
    transforms.Normalize((0.1307,), (0.3081,)) # 标准化,MNIST数据集的均值和标准差
    ])

    加载MNIST数据集

    train_dataset = datasets.MNIST(
    root='./data', # 数据保存路径
    train=True, # 训练集
    download=True, # 如果数据不存在则下载
    transform=transform
    )

    test_dataset = datasets.MNIST(
    root='./data',
    train=False, # 测试集
    download=True,
    transform=transform
    )

    创建数据加载器

    batch_size = 64
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    2. 定义模型

    class SimpleNN(nn.Module):
    def init(self):
    super(SimpleNN, self).init()
    # 输入层到隐藏层
    self.fc1 = nn.Linear(28 * 28, 128) # MNIST图像大小为28x28
    # 隐藏层到输出层
    self.fc2 = nn.Linear(128, 10) # 10个类别(0-9)

    复制代码
      def forward(self, x):
          # 将图像展平为一维向量
          x = x.view(-1, 28 * 28)
          # 隐藏层,使用ReLU激活函数
          x = torch.relu(self.fc1(x))
          # 输出层,不使用激活函数(因为后面会用CrossEntropyLoss)
          x = self.fc2(x)
          return x

    3. 初始化模型、损失函数和优化器

    model = SimpleNN()
    criterion = nn.CrossEntropyLoss() # 交叉熵损失,适用于分类问题
    optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam优化器

    4. 训练模型

    def train(model, train_loader, criterion, optimizer, epochs=5):
    model.train() # 设置为训练模式
    train_losses = []

    复制代码
      for epoch in range(epochs):
          running_loss = 0.0
          for batch_idx, (data, target) in enumerate(train_loader):
              # 清零梯度
              optimizer.zero_grad()
    
              # 前向传播
              outputs = model(data)
              loss = criterion(outputs, target)
    
              # 反向传播和优化
              loss.backward()
              optimizer.step()
    
              running_loss += loss.item()
    
              # 每100个批次打印一次信息
              if batch_idx % 100 == 99:
                  print(
                      f'Epoch [{epoch + 1}/{epochs}], Batch [{batch_idx + 1}/{len(train_loader)}], Loss: {running_loss / 100:.4f}')
                  running_loss = 0.0
    
          train_losses.append(running_loss / len(train_loader))
    
      return train_losses

    6. 运行训练和测试

    if name == 'main':
    # 训练模型
    print("开始训练模型...")
    train_losses = train(model, train_loader, criterion, optimizer, epochs=5)
    print("模型训练完成...")
    # 保存模型
    torch.save(model.state_dict(), 'mnist_model.pth')
    print("模型已保存为 mnist_model.pth")

三、模型使用测试

复制代码
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
from torchvision import transforms  # 修正transforms的导入方式

# 定义与训练时相同的模型结构
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 加载模型
def load_model(model_path='mnist_model.pth'):
    model = SimpleNN()
    # 加载模型时添加参数以避免潜在的Python 3兼容性问题
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))
    model.eval()  # 设置为评估模式
    return model

# 图像预处理(与训练时保持一致)
def preprocess_image(image_path):
    # 打开图像并转换为灰度图
    img = Image.open(image_path).convert('L')  # 'L'表示灰度模式
    # 调整大小为28x28
    img = img.resize((28, 28))
    # 转换为numpy数组并归一化
    img_array = np.array(img) / 255.0
    
    # 定义图像转换(使用torchvision的transforms)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    # 注意:这里需要先将numpy数组转换为PIL图像再应用transform
    img_pil = Image.fromarray((img_array * 255).astype(np.uint8))
    img_tensor = transform(img_pil).unsqueeze(0)  # 增加批次维度
    return img_tensor

# 预测函数
def predict_digit(model, image_path):
    # 预处理图像
    img_tensor = preprocess_image(image_path)
    
    # 预测
    with torch.no_grad():  # 不计算梯度
        outputs = model(img_tensor)
        _, predicted = torch.max(outputs.data, 1)
        return predicted.item()  # 返回预测的数字

# 示例使用
if __name__ == '__main__':
    # 加载模型
    model = load_model('mnist_model.pth')
    
    # 预测示例图像
    test_image_path = 'test_digit.png'  # 用户需要提供的测试图像路径
    try:
        predicted_digit = predict_digit(model, test_image_path)
        print(f"预测的数字是: {predicted_digit}")
    except Exception as e:
        print(f"预测出错: {str(e)}")

使用gpu0(第一块gpu)进行训练/推理:

torch.cuda.set_device(0)

model = model.cuda(0)

使用cpu记性训练/推理:

model = model.cpu()


怎么用pytorch训练一个模型-手写数字识别
手把手教你如何跑通一个手写中文汉字识别模型-OCR识别【pytorch】
手把手教你用PyTorch从零训练自己的大模型(非常详细)零基础入门到精通,收藏这一篇就够了
揭秘大模型的训练方法:使用PyTorch进行超大规模深度学习模型训练
全套解决方案:基于pytorch、transformers的中文NLP训练框架,支持大模型训练和文本生成,快速上手,海量训练数据!
用 pytorch 从零开始创建大语言模型(三):编码注意力机制

YOLOv5源码逐行超详细注释与解读(1)------项目目录结构解析

相关推荐
北京耐用通信8 分钟前
协议不通,数据何通?耐达讯自动化Modbus TCP与Profibus网关技术破解建筑自动化最大瓶颈
网络·人工智能·网络协议·自动化·信息与通信
IT_陈寒11 分钟前
Redis 性能提升秘籍:这5个被低估的命令让你的QPS飙升200%
前端·人工智能·后端
CodeCraft Studio22 分钟前
借助Aspose.Email,使用 Python 将 EML 转换为 MHTML
linux·服务器·python·aspose·email·mhtml·eml
SkylerHu22 分钟前
MacOS 使用ssh2-python报错ImportError: dlopen ... _libssh2_channel_direct_tcpip_ex
python·macos·ssh2
victory043129 分钟前
从机器学习到RLHF的完整学科分支脉络与赛道分析
人工智能·机器学习
京东零售技术5 小时前
京东零售胡浩:智能供应链从运筹到大模型到超级智能体的演进
大数据·人工智能
榕壹云5 小时前
GEO正在通过大模型技术重构企业数字营销生态
人工智能·重构·geo
计算机软件程序设计8 小时前
基于Python的二手车价格数据分析与预测系统的设计与实现
开发语言·python·数据分析·预测系统
K姐研究社8 小时前
通义万相Wan2.5模型实测,可生成音画同步视频
人工智能·aigc·音视频
mortimer8 小时前
Traceback 模块:`format_exc` 和 `format_exception` 傻傻分不清
python