pytorch学习3(pytorch手写数字识别练习)

网络模型

设置三层网络,一般最后一层激活函数不选择relu

任务步骤

复制代码
手写数字识别任务共有四个步骤:
1、数据加载--Load Data
2、构建网络--Build Model
3、训练--Train
4、测试--Test

实战

1、导入各种需要的包

python 复制代码
import torch
from torch import nn
from torch.nn import functional as F
from torch import optim

import torchvision

from matplotlib import pyplot as plt

from minist_utils import plot_image, plot_curve, one_hot ##自写文件

minist_utils:

2、加载数据

python 复制代码
batch_size = 512

train_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=True, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081, ))
                               ])),
    batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    torchvision.datasets.MNIST('mnist_data', train=False, download=True,
                               transform=torchvision.transforms.Compose([
                                   torchvision.transforms.ToTensor(),
                                   torchvision.transforms.Normalize(
                                       (0.1307,), (0.3081, ))
                               ])),
    batch_size=batch_size, shuffle=False

取一些样本看数据的shape以及图片内容

python 复制代码
x, y = next(iter(train_loader))
print(x.shape, y.shape, x.min(), x.max())
plot_image(x, y, 'image sample')

复制代码
注:经过load加载处理后的数据集包含x(图像信息)和y(标签信息)
next(iter())的用法是取一组样本,重复运行可以依次顺序取样,直到样本被取完
可在csdn自行搜索学习了解

3、网络构建

按之前设想的三层线性模型嵌套的思想搭建模型,为了模型简单,第三层不加激活函数。

python 复制代码
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()

        # xw+b
        self.fc1 = nn.Linear(28*28, 256) #输入特征数,输出特征数
        self.fc2 = nn.Linear(256, 64)  #256,64是根据经验判断
        self.fc3 = nn.Linear(64, 10)  #最开始的28*28和输出的10是一定的

    def forward(self, x):
        # x: [b, 1, 28, 28]
        # h1 = relu(xw1 + b1)
        x = F.relu(self.fc1(x)) #输入x后第一次线性模型得到H1作第二层输入
        # h2 = relu(h1w2 + b2)
        x = F.relu(self.fc2(x)) #输入H1得到H2作第三层输入
        # h3 = h2w3 + b3
        x = self.fc3(x)	#输入H3得到最终结果,维度为10

        return x

4、模型训练

python 复制代码
net = Net()

# [w1, b1, w2, b2, w3, b3]
optimizer = optim.SGD(net.parameters(), lr=0.01, momentum=0.9)

train_loss = []

for epoch in range(3):

    for batch_idx, (x, y) in enumerate(train_loader):

        # x: [b, 1, 28, 28], y: [512]
        # [b, 1, 28, 28] => [b, feature] 全连接层只能接受这样的数据
        x = x.view(x.size(0), 28*28)
        # => [b, 10]
        out = net(x)
        # [b, 10]
        y_onehot = one_hot(y)
        # loss = mse(out, y_onehot)
        loss = F.mse_loss(out, y_onehot)

        optimizer.zero_grad()
        loss.backward() # 梯度计算过程
        # w` = w - lr * grad
        optimizer.step() # 优化更新w,b

        train_loss.append(loss.item())

        if batch_idx % 10 == 0:
            print(epoch, batch_idx, loss.item())

plot_curve(train_loss)

5、测试

1、计算准确率acc

python 复制代码
total_correct = 0
for x, y in test_loader:
    x = x.view(x.size(0), 28*28)
    out = net(x)
    # out: [b, 10] => pred: [b]
    pred = out.argmax(dim=1)
    correct = pred.eq(y).sum().float().item()
    total_correct += correct

total_num = len(test_loader.dataset)
acc = total_correct / total_num
print(("acc:", acc))

2、展示部分测试样本原图以及预测标签结果

python 复制代码
x, y =next(iter(test_loader))
out = net(x.view(x.size(0), 28*28))
pred = out.argmax(dim=1)
plot_image(x, pred, 'test')
相关推荐
TL滕几秒前
从0开始学算法——第二十一天(复杂链表问题)
笔记·学习·算法
X_Eartha_8155 分钟前
前端学习—HTML基础语法(1)
前端·学习·html
不被AI替代的BOT9 分钟前
AgentScope深入分析-LLM&MCP
人工智能·后端
Jorunk12 分钟前
状态对齐是连接 GMM-HMM 和 DNN-HMM 的核心桥梁
人工智能·神经网络·dnn
秋深枫叶红13 分钟前
嵌入式第三十八篇——linux系统编程——IPC进程间通信
linux·服务器·网络·学习
YJlio15 分钟前
FindLinks 学习笔记(12.4):NTFS 硬链接扫描与文件“多重身份”排查
笔记·学习·intellij-idea
程序员大辉18 分钟前
新人学习Flutter,如何搭建开发环境(附所有安装包)
学习·flutter
袋鼠云数栈21 分钟前
媒体专访丨袋鼠云 CEO 宁海元:Agent元年之后,产业需回到“数据+智能”的长期结构
大数据·人工智能
Ahtacca26 分钟前
保姆级教程:Obsidian + PicGo + Gitee 搭建免费稳定的自动化图床
运维·笔记·学习·gitee·自动化
Wishell201528 分钟前
日拱一卒之Python与matlab的内存读取区别
pytorch