深度学习-2.6在MINST-FASHION上实现神经网络的学习流程

文章目录

在MINST-FASHION上实现神经网络的学习流程

现在我们要整合本节课中所有的代码实现一个完整的训练流程。

首先要梳理一下整个流程:

  • 1)设置步长lr,动量值 g a m m a gamma gamma ,迭代次数 e p o c h s epochs epochs , b a t c h _ s i z e batch\_size batch_size等信息,(如果需要)设置初始权重 w 0 w_0 w0

  • 2)导入数据,将数据切分成 b a t c h _ s i z e batch\_size batch_size

  • 3)定义神经网络架构

  • 4)定义损失函数 L ( w ) L(w) L(w),如果需要的话,将损失函数调整成凸函数,以便求解最小值

  • 5)定义所使用的优化算法

  • 6)开始在 e p o c h e s epoches epoches 和 b a t c h batch batch上循环,执行优化算法:

    • 6.1)调整数据结构,确定数据能够在神经网络、损失函数和优化算法中顺利运行;
    • 6.2)完成向前传播,计算初始损失
    • 6.3)利用反向传播,在损失函数 L ( w ) L(w) L(w)上对每一个 w w w求偏导数
    • 6.4)迭代当前权重
    • 6.5)清空本轮梯度
    • 6.6)完成模型进度与效果监控
  • 7)输出结果

1. 导库

这次我们要使用PyTorch中自带的数据,MINST-FATION。

python 复制代码
import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
#确定数据、确定优先需要设置的值
lr = 0.15
gamma = 0
epochs = 10
bs = 128

2. 导入数据,分割小批量

python 复制代码
import torchvision
import torchvision.transforms as transforms#初次运行时会下载,需要等待较长时间
mnist = torchvision.datasets.FashionMNIST(root='C:\Pythonwork\DEEP LEARNING\Datasets\FashionMNIST',
										train=True, 
										download=True, 
										transform=transforms.ToTensor())

len(mnist)#查看特征张量

mnist.data#这个张量结构看起来非常常规,可惜的是它与我们要输入到模型的数据结构有差异

#查看标签
mnist.targets

#查看标签的类别
mnist.classes
#查看图像的模样
import matplotlib.pyplot as plt
plt.imshow(mnist[0][0].view((28, 28)).numpy());

plt.imshow(mnist[1][0].view((28, 28)).numpy());

#分割batch
batchdata = DataLoader(mnist,batch_size=bs, shuffle = True)
#总共多少个batch?
len(batchdata)
#查看会放入进行迭代的数据结构
for x,y in batchdata:
    print(x.shape)
    print(y.shape)
    break

input_ = mnist.data[0].numel() #特征的数目,一般是第一维之外的所有维度相乘的数
output_ = len(mnist.targets.unique()) #分类的数目#最好确认一下没有错误

input_output_


#========================
import torchvision
import torchvision.transforms as transforms
mnist = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=False, transform=transforms.ToTensor())
batchdata = DataLoader(mnist,batch_size=bs, shuffle = True)
input_ = mnist.data[0].numel()
output_ = len(mnist.targets.unique()

3. 定义神经网络

python 复制代码
class Model(nn.Module):
    def __init__(self,in_features=10,out_features=2):
        super().__init__()
        #self.normalize = nn.BatchNorm2d(num_features=1)
        self.linear1 = nn.Linear(in_features,128,bias=False)
        self.output = nn.Linear(128,out_features,bias=False)
    def forward(self, x):
        #x = self.normalize(x)
        x = x.view(-1, 28*28)
        #需要对数据的结构进行一个改变,这里的"-1"代表,我不想算,请pytorch帮我计算
        sigma1 = torch.relu(self.linear1(x))
        z2 = self.output(sigma1)
        sigma2 = F.log_softmax(z2,dim=1)
        return sigma2

4.定义训练函数

python 复制代码
def fit(net,batchdata,lr=0.01,epochs=5,gamma=0):
    criterion = nn.NLLLoss() #定义损失函数
    opt = optim.SGD(net.parameters(), lr=lr,momentum=gamma) #定义优化算法
    correct = 0
    samples = 0
    for epoch in range(epochs):
        for batch_idx, (x,y) in enumerate(batchdata):
            y = y.view(x.shape[0])
            sigma = net.forward(x)
            loss = criterion(sigma,y)
            loss.backward()
            opt.step()
            opt.zero_grad()#求解准确率
            yhat = torch.max(sigma,1)[1]
            correct + torch. sum Cyhat == y)
            samples + = x. shape [ o]
            if (batch_ idx+ 1) % 125 o or batch_ idx = len (batchdata)-1:
                print( Epocht: [ / (:of] % ) ] tLoss : 6ft Accuracy::.3f].format(
                    epoch+1 
                    , samples 
                    ,len( batchdata. dataset) * epochs
                    ,100* samples/ ( len (batchdata. dataset)epochs)
                    ,loss.data.item()
                    ,float(correct*100)/samples))

5.进行训练与评估

python 复制代码
#实例化神经网络,调用优化算法需要的参数
torch. manualseed(420)
net = Mode ( in_ features= input_ out_features=output_)
fit( net, batchdata, lr= lr, epochs= epochs, gamma=gamma)

我们现在已经完成了一个最基本的、神经网络训练并查看训练结果的代码。

相关推荐
乖乖是干饭王30 分钟前
Linux系统编程中的_GNU_SOURCE宏
linux·运维·c语言·学习·gnu
Best_Me071 小时前
深度学习模块缝合
人工智能·深度学习
待什么青丝1 小时前
【TMS570LC4357】之相关驱动开发学习记录2
c语言·arm开发·驱动开发·单片机·学习
行云流水剑2 小时前
【学习记录】如何使用 Python 提取 PDF 文件中的内容
python·学习·pdf
虾球xz3 小时前
CppCon 2015 学习:CLANG/C2 for Windows
开发语言·c++·windows·学习
狂小虎3 小时前
亲测解决self.transform is not exist
python·深度学习
Fxrain3 小时前
[深度学习]搭建开发平台及Tensor基础
人工智能·深度学习
蓝婷儿3 小时前
6个月Python学习计划 Day 17 - 继承、多态与魔术方法
开发语言·python·学习
持续前进的奋斗鸭4 小时前
Postman测试学习(1)
学习·postman
hello kitty w4 小时前
Python学习(7) ----- Python起源
linux·python·学习