深度学习-2.6在MINST-FASHION上实现神经网络的学习流程

文章目录

在MINST-FASHION上实现神经网络的学习流程

现在我们要整合本节课中所有的代码实现一个完整的训练流程。

首先要梳理一下整个流程:

  • 1)设置步长lr,动量值 g a m m a gamma gamma ,迭代次数 e p o c h s epochs epochs , b a t c h _ s i z e batch\_size batch_size等信息,(如果需要)设置初始权重 w 0 w_0 w0

  • 2)导入数据,将数据切分成 b a t c h _ s i z e batch\_size batch_size

  • 3)定义神经网络架构

  • 4)定义损失函数 L ( w ) L(w) L(w),如果需要的话,将损失函数调整成凸函数,以便求解最小值

  • 5)定义所使用的优化算法

  • 6)开始在 e p o c h e s epoches epoches 和 b a t c h batch batch上循环,执行优化算法:

    • 6.1)调整数据结构,确定数据能够在神经网络、损失函数和优化算法中顺利运行;
    • 6.2)完成向前传播,计算初始损失
    • 6.3)利用反向传播,在损失函数 L ( w ) L(w) L(w)上对每一个 w w w求偏导数
    • 6.4)迭代当前权重
    • 6.5)清空本轮梯度
    • 6.6)完成模型进度与效果监控
  • 7)输出结果

1. 导库

这次我们要使用PyTorch中自带的数据,MINST-FATION。

python 复制代码
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
#确定数据、确定优先需要设置的值
lr = 0.15
gamma = 0
epochs = 10
bs = 128

2. 导入数据,分割小批量

python 复制代码
import torchvision
import torchvision.transforms as transforms#初次运行时会下载,需要等待较长时间
mnist = torchvision.datasets.FashionMNIST(root='C:\Pythonwork\DEEP LEARNING\Datasets\FashionMNIST',
										train=True, 
										download=True, 
										transform=transforms.ToTensor())

len(mnist)#查看特征张量

mnist.data#这个张量结构看起来非常常规,可惜的是它与我们要输入到模型的数据结构有差异

#查看标签
mnist.targets

#查看标签的类别
mnist.classes
#查看图像的模样
import matplotlib.pyplot as plt
plt.imshow(mnist[0][0].view((28, 28)).numpy());

plt.imshow(mnist[1][0].view((28, 28)).numpy());

#分割batch
batchdata = DataLoader(mnist,batch_size=bs, shuffle = True)
#总共多少个batch?
len(batchdata)
#查看会放入进行迭代的数据结构
for x,y in batchdata:
    print(x.shape)
    print(y.shape)
    break

input_ = mnist.data[0].numel() #特征的数目,一般是第一维之外的所有维度相乘的数
output_ = len(mnist.targets.unique()) #分类的数目#最好确认一下没有错误

input_output_


#========================
import torchvision
import torchvision.transforms as transforms
mnist = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=False, transform=transforms.ToTensor())
batchdata = DataLoader(mnist,batch_size=bs, shuffle = True)
input_ = mnist.data[0].numel()
output_ = len(mnist.targets.unique()

3. 定义神经网络

python 复制代码
class Model(nn.Module):
    def __init__(self,in_features=10,out_features=2):
        super().__init__()
        #self.normalize = nn.BatchNorm2d(num_features=1)
        self.linear1 = nn.Linear(in_features,128,bias=False)
        self.output = nn.Linear(128,out_features,bias=False)
    def forward(self, x):
        #x = self.normalize(x)
        x = x.view(-1, 28*28)
        #需要对数据的结构进行一个改变,这里的"-1"代表,我不想算,请pytorch帮我计算
        sigma1 = torch.relu(self.linear1(x))
        z2 = self.output(sigma1)
        sigma2 = F.log_softmax(z2,dim=1)
        return sigma2

4.定义训练函数

python 复制代码
def fit(net,batchdata,lr=0.01,epochs=5,gamma=0):
    criterion = nn.NLLLoss() #定义损失函数
    opt = optim.SGD(net.parameters(), lr=lr,momentum=gamma) #定义优化算法
    correct = 0
    samples = 0
    for epoch in range(epochs):
        for batch_idx, (x,y) in enumerate(batchdata):
            y = y.view(x.shape[0])
            sigma = net.forward(x)
            loss = criterion(sigma,y)
            loss.backward()
            opt.step()
            opt.zero_grad()#求解准确率
            yhat = torch.max(sigma,1)[1]
            correct + torch. sum Cyhat == y)
            samples + = x. shape [ o]
            if (batch_ idx+ 1) % 125 o or batch_ idx = len (batchdata)-1:
                print( Epocht: [ / (:of] % ) ] tLoss : 6ft Accuracy::.3f].format(
                    epoch+1 
                    , samples 
                    ,len( batchdata. dataset) * epochs
                    ,100* samples/ ( len (batchdata. dataset)epochs)
                    ,loss.data.item()
                    ,float(correct*100)/samples))

5.进行训练与评估

python 复制代码
#实例化神经网络,调用优化算法需要的参数
torch. manualseed(420)
net = Mode ( in_ features= input_ out_features=output_)
fit( net, batchdata, lr= lr, epochs= epochs, gamma=gamma)

我们现在已经完成了一个最基本的、神经网络训练并查看训练结果的代码。

相关推荐
B1nna4 小时前
Redis学习(三)缓存
redis·学习·缓存
人类群星闪耀时5 小时前
深度学习在灾难恢复中的作用:智能运维的新时代
运维·人工智能·深度学习
_im.m.z5 小时前
【设计模式学习笔记】1. 设计模式概述
笔记·学习·设计模式
机器懒得学习5 小时前
从随机生成到深度学习:使用DCGAN和CycleGAN生成图像的实战教程
人工智能·深度学习
落魄君子6 小时前
BP回归-反向传播(Backpropagation)
人工智能·神经网络·回归
烟波人长安吖~6 小时前
【目标跟踪+人流计数+人流热图(Web界面)】基于YOLOV11+Vue+SpringBoot+Flask+MySQL
vue.js·pytorch·spring boot·深度学习·yolo·目标跟踪
最好Tony6 小时前
深度学习blog-Transformer-注意力机制和编码器解码器
人工智能·深度学习·机器学习·计算机视觉·自然语言处理·chatgpt
左漫在成长7 小时前
王佩丰24节Excel学习笔记——第十九讲:Indirect函数
笔记·学习·excel
四口鲸鱼爱吃盐8 小时前
Pytorch | 利用SMI-FGRM针对CIFAR10上的ResNet分类器进行对抗攻击
人工智能·pytorch·python·深度学习·机器学习·计算机视觉
纪伊路上盛名在8 小时前
Max AI prompt1
笔记·学习·学习方法