【李沐第三章】3.6、softmax回归的简单实现

1、初始化模型参数

python 复制代码
# PyTorch不会隐式地调整输入的形状。因此,
# 我们在线性层前定义了展平层(flatten),来调整网络输入的形状
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

# 对模型的权重进行初始化操作
net.apply(init_weights);

2、重新定义softmax损失函数

3、优化方法

python 复制代码
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

4、训练

python 复制代码
num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
相关推荐
灏瀚星空1 小时前
高频交易技术:订单簿分析与低延迟架构——从Level 2数据挖掘到FPGA硬件加速的全链路解决方案
人工智能·python·算法·信息可视化·fpga开发·架构·数据挖掘
Hello.Reader1 小时前
在多云环境透析连接ngx_stream_proxy_protocol_vendor_module
后端·python·flask
zh_199951 小时前
Spark面试精讲(上)
java·大数据·数据仓库·python·spark·数据库开发·数据库架构
没有钱的钱仔2 小时前
python文件传输 带进度条
服务器·网络·python
Python当打之年2 小时前
【62 Pandas+Pyecharts | 智联招聘大数据岗位数据分析可视化】
大数据·python·数据分析·pandas·数据可视化
好易学·数据结构2 小时前
可视化图解算法51:寻找第K大(数组中的第K个最大的元素)
数据结构·python·算法·leetcode·力扣·牛客网·堆栈
纬领网络3 小时前
Linux环境下安装和使用RAPIDS平台的cudf和cuml - pip 安装方法
开发语言·python·pip
成都犀牛3 小时前
LlamaIndex 学习笔记
人工智能·python·深度学习·神经网络·学习
猛犸MAMMOTH3 小时前
Python打卡第53天
开发语言·python·深度学习
thinking-fish4 小时前
提示词Prompts(2)
python·langchain·提示词·提示词模板