手写数字识别

python 复制代码
# 使用pytorch完成手写数字识别
# 准备dataloader
import os
import torch
from torch.utils.data import DataLoader
# 这里使用torch自带的MNIST数据集
from torchvision.datasets import MNIST
# 这里是为了将图片进行处理,调用这三个函数
from torchvision.transforms import Compose, ToTensor, Normalize
from torch import nn
import torch.nn.functional as F
from torch.optim import Adam
import numpy as np

BATCH_SIZE = 128
TEST_BATCH_SIZE = 1000


# 1. 准备数据集
def get_dataloader(train=True,batch_size=BATCH_SIZE):
    transform_fn = Compose([
        ToTensor(),     # 因为图片本身不是一个tensor,所以要to一下
        Normalize(mean=(0.1307,),std=(0.3081,))        # mean和std的形状和通道数相同,里面的数据mean和std都是给定的,直接搜

    ])
    dataset = MNIST(root="data",train=train,transform=transform_fn)
    data_loader = DataLoader(dataset,batch_size=BATCH_SIZE,shuffle=True)
    return data_loader

# 2. 构建模型
class MnistModel(nn.Module):
    def __init__(self):
        super(MnistModel, self).__init__()
        self.fc1 = nn.Linear(1*28*28,28) #后面这个28表示经过变换后的形状为28而不是1*28*28了
        self.fc2 = nn.Linear(28,10) #最后只有10个数所以为10

    def forward(self,input):

        # 1.修改形状
        x = input.view(-1,28*28*1)   #对数据形状变形,-1表示该位置根据后面的形状自动调整
        # 2.进行全连接操作
        x = self.fc1(x)     #[batch_size,28]
        # 3. 进行激活函数处理,形状不发生变化
        x = F.relu(x)   #[batch_size,28]
        # 4. 输出层
        out = self.fc2(x)     #[batch_size,10]

        return F.log_softmax(out,dim=-1)


# 实例化模型
model = MnistModel()
# 优化器类
optimizer = Adam(model.parameters(),lr=0.001)
# 判断模型是否存在
# 模型的加载,用于断连之后快速重启,以低损失运行
if os.path.exists("./model/model.pkl"):
    model.load_state_dict(torch.load("./model/model.pkl"))
    optimizer.load_state_dict(torch.load("./model/optimizer.pkl"))


# 3. 实现训练的过程
def train(epoch):

    data_loader = get_dataloader()
    for idx,(input,traget) in enumerate(data_loader):
        optimizer.zero_grad()   # 梯度归零
        output = model(input)    # 调用模型得到预测值
        loss = F.nll_loss(output,traget)   # 得到损失
        loss.backward()   # 反向传播
        optimizer.step()   # 梯度更新
        if idx%10==0:
            print(epoch,idx,loss.item())

        # 模型的保存
        if idx%100==0:
            torch.save(model.state_dict(),"./model/model.pkl")
            torch.save(optimizer.state_dict(),"./model/optimizer.pkl")


def test():
    loss_list = []
    acc_list = []
    test_dataloader = get_dataloader(train=False,batch_size=TEST_BATCH_SIZE)
    for idx,(input,target) in enumerate(test_dataloader):
        # 测试不需要梯度
        with torch.no_grad():
            output = model(input)
            cur_loss = F.nll_loss(output,target)
            loss_list.append(cur_loss)
            # 计算准确率
            # -1代表每一行的最大值,0代表每一列的最大值。里面有两组数,一组是值,一组是那个值的坐标.我们这里要坐标
            pred = output.max(dim=-1)[-1]
            # 比较两组数,eq方法返回的是bool值所以要变为float,再求均值
            cur_acc = pred.eq(target).float().mean()
            acc_list.append(cur_acc)
    print("平均准确率,平均损失:",np.mean(acc_list),np.mean(loss_list))



if __name__ == '__main__':
    # for i in range(3): # 训练3轮
    #     train(i)
    #
    test()

test和train是两个方向,想要走test就是现在代码写的,想运行train就是把main函数里面的注释取消

相关推荐
查无此人byebye2 小时前
【保姆级教程】从零实现模块化Transformer对话生成模型(PyTorch完整代码)
pytorch·深度学习·transformer
红茶川4 小时前
[ExecuTorch 系列] 2. 导出官方支持的大语言模型
人工智能·pytorch·ai·端侧ai
shy^-^cky5 小时前
TensorFlow、PyTorch、PaddlePaddle 三大深度学习框架全维度对比表
pytorch·深度学习·tensorflow·paddlepaddle·飞桨
兜兜风d'6 小时前
PyTorch深度学习实践——卷积神经网络高级篇
人工智能·pytorch·深度学习
zhangfeng11338 小时前
unsloth 安装的时候会 自动升级torch版本,解决办法
人工智能·pytorch
Narrastory13 小时前
明日香 - Pytorch 快速入门保姆级教程(五)
人工智能·pytorch·深度学习
如若1231 天前
flash-attn 安装失败?从报错到成功的完整排雷指南(CUDA 12.8 + PyTorch 2.7)
人工智能·pytorch·python
love530love1 天前
Windows 11 源码编译 vLLM 0.16 完全指南(CUDA 12.6 / PyTorch 2.7.1+cu126)
人工智能·pytorch·windows·python·深度学习·comfyui·vllm
兜兜风d'1 天前
PyTorch 深度学习实践——加载数据集
人工智能·pytorch·深度学习
一碗姜汤1 天前
torch.autograd.Function的apply()方法作用
人工智能·pytorch·深度学习