PyTorch手写体数字识别实例

MNIST数据集的准备

"HelloWorld"是所有编程语言入门的基础程序,在开始编程学习时,我们打印的第一句话通常就是这个"HelloWorld"。本书也不例外,在深度学习编程中也有其特有的"HelloWorld",一般就是采用MNIST完成一项特定的深度学习项目。

MNIST是一个手写数字图像数据库,如图2-21所示,它有60 000个训练样本集和10 000个测试样本集。读者可直接使用本书源码库提供的MNIST数据集,它位于配套源码的dataset文件夹中,如图2-22所示。

然后使用NumPy数据库进行数据的读取,代码如下:

import numpy as np

x_train = np.load("./dataset/mnist/x_train.npy")

y_train_label = np.load("./dataset/mnist/y_train_label.npy")

或者读者也可以在网上搜索MNIST的下载地址,下载MNIST文件中包含的数据集train-images-idx3-ubyte.gz(训练图片集)、train-labels-idx1-ubyte.gz(训练标签集)、t10k-images-idx3-ubyte.gz(测试图片集)和t10k-labels-idx1-ubyte.gz(测试标签集),如图2-23所示。

图2-23 MNIST文件中包含的数据集

将下载的4个文件进行解压缩。解压缩后,会发现这些文件并不是标准的图像格式,而是二进制文件,将文件保存到源码可以访问到的目录下。

基于PyTorch的手写体识别

下面我们开始基于PyTorch的手写体识别。通过2.3.4小节的介绍可知,我们还需要定义的一个内容就是深度学习的优化器部分,在这里采用Adam优化器,这部分代码如下:

model = NeuralNetwork()

optimizer = torch.optim.Adam(model.parameters(), lr=2e-5) #设定优化函数

完整的手写体识别首先需要定义模型,然后将模型参数传入优化器中。其中lr是对学习率的设定,根据设定的学习率进行模型计算。完整的手写体识别模型代码如下:

复制代码
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #指定GPU编号
import torch
import numpy as np
from tqdm import tqdm

batch_size = 320            	#设定每次训练的批次数
epochs = 1024              	#设定训练次数

#device = "cpu"    	#PyTorch的特性,需要指定计算的硬件,如果没有GPU的存在,就使用CPU进行计算
device = "cuda"   	#这里默认使用GPU,如果出现运行问题,可以将其改成CPU模式


#设定的多层感知机网络模型
class NeuralNetwork(torch.nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = torch.nn.Flatten()
        self.linear_relu_stack = torch.nn.Sequential(
            torch.nn.Linear(28*28,312),
            torch.nn.ReLU(),
            torch.nn.Linear(312, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 10)
        )
    def forward(self, input):
        x = self.flatten(input)
        logits = self.linear_relu_stack(x)

        return logits

model = NeuralNetwork()
model = model.to(device)          	#将计算模型传入GPU硬件等待计算
#model = torch.compile(model)     	#PyTorch 2.0的特性,加速计算速度,选择性使用
loss_fu = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)   #设定优化函数

#载入数据
x_train = np.load("../../dataset/mnist/x_train.npy")
y_train_label = np.load("../../dataset/mnist/y_train_label.npy")

train_num = len(x_train)//batch_size

#开始计算
for epoch in range(20):
    train_loss = 0
    for i in range(train_num):
        start = i * batch_size
        end = (i + 1) * batch_size

        train_batch = torch.tensor(x_train[start:end]).to(device)
        label_batch = torch.tensor(y_train_label[start:end]).to(device)

        pred = model(train_batch)
        loss = loss_fu(pred,label_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()  #记录每个批次的损失值

    #计算并打印损失值
    train_loss /= train_num
    accuracy = (pred.argmax(1) == label_batch).type(torch.float32).sum().item() / batch_size
    print("train_loss:", round(train_loss,2),"accuracy:",round(accuracy,2))

模型的训练结果如图3-5所示。

图3-5 训练结果

可以看到,随着模型循环次数的增加,模型的损失值在降低,而准确率在提高,具体请读者自行验证测试。

《PyTorch深度学习与计算机视觉实践(人工智能技术丛书)》(王晓华)【摘要 书评 试读】- 京东图书 (jd.com)

相关推荐
创可贴治愈心灵2 分钟前
AI浪潮下C#就业前景剖析:深耕C#为主,按需选修Java与Python
java·人工智能·c#
子非鱼@Itfuture2 分钟前
端侧AI(On-Device AI / Edge AI)|边缘 AI|云端 AI 探索报告
人工智能·ai·agi·端侧ai
MageGojo5 分钟前
基于 API Zero 平台集成 TTS 语音合成服务的技术实践
python·语音合成·tts·restful api·api集成
愚公搬代码10 分钟前
【愚公系列】《移动端AI应用开发》014-DeepSeek API开发与集成(处理多轮对话与动态请求)
人工智能·中间件·架构
真上帝的左手13 分钟前
19. 大数据- BI - AI 应用1-融合场景解析
大数据·人工智能·ai·bi
wgc2k17 分钟前
Oops Framework-6-项目中如何使用AI的思路
人工智能·游戏·cocos2d
Jump 不二23 分钟前
Memory-os 7 层记忆架构深度解析:让 Hermes Agent 真正 “记住并使用“ 知识
人工智能·语言模型·系统架构
程序猿阿伟24 分钟前
《无需额外付费的OpenClaw Agent部署指南》
人工智能
DS随心转APP27 分钟前
AI导出鸭:AI 文档排版与一键导出实战指南
人工智能·ai·chatgpt·deepseek·ai导出鸭
geneculture29 分钟前
语(暨各级各类字组)对接外来的词和句以及本土的言和语:言和语的关系及双重形式化彻底解决问题
人工智能·语言学·融智学应用场景·哲学与科学统一性·融智时代(杂志)