【李沐第三章】3.6、softmax回归的简单实现

1、初始化模型参数

python 复制代码
# PyTorch不会隐式地调整输入的形状。因此,
# 我们在线性层前定义了展平层(flatten),来调整网络输入的形状
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

# 对模型的权重进行初始化操作
net.apply(init_weights);

2、重新定义softmax损失函数

3、优化方法

python 复制代码
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

4、训练

python 复制代码
num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
相关推荐
Zoey的笔记本27 分钟前
敏捷与稳定并行:Scrum看板+BPM工具选型指南
大数据·前端·数据库·python·低代码
开开心心就好2 小时前
图片格式转换工具,右键菜单一键转换简化
linux·运维·服务器·python·django·pdf·1024程序员节
骥龙2 小时前
1.2下、工欲善其事:物联网安全研究环境搭建指南
python·物联网·安全
Lxinccode2 小时前
BUG(20) : response.text耗时很久, linux耗时十几秒, Windows耗时零点几秒
python·bug·requests·response.text·response.text慢
智航GIS2 小时前
11.2 Matplotlib 数据可视化教程
python·信息可视化·matplotlib
技术净胜2 小时前
Python 操作 Cookie 完全指南,爬虫与 Web 开发实战
前端·爬虫·python
海棠AI实验室2 小时前
第六章 日志体系:logging 让排错效率翻倍
python·logging
laufing2 小时前
flask_restx 创建restful api
python·flask·restful
毕设源码-郭学长3 小时前
【开题答辩全过程】以 基于python电商商城系统为例,包含答辩的问题和答案
开发语言·python