【李沐第三章】3.6、softmax回归的简单实现

1、初始化模型参数

python 复制代码
# PyTorch不会隐式地调整输入的形状。因此,
# 我们在线性层前定义了展平层(flatten),来调整网络输入的形状
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

# 对模型的权重进行初始化操作
net.apply(init_weights);

2、重新定义softmax损失函数

3、优化方法

python 复制代码
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

4、训练

python 复制代码
num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
相关推荐
rabbit_pro1 分钟前
Java 文件上传到服务器本地存储
java·服务器·python
serve the people20 分钟前
PQ+IVF组合解决海量向量内存占用高和检索慢的问题
人工智能·python
on_pluto_22 分钟前
【debug】解决 5070ti 与 pytorch 版本不兼容的问题
人工智能·pytorch·python
嫂子的姐夫22 分钟前
02-多线程
爬虫·python·多线程·并发爬虫·基础爬虫
【建模先锋】1 小时前
基于Python的智能故障诊断系统 | SmartDiag AI (基础版)V1.0 正式发布!
开发语言·人工智能·python·故障诊断·智能分析平台·大数据分析平台·智能故障诊断系统
AIsdhuang1 小时前
2025 年企业 AI 培训精选指南:聚焦企业培训场景
人工智能·python
今天没有盐1 小时前
Python 数据分析实战:多场景数据处理与可视化全解析
python·pycharm·编程语言
程序员三藏1 小时前
如何用Postman做接口自动化测试?
自动化测试·软件测试·python·测试工具·测试用例·接口测试·postman
n***27192 小时前
JAVA (Springboot) i18n国际化语言配置
java·spring boot·python
心无旁骛~2 小时前
python多进程multiprocessing——spawn启动方式解析
开发语言·python