【李沐第三章】3.6、softmax回归的简单实现

1、初始化模型参数

python 复制代码
# PyTorch不会隐式地调整输入的形状。因此,
# 我们在线性层前定义了展平层(flatten),来调整网络输入的形状
net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))

def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

# 对模型的权重进行初始化操作
net.apply(init_weights);

2、重新定义softmax损失函数

3、优化方法

python 复制代码
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

4、训练

python 复制代码
num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
相关推荐
UrbanJazzerati几秒前
Python列表操作小练习
python
UrbanJazzerati4 分钟前
Python条件语句与循环结构详解
python
AI妈妈手把手1 小时前
【深度学习框架终极PK】TensorFlow/PyTorch/MindSpore深度解析!选对框架效率翻倍
人工智能·pytorch·python·深度学习·tensorflow·mindspore·ai选型指南
苇柠2 小时前
Java数组补充v2
java·python·排序算法
玲娜贝儿--努力学习买大鸡腿版2 小时前
推荐系统---AUC计算
人工智能·python·机器学习
蓝倾9762 小时前
小红书获取关键词列表API接口详解
开发语言·数据库·python
是小崔啊2 小时前
【爬虫】03 - 爬虫的基本数据存储
网络·爬虫·python·beautifulsoup
java1234_小锋3 小时前
【NLP舆情分析】基于python微博舆情分析可视化系统(flask+pandas+echarts) 视频教程 - 基于jieba实现词频统计
python·自然语言处理·flask
星期天要睡觉3 小时前
python网络爬虫(第一章/共三章:网络爬虫库、robots.txt规则(防止犯法)、查看获取网页源代码)
开发语言·爬虫·python
Gyoku Mint4 小时前
深度学习×第10卷:她用一块小滤镜,在图像中找到你
人工智能·python·深度学习·神经网络·opencv·算法·cnn