基于 PyTorch 的 MNIST数字图像数据集分类模型训练与评估的简单练习

首先,导入需要用到的包。

Python 复制代码
import numpy as np
import torch
from torch.utils.data import DataLoader
from torchvision.datasets import mnist
from torch import nn
from torch.autograd import Variable
import matplotlib.pyplot as plt

然后构建MNIST数据集数据转换函数,将图像转换为Pytorch能处理的张量。

Python 复制代码
def data_transform(img):
    img= np.array(img, dtype="float") / 255
    img= (img- 0.5) / 0.5
    img= img.reshape((-1))
    img= torch.Tensor(img)
    return img

通过Python下载MNIST数据,构建训练集与测试集,此处"./data2"为数据的存放位置。

Python 复制代码
train_dataset = mnist.MNIST("./data2", train=True, transform=data_transform, download=True)
test_dataset = mnist.MNIST("./data2", train=False, transform=data_transform, download=True)

构建神经网络。由于所采用的MNIST数据集一张图像的大小为28*28,所以设置输入数据时设置28*28=784个输入值,一共有0-9十个数字,所以最终的输出为10个输出值。通过ReLU函数设置如下。

Python 复制代码
net = nn.Sequential(
    nn.Linear(784, 400),
    nn.ReLU(),
    nn.Linear(400, 200),
    nn.ReLU(),
    nn.Linear(200, 100),
    nn.ReLU(),
    nn.Linear(100, 10)
)

然后,构建损失函数与优化器。这里使用交叉熵设置损失函数。

Python 复制代码
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), 1e-1)

创建四个数组用于存放每次处理后的损失值与准确度,以便于图像显示。

Python 复制代码
losses = []
acces = []
eval_losses = []
eval_acces = []

开始训练模型并测试。设置训练准确度与训练损失参数,通过循环遍历每一批数据。在处理一批数据时,首先将图像与标签数据类型转换为张量,然后通过建立的神经网络训练数据并通过损失函数获取损失。接着,将参数的梯度归零,对损失求导并更新参数。其次,将该批次的损失汇总并计算准确度。最后,在完成内循环后将损失与准确度添加到相关数组中用于图像显示。而测试过程与训练过程类似,只是没有求梯度的过程。

Python 复制代码
for e in range(20):
    train_loss = 0
    train_acc = 0
    for im, label in train_dataset:
        im = Variable(im)
        label = Variable(label)
 
        out = net(im)
        loss = criterion(out, label)
 
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
        train_loss += loss
        _, pred = out.max(1)
        num_correct = (pred == label).sum()
        acc = num_correct / im.shape[0]
        train_acc += acc
 
    losses.append(train_loss / len(train_dataset))
    acces.append(train_acc / len(train_dataset))
 
    eval_loss = 0
    eval_acc = 0
    for im, label in test_dataset:
        im = Variable(im)
        label = Variable(label)
        out = net(im)
        loss = criterion(out, label)
        eval_loss += loss
        _, pred = out.max(1)
        num_correct = (pred == label).sum()
        acc = num_correct / im.shape[0]
        eval_acc += acc
    eval_losses.append(eval_loss / len(test_dataset))
    eval_acces.append(eval_acc / len(test_dataset))
print("epoch:{},Train Loss:{:.6f},Train acc:{:.6f},Eval Loss:{:.6f},Eval acc:{:.6f}".format(e + 1, train_loss / len(train_dataset),train_acc / len(train_dataset),eval_loss / len(test_dataset),eval_acc / len(test_dataset)))

接下来,将数组中的数据类型从张量转换为可以处理的numpy数据格式。

Python 复制代码
losses = [item.detach().numpy() for item in losses]
acces = [item.detach().numpy() for item in acces]
eval_acces = [item.detach().numpy() for item in eval_acces]
eval_losses = [item.detach().numpy() for item in eval_losses]

最后构建绘图函数并完成绘图。

Python 复制代码
def make_plt(title, list):
    plt.title(title)
    plt.plot(np.arange(len(list)), list)
    plt.show()
 
make_plt("train loss", losses)
make_plt("tain acc", acces)
make_plt("eval loss", eval_losses)
make_plt("eval acc", eval_acces)

最终得到以下四个图像与输出。

Console 复制代码
......
epoch:14,Train Loss:0.019388,Train acc:0.993687,Eval Loss:0.073528,Eval acc:0.980716
epoch:15,Train Loss:0.018121,Train acc:0.994170,Eval Loss:0.075069,Eval acc:0.979727
epoch:16,Train Loss:0.013503,Train acc:0.995436,Eval Loss:0.078196,Eval acc:0.980617
epoch:17,Train Loss:0.012675,Train acc:0.995919,Eval Loss:0.070875,Eval acc:0.981309
epoch:18,Train Loss:0.014213,Train acc:0.995769,Eval Loss:0.076365,Eval acc:0.981408
epoch:19,Train Loss:0.011078,Train acc:0.996335,Eval Loss:0.068176,Eval acc:0.983683
epoch:20,Train Loss:0.006785,Train acc:0.998118,Eval Loss:0.114188,Eval acc:0.974684
相关推荐
c7695 分钟前
【文献笔记】Automatic Chain of Thought Prompting in Large Language Models
人工智能·笔记·语言模型·论文笔记
Blossom.11833 分钟前
机器学习在智能供应链中的应用:需求预测与物流优化
人工智能·深度学习·神经网络·机器学习·计算机视觉·机器人·语音识别
Gyoku Mint40 分钟前
深度学习×第4卷:Pytorch实战——她第一次用张量去拟合你的轨迹
人工智能·pytorch·python·深度学习·神经网络·算法·聚类
zzywxc78742 分钟前
AI大模型的技术演进、流程重构、行业影响三个维度的系统性分析
人工智能·重构
点控云43 分钟前
智能私域运营中枢:从客户视角看 SCRM 的体验革新与价值重构
大数据·人工智能·科技·重构·外呼系统·呼叫中心
zhaoyi_he1 小时前
多模态大模型的技术应用与未来展望:重构AI交互范式的新引擎
人工智能·重构
葫三生2 小时前
如何评价《论三生原理》在科技界的地位?
人工智能·算法·机器学习·数学建模·量子计算
m0_751336393 小时前
突破性进展:超短等离子体脉冲实现单电子量子干涉,为飞行量子比特奠定基础
人工智能·深度学习·量子计算·材料科学·光子器件·光子学·无线电电子
美狐美颜sdk6 小时前
跨平台直播美颜SDK集成实录:Android/iOS如何适配贴纸功能
android·人工智能·ios·架构·音视频·美颜sdk·第三方美颜sdk
DeepSeek-大模型系统教程6 小时前
推荐 7 个本周 yyds 的 GitHub 项目。
人工智能·ai·语言模型·大模型·github·ai大模型·大模型学习