练习3-softmax分类(李沐函数简要解析)

环境为:练习1的环境

网址为:https://www.bilibili.com/video/BV1K64y1Q7wu/?spm_id_from=333.1007.top_right_bar_window_history.content.click

代码简要解析

导入模块

导入PyTorch

导入Torch中的nn模块

导入d2l中torch模块 并命名为d2l

import torch
from torch import nn
from d2l import torch as d2l

获取数据

从Fashion-MNIST中获取batch_size个数据 注意此处为28*28的像素图像 d2l.load_data_fashion_mnist(batch_size) 函数加载 Fashion-MNIST 数据集,并返回两个迭代器

batch_size=100
train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size)

初始化模型和参数

Flatten()将输入为28*28的像素图像摊开成一组784长的数组 作为特征值 输入

nn.Linear() 为784输入 10输出的层

net.apply(init); 是将其中init函数作为所有可变参数的初始化方式 注意:m是层 既对每层m进行判断 符合条件对m的权重进行初始化

type(m) == nn.Linear 用于检查变量 m 是否属于 PyTorch 中的线性层(nn.Linear

net=nn.Sequential(nn.Flatten(),nn.Linear(784,10))
def init_weights(m):
    if type(m)==nn.Linear:
            nn.init.normal_(m.weight,std=0.01)
        
net.apply(init_weights)

初始化损失函数 这里为交叉熵损失函数

loss=nn.CrossEntropyLoss(reduction='none')

设定梯度下降算法

torch.optim.SGD()

trainer=torch.optim.SGD(net.parameters(),lr=0.1)

训练

这里的d2l是李沐老师自己写的,想要运行成功,理论上需要把d2l下载下来

网址:https://github.com/d2l-ai/d2l-zh

num_epochs=10;
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

我所学习到的

获得Fashion-MNIST的数据

train_iter,test_iter=d2l.load_data_fashion_mnist(batch_size)

对输入进行平铺处理 其本质是把每个像素点都当作特征值

nn.Flatten()

多层的权重初始化

net.apply(init_weights)

交叉熵损失函数

loss=nn.CrossEntropy()
相关推荐
幼儿园老大*4 小时前
【系统架构】如何设计一个秒杀系统?
java·经验分享·后端·微服务·系统架构
Pandaconda6 小时前
【Golang 面试题】每日 3 题(三十九)
开发语言·经验分享·笔记·后端·面试·golang·go
Jason Yan15 小时前
【经验分享】ARM Linux-RT内核实时系统性能评估工具
linux·arm开发·经验分享
结衣结衣.19 小时前
「2024·我的成长之路」:年终反思与展望
经验分享·年终总结·个人成长·博客之星
paradoxjun1 天前
落地级分类模型训练框架搭建(1):resnet18/50和mobilenetv2在CIFAR10上测试结果
人工智能·深度学习·算法·计算机视觉·分类
Pandaconda1 天前
【新人系列】Python 入门(二十八):常用标准库 - 上
开发语言·经验分享·笔记·后端·python·面试·标准库
Makerbase_mks1 天前
MKS SERVO42E&57E 闭环步进电机_系列9 arduino 例程
经验分享·电机·电机控制·闭环步进·闭环步进电机
计软考研大C哥1 天前
【25考研】考清华的软件工程专业的研究生需要准备什么?
经验分享·考研·软件工程
十二测试录1 天前
【大厂面试题】软件测试面试题整理(附答案)
经验分享·面试·职场和发展
s_little_monster1 天前
【Linux】权限
linux·运维·数据库·经验分享·学习·学习方法