pytorch学习3(pytorch手写数字识别练习)

网络模型

设置三层网络,一般最后一层激活函数不选择relu

任务步骤

复制代码
手写数字识别任务共有四个步骤:
1、数据加载--Load Data
2、构建网络--Build Model
3、训练--Train
4、测试--Test

实战

1、导入各种需要的包

python 复制代码
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim

import torchvision

from matplotlib import pyplot as plt

from minist_utils import plot_image, plot_curve, one_hot ##自写文件

minist_utils:

2、加载数据

python 复制代码
batch_size = 512

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081, ))
                               ])),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081, ))
                               ])),
    batch_size=batch_size, shuffle=False

取一些样本看数据的shape以及图片内容

python 复制代码
x, y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, 'image sample')

复制代码
注:经过load加载处理后的数据集包含x(图像信息)和y(标签信息)
next(iter())的用法是取一组样本,重复运行可以依次顺序取样,直到样本被取完
可在csdn自行搜索学习了解

3、网络构建

按之前设想的三层线性模型嵌套的思想搭建模型,为了模型简单,第三层不加激活函数。

python 复制代码
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()

        # xw+b
        self.fc1 = nn.Linear(28*28, 256) #输入特征数,输出特征数
        self.fc2 = nn.Linear(256, 64)  #256,64是根据经验判断
        self.fc3 = nn.Linear(64, 10)  #最开始的28*28和输出的10是一定的

    def forward(self, x):
        # x: [b, 1, 28, 28]
        # h1 = relu(xw1 + b1)
        x = F.relu(self.fc1(x)) #输入x后第一次线性模型得到H1作第二层输入
        # h2 = relu(h1w2 + b2)
        x = F.relu(self.fc2(x)) #输入H1得到H2作第三层输入
        # h3 = h2w3 + b3
        x = self.fc3(x)	#输入H3得到最终结果,维度为10

        return x

4、模型训练

python 复制代码
net = Net()

# [w1, b1, w2, b2, w3, b3]
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

train_loss = []

for epoch in range(3):

    for batch_idx, (x, y) in enumerate(train_loader):

        # x: [b, 1, 28, 28], y: [512]
        # [b, 1, 28, 28] => [b, feature] 全连接层只能接受这样的数据
        x = x.view(x.size(0), 28*28)
        # => [b, 10]
        out = net(x)
        # [b, 10]
        y_onehot = one_hot(y)
        # loss = mse(out, y_onehot)
        loss = F.mse_loss(out, y_onehot)

        optimizer.zero_grad()
        loss.backward() # 梯度计算过程
        # w` = w - lr * grad
        optimizer.step() # 优化更新w,b

        train_loss.append(loss.item())

        if batch_idx % 10 == 0:
            print(epoch, batch_idx, loss.item())

plot_curve(train_loss)

5、测试

1、计算准确率acc

python 复制代码
total_correct = 0
for x, y in test_loader:
    x = x.view(x.size(0), 28*28)
    out = net(x)
    # out: [b, 10] => pred: [b]
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct

total_num = len(test_loader.dataset)
acc = total_correct / total_num
print(("acc:", acc))

2、展示部分测试样本原图以及预测标签结果

python 复制代码
x, y =next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, 'test')
相关推荐
云边有个稻草人几秒前
基于CANN ops-nn的AIGC神经网络算子优化与落地实践
人工智能·神经网络·aigc
chian-ocean2 分钟前
视觉新范式:基于 `ops-transformer` 的 Vision Transformer 高效部署
人工智能·深度学习·transformer
程序猿追5 分钟前
探索 CANN Graph 引擎的计算图编译优化策略:深度技术解读
人工智能·目标跟踪
哈__5 分钟前
CANN加速语音识别ASR推理:声学模型与语言模型融合优化
人工智能·语言模型·语音识别
慢半拍iii15 分钟前
CANN算子开发实战:手把手教你基于ops-nn仓库编写Broadcast广播算子
人工智能·计算机网络·ai
User_芊芊君子28 分钟前
CANN数学计算基石ops-math深度解析:高性能科学计算与AI模型加速的核心引擎
人工智能·深度学习·神经网络·ai
小白|31 分钟前
CANN与联邦学习融合:构建隐私安全的分布式AI推理与训练系统
人工智能·机器学习·自动驾驶
艾莉丝努力练剑39 分钟前
hixl vs NCCL:昇腾生态通信库的独特优势分析
运维·c++·人工智能·cann
梦帮科技40 分钟前
Node.js配置生成器CLI工具开发实战
前端·人工智能·windows·前端框架·node.js·json
程序员泠零澪回家种桔子41 分钟前
Spring AI框架全方位详解
java·人工智能·后端·spring·ai·架构