PyTorch手写体数字识别实例

MNIST数据集的准备

"HelloWorld"是所有编程语言入门的基础程序,在开始编程学习时,我们打印的第一句话通常就是这个"HelloWorld"。本书也不例外,在深度学习编程中也有其特有的"HelloWorld",一般就是采用MNIST完成一项特定的深度学习项目。

MNIST是一个手写数字图像数据库,如图2-21所示,它有60 000个训练样本集和10 000个测试样本集。读者可直接使用本书源码库提供的MNIST数据集,它位于配套源码的dataset文件夹中,如图2-22所示。

然后使用NumPy数据库进行数据的读取,代码如下:

import numpy as np

x_train = np.load("./dataset/mnist/x_train.npy")

y_train_label = np.load("./dataset/mnist/y_train_label.npy")

或者读者也可以在网上搜索MNIST的下载地址,下载MNIST文件中包含的数据集train-images-idx3-ubyte.gz(训练图片集)、train-labels-idx1-ubyte.gz(训练标签集)、t10k-images-idx3-ubyte.gz(测试图片集)和t10k-labels-idx1-ubyte.gz(测试标签集),如图2-23所示。

图2-23 MNIST文件中包含的数据集

将下载的4个文件进行解压缩。解压缩后,会发现这些文件并不是标准的图像格式,而是二进制文件,将文件保存到源码可以访问到的目录下。

基于PyTorch的手写体识别

下面我们开始基于PyTorch的手写体识别。通过2.3.4小节的介绍可知,我们还需要定义的一个内容就是深度学习的优化器部分,在这里采用Adam优化器,这部分代码如下:

model = NeuralNetwork()

optimizer = torch.optim.Adam(model.parameters(), lr=2e-5) #设定优化函数

完整的手写体识别首先需要定义模型,然后将模型参数传入优化器中。其中lr是对学习率的设定,根据设定的学习率进行模型计算。完整的手写体识别模型代码如下:

复制代码
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #指定GPU编号
import torch
import numpy as np
from tqdm import tqdm

batch_size = 320            	#设定每次训练的批次数
epochs = 1024              	#设定训练次数

#device = "cpu"    	#PyTorch的特性,需要指定计算的硬件,如果没有GPU的存在,就使用CPU进行计算
device = "cuda"   	#这里默认使用GPU,如果出现运行问题,可以将其改成CPU模式


#设定的多层感知机网络模型
class NeuralNetwork(torch.nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = torch.nn.Flatten()
        self.linear_relu_stack = torch.nn.Sequential(
            torch.nn.Linear(28*28,312),
            torch.nn.ReLU(),
            torch.nn.Linear(312, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 10)
        )
    def forward(self, input):
        x = self.flatten(input)
        logits = self.linear_relu_stack(x)

        return logits

model = NeuralNetwork()
model = model.to(device)          	#将计算模型传入GPU硬件等待计算
#model = torch.compile(model)     	#PyTorch 2.0的特性,加速计算速度,选择性使用
loss_fu = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)   #设定优化函数

#载入数据
x_train = np.load("../../dataset/mnist/x_train.npy")
y_train_label = np.load("../../dataset/mnist/y_train_label.npy")

train_num = len(x_train)//batch_size

#开始计算
for epoch in range(20):
    train_loss = 0
    for i in range(train_num):
        start = i * batch_size
        end = (i + 1) * batch_size

        train_batch = torch.tensor(x_train[start:end]).to(device)
        label_batch = torch.tensor(y_train_label[start:end]).to(device)

        pred = model(train_batch)
        loss = loss_fu(pred,label_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()  #记录每个批次的损失值

    #计算并打印损失值
    train_loss /= train_num
    accuracy = (pred.argmax(1) == label_batch).type(torch.float32).sum().item() / batch_size
    print("train_loss:", round(train_loss,2),"accuracy:",round(accuracy,2))

模型的训练结果如图3-5所示。

图3-5 训练结果

可以看到,随着模型循环次数的增加,模型的损失值在降低,而准确率在提高,具体请读者自行验证测试。

《PyTorch深度学习与计算机视觉实践(人工智能技术丛书)》(王晓华)【摘要 书评 试读】- 京东图书 (jd.com)

相关推荐
Hy行者勇哥25 分钟前
多源数据抽取与推送模块架构设计
人工智能·个人开发
星空的资源小屋31 分钟前
Text Grab,一款OCR 截图文字识别工具
python·django·ocr·scikit-learn
寒秋丶32 分钟前
Milvus:Json字段详解(十)
数据库·人工智能·python·ai·milvus·向量数据库·rag
长桥夜波1 小时前
机器学习日报07
人工智能·机器学习
长桥夜波1 小时前
机器学习日报11
人工智能·机器学习
一个处女座的程序猿4 小时前
LLMs之SLMs:《Small Language Models are the Future of Agentic AI》的翻译与解读
人工智能·自然语言处理·小语言模型·slms
自由随风飘4 小时前
python 题目练习1~5
开发语言·python
fl1768317 小时前
基于python的天气预报系统设计和可视化数据分析源码+报告
开发语言·python·数据分析
档案宝档案管理7 小时前
档案宝:企业合同档案管理的“安全保险箱”与“效率加速器”
大数据·数据库·人工智能·安全·档案·档案管理
闲人编程7 小时前
Python与区块链:如何用Web3.py与以太坊交互
python·安全·区块链·web3.py·以太坊·codecapsule