PyTorch手写体数字识别实例

MNIST数据集的准备

"HelloWorld"是所有编程语言入门的基础程序,在开始编程学习时,我们打印的第一句话通常就是这个"HelloWorld"。本书也不例外,在深度学习编程中也有其特有的"HelloWorld",一般就是采用MNIST完成一项特定的深度学习项目。

MNIST是一个手写数字图像数据库,如图2-21所示,它有60 000个训练样本集和10 000个测试样本集。读者可直接使用本书源码库提供的MNIST数据集,它位于配套源码的dataset文件夹中,如图2-22所示。

然后使用NumPy数据库进行数据的读取,代码如下:

import numpy as np

x_train = np.load("./dataset/mnist/x_train.npy")

y_train_label = np.load("./dataset/mnist/y_train_label.npy")

或者读者也可以在网上搜索MNIST的下载地址,下载MNIST文件中包含的数据集train-images-idx3-ubyte.gz(训练图片集)、train-labels-idx1-ubyte.gz(训练标签集)、t10k-images-idx3-ubyte.gz(测试图片集)和t10k-labels-idx1-ubyte.gz(测试标签集),如图2-23所示。

图2-23 MNIST文件中包含的数据集

将下载的4个文件进行解压缩。解压缩后,会发现这些文件并不是标准的图像格式,而是二进制文件,将文件保存到源码可以访问到的目录下。

基于PyTorch的手写体识别

下面我们开始基于PyTorch的手写体识别。通过2.3.4小节的介绍可知,我们还需要定义的一个内容就是深度学习的优化器部分,在这里采用Adam优化器,这部分代码如下:

model = NeuralNetwork()

optimizer = torch.optim.Adam(model.parameters(), lr=2e-5) #设定优化函数

完整的手写体识别首先需要定义模型,然后将模型参数传入优化器中。其中lr是对学习率的设定,根据设定的学习率进行模型计算。完整的手写体识别模型代码如下:

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #指定GPU编号
import torch
import numpy as np
from tqdm import tqdm

batch_size = 320            	#设定每次训练的批次数
epochs = 1024              	#设定训练次数

#device = "cpu"    	#PyTorch的特性,需要指定计算的硬件,如果没有GPU的存在,就使用CPU进行计算
device = "cuda"   	#这里默认使用GPU,如果出现运行问题,可以将其改成CPU模式


#设定的多层感知机网络模型
class NeuralNetwork(torch.nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = torch.nn.Flatten()
        self.linear_relu_stack = torch.nn.Sequential(
            torch.nn.Linear(28*28,312),
            torch.nn.ReLU(),
            torch.nn.Linear(312, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 10)
        )
    def forward(self, input):
        x = self.flatten(input)
        logits = self.linear_relu_stack(x)

        return logits

model = NeuralNetwork()
model = model.to(device)          	#将计算模型传入GPU硬件等待计算
#model = torch.compile(model)     	#PyTorch 2.0的特性,加速计算速度,选择性使用
loss_fu = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)   #设定优化函数

#载入数据
x_train = np.load("../../dataset/mnist/x_train.npy")
y_train_label = np.load("../../dataset/mnist/y_train_label.npy")

train_num = len(x_train)//batch_size

#开始计算
for epoch in range(20):
    train_loss = 0
    for i in range(train_num):
        start = i * batch_size
        end = (i + 1) * batch_size

        train_batch = torch.tensor(x_train[start:end]).to(device)
        label_batch = torch.tensor(y_train_label[start:end]).to(device)

        pred = model(train_batch)
        loss = loss_fu(pred,label_batch)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        train_loss += loss.item()  #记录每个批次的损失值

    #计算并打印损失值
    train_loss /= train_num
    accuracy = (pred.argmax(1) == label_batch).type(torch.float32).sum().item() / batch_size
    print("train_loss:", round(train_loss,2),"accuracy:",round(accuracy,2))

模型的训练结果如图3-5所示。

图3-5 训练结果

可以看到,随着模型循环次数的增加,模型的损失值在降低,而准确率在提高,具体请读者自行验证测试。

《PyTorch深度学习与计算机视觉实践(人工智能技术丛书)》(王晓华)【摘要 书评 试读】- 京东图书 (jd.com)

相关推荐
geovindu2 分钟前
python: SQLAlchemy (ORM) Simple example using mysql in Ubuntu 24.04
python·mysql·ubuntu
nuclear20114 分钟前
Python 将PPT幻灯片和形状转换为多种图片格式(JPG, PNG, BMP, SVG, TIFF)
python·ppt转图片·ppt转png·ppt转jpg·ppt转svg·ppt转tiff·ppt转bmp
小白狮ww15 分钟前
国产超强开源大语言模型 DeepSeek-R1-70B 一键部署教程
人工智能·深度学习·机器学习·语言模型·自然语言处理·开源·deepseek
风口猪炒股指标21 分钟前
想象一个AI保姆机器人使用场景分析
人工智能·机器人·deepseek·深度思考
没有晚不了安24 分钟前
1.13作业
开发语言·python
Blankspace空白34 分钟前
【小白学AI系列】NLP 核心知识点(八)多头自注意力机制
人工智能·自然语言处理
刀客12334 分钟前
python小项目编程-中级(1、图像处理)
开发语言·图像处理·python
Sodas(填坑中....)42 分钟前
SVM对偶问题
人工智能·机器学习·支持向量机·数据挖掘
信阳农夫1 小时前
python 3.6.8支持的Django版本是多少?
python·django·sqlite
forestsea1 小时前
DeepSeek 提示词:定义、作用、分类与设计原则
人工智能·prompt·deepseek