使用pytorch创建/训练/推理OCR模型

一、任务描述

从手写数字图像中自动识别出对应的数字(0-9)" 的问题,属于单标签图像分类任务(每张图像仅对应一个类别,即 0-9 中的一个数字)

1、任务的核心定义:输入与输出

  • **输入:**28×28 像素的灰度图像(像素值范围 0-255,0 代表黑色背景,255 代表白色前景),图像内容是人类手写的 0-9 中的某一个数字,例如:一张 28×28 的图像,像素分布呈现 "3" 的形状,就是模型的输入。
  • **输出:**一个 "类别标签",即从 10 个可能的类别(0、1、2、...、9)中选择一个,作为输入图像对应的数字,例如:输入 "3" 的图像,模型输出 "类别 3",即完成一次正确识别。
  • **目标:**让模型在 "未见的手写数字图像" 上,尽可能准确地输出正确类别(通常用 "准确率" 衡量,即正确识别的图像数 / 总图像数)

2、任务的核心挑战

  • **不同人书写习惯差异极大:**有人写的 "4" 带弯钩,有人写的 "7" 带横线,有人字体粗大,有人字体纤细;甚至同一个人不同时间写的同一数字,笔画粗细、倾斜角度也会不同。例如:同样是 "5",可能是 "直笔 5""圆笔 5",也可能是倾斜 10° 或 20° 的 "5"------ 模型需要忽略这些 "风格差异",抓住 "数字的本质特征"(如 "5 有一个上半圆 + 一个竖线")。
  • **图像噪声与干扰:**手写数字图像可能存在噪声,比如纸张上的污渍、书写时的断笔、扫描时的光线不均,这些都会影响像素分布。例如:一张 "0" 的图像,边缘有一小块污渍,模型需要判断 "这是噪声" 而不是 "0 的一部分",避免误判为 "6" 或 "8"。

二、模型训练

1、MNIST数据集

MNIST(Modified National Institute of Standards and Technology database)是由美国国家标准与技术研究院(NIST)整理的手写数字数据集,后经修改(调整图像大小、居中对齐)成为机器学习领域的 "基准数据集",MNIST手写数字识别的核心是 "让计算机从标准化的手写数字灰度图中,自动识别出对应的 0-9 数字",它看似基础,却浓缩了图像分类的核心挑战(风格多样性、噪声鲁棒性、特征自动提取),同时是实际 OCR 场景的技术基础和机器学习入门的经典案例。

  • **数据量适中:**包含 70000 张图像,其中 60000 张用于训练(让模型学习特征),10000 张用于测试(验证模型泛化能力);
  • **图像规格统一:**所有图像都是 28×28 灰度图,无需复杂的预处理(如尺寸缩放、颜色通道处理),降低入门门槛;
  • **标注准确:**每张图像都有明确的 "正确数字标签"(人工标注),无需额外标注成本。

2、代码

  • **数据准备:**使用torchvision.datasets加载 MNIST 数据集,对数据进行转换(转为 Tensor 并标准化),使用DataLoader创建可迭代的数据加载器;

  • **模型定义:**定义了一个简单的两层神经网络SimpleNN,第一层将 28x28 的图像展平后映射到 128 维,第二层将 128 维特征映射到 10 个类别(对应数字 0-9);

  • **训练设置:**使用交叉熵损失函数(CrossEntropyLoss),使用 Adam 优化器,设置批量大小为64,训练轮次为5;

  • **训练过程:**循环多个训练轮次(epoch),每个轮次中迭代所有批次数据,执行前向传播、计算损失、反向传播和参数更新。

    -- coding: utf-8 --

    import torch
    import torch.nn as nn
    import torch.optim as optim
    from torch.utils.data import DataLoader
    from torchvision import datasets, transforms

    设置随机种子,确保结果可复现

    torch.manual_seed(42)

    1. 数据准备

    定义数据变换

    transform = transforms.Compose([
    transforms.ToTensor(), # 转换为Tensor
    transforms.Normalize((0.1307,), (0.3081,)) # 标准化,MNIST数据集的均值和标准差
    ])

    加载MNIST数据集

    train_dataset = datasets.MNIST(
    root='./data', # 数据保存路径
    train=True, # 训练集
    download=True, # 如果数据不存在则下载
    transform=transform
    )

    test_dataset = datasets.MNIST(
    root='./data',
    train=False, # 测试集
    download=True,
    transform=transform
    )

    创建数据加载器

    batch_size = 64
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    2. 定义模型

    class SimpleNN(nn.Module):
    def init(self):
    super(SimpleNN, self).init()
    # 输入层到隐藏层
    self.fc1 = nn.Linear(28 * 28, 128) # MNIST图像大小为28x28
    # 隐藏层到输出层
    self.fc2 = nn.Linear(128, 10) # 10个类别(0-9)

    复制代码
      def forward(self, x):
          # 将图像展平为一维向量
          x = x.view(-1, 28 * 28)
          # 隐藏层,使用ReLU激活函数
          x = torch.relu(self.fc1(x))
          # 输出层,不使用激活函数(因为后面会用CrossEntropyLoss)
          x = self.fc2(x)
          return x

    3. 初始化模型、损失函数和优化器

    model = SimpleNN()
    criterion = nn.CrossEntropyLoss() # 交叉熵损失,适用于分类问题
    optimizer = optim.Adam(model.parameters(), lr=0.001) # Adam优化器

    4. 训练模型

    def train(model, train_loader, criterion, optimizer, epochs=5):
    model.train() # 设置为训练模式
    train_losses = []

    复制代码
      for epoch in range(epochs):
          running_loss = 0.0
          for batch_idx, (data, target) in enumerate(train_loader):
              # 清零梯度
              optimizer.zero_grad()
    
              # 前向传播
              outputs = model(data)
              loss = criterion(outputs, target)
    
              # 反向传播和优化
              loss.backward()
              optimizer.step()
    
              running_loss += loss.item()
    
              # 每100个批次打印一次信息
              if batch_idx % 100 == 99:
                  print(
                      f'Epoch [{epoch + 1}/{epochs}], Batch [{batch_idx + 1}/{len(train_loader)}], Loss: {running_loss / 100:.4f}')
                  running_loss = 0.0
    
          train_losses.append(running_loss / len(train_loader))
    
      return train_losses

    6. 运行训练和测试

    if name == 'main':
    # 训练模型
    print("开始训练模型...")
    train_losses = train(model, train_loader, criterion, optimizer, epochs=5)
    print("模型训练完成...")
    # 保存模型
    torch.save(model.state_dict(), 'mnist_model.pth')
    print("模型已保存为 mnist_model.pth")

三、模型使用测试

复制代码
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
from torchvision import transforms  # 修正transforms的导入方式

# 定义与训练时相同的模型结构
class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 10)
        
    def forward(self, x):
        x = x.view(-1, 28*28)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# 加载模型
def load_model(model_path='mnist_model.pth'):
    model = SimpleNN()
    # 加载模型时添加参数以避免潜在的Python 3兼容性问题
    model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'), weights_only=True))
    model.eval()  # 设置为评估模式
    return model

# 图像预处理(与训练时保持一致)
def preprocess_image(image_path):
    # 打开图像并转换为灰度图
    img = Image.open(image_path).convert('L')  # 'L'表示灰度模式
    # 调整大小为28x28
    img = img.resize((28, 28))
    # 转换为numpy数组并归一化
    img_array = np.array(img) / 255.0
    
    # 定义图像转换(使用torchvision的transforms)
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    
    # 注意:这里需要先将numpy数组转换为PIL图像再应用transform
    img_pil = Image.fromarray((img_array * 255).astype(np.uint8))
    img_tensor = transform(img_pil).unsqueeze(0)  # 增加批次维度
    return img_tensor

# 预测函数
def predict_digit(model, image_path):
    # 预处理图像
    img_tensor = preprocess_image(image_path)
    
    # 预测
    with torch.no_grad():  # 不计算梯度
        outputs = model(img_tensor)
        _, predicted = torch.max(outputs.data, 1)
        return predicted.item()  # 返回预测的数字

# 示例使用
if __name__ == '__main__':
    # 加载模型
    model = load_model('mnist_model.pth')
    
    # 预测示例图像
    test_image_path = 'test_digit.png'  # 用户需要提供的测试图像路径
    try:
        predicted_digit = predict_digit(model, test_image_path)
        print(f"预测的数字是: {predicted_digit}")
    except Exception as e:
        print(f"预测出错: {str(e)}")

使用gpu0(第一块gpu)进行训练/推理:

torch.cuda.set_device(0)

model = model.cuda(0)

使用cpu记性训练/推理:

model = model.cpu()


怎么用pytorch训练一个模型-手写数字识别
手把手教你如何跑通一个手写中文汉字识别模型-OCR识别【pytorch】
手把手教你用PyTorch从零训练自己的大模型(非常详细)零基础入门到精通,收藏这一篇就够了
揭秘大模型的训练方法:使用PyTorch进行超大规模深度学习模型训练
全套解决方案:基于pytorch、transformers的中文NLP训练框架,支持大模型训练和文本生成,快速上手,海量训练数据!
用 pytorch 从零开始创建大语言模型(三):编码注意力机制

YOLOv5源码逐行超详细注释与解读(1)------项目目录结构解析

相关推荐
wan5555cn5 小时前
文字生视频的“精准”代码设定的核心原则本质是最小化文本语义与视频内容的KL散度
人工智能·笔记·深度学习·音视频
MediaTea5 小时前
Python 内置函数:pow()
开发语言·python
AndrewHZ5 小时前
【图像处理基石】图像预处理方面有哪些经典的算法?
图像处理·python·opencv·算法·计算机视觉·cv·图像预处理
IT_陈寒5 小时前
Python异步编程的7个致命误区:90%开发者踩过的坑及高效解决方案
前端·人工智能·后端
老猿讲编程5 小时前
存算一体:重构AI计算的革命性技术(1)
人工智能·重构
easy20206 小时前
从 Excel 趋势线到机器学习:拆解 AI 背后的核心框架
人工智能·笔记·机器学习
winfredzhang6 小时前
用Python打造逼真的照片桌面:从拖拽到交互的完整实现
python·拖拽·照片·桌面
天机️灵韵6 小时前
OpenAvatarChat项目在Windows本地运行指南
人工智能·开源项目·openavatarchat
DeeplyMind7 小时前
AMD KFD驱动技术分析16:SVM Aperture
人工智能·机器学习·amdgpu·rocm·kfd