机器学习之多层感知机 MLP简洁实现 《动手深度学习》实例


🎈 作者: Linux猿

**🎈 简介:**CSDN博客专家🏆,华为云享专家🏆,Linux、C/C++、云计算、物联网、面试、刷题、算法尽管咨询我,关注我,有问题私聊!

**🎈**欢迎小伙伴们点赞👍、收藏⭐、留言💬


本篇文章主要介绍《动手深度学习》实例中的多层感知机 MLP 的简洁实现。

一、代码实现

多层感知机(MLP)的简洁实现如下所示。

python 复制代码
import torch
from torch import nn
from d2l import torch as d2l

'''
1. 设置网络模型
'''
# nn.Sequential() 用于网络模型的各层
# nn.Flatten() 它用来将输入张量展平为 [batch_size, features] 的形式
# nn.Linear() 用于设置网络中的全连接层,需要注意的是全连接层的输入与输出都是二维张量
# https://www.cnblogs.com/douzujun/p/13366939.html
net = nn.Sequential(nn.Flatten(),   # 输入的是 [batch_size, n, m], 经过 nn.Flatten 后变为 [batch_size, nxm], 其中, n和m分别为图像的行和列
                    # [batch_size, in_features], [batch_size, out_features]
                    nn.Linear(784, 256),
                    # 激活函数 ReLU(x) = max(0,x)
                    nn.ReLU(),
                    # [batch_size, in_features], [batch_size, out_features]
                    nn.Linear(256, 10))

'''
2. 设置和应用权重
'''
# 用于初始化权重
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

# 用于初始化权重
net.apply(init_weights)


'''
3. 设置超参数、损失函数、优化函数
'''
batch_size, lr, num_epochs = 256, 0.1, 10
# 交叉熵损失函数,一般用于多分类
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=lr)

'''
4. 训练模型
'''
# 获取训练集和测试集
# load_data_fashion_mnist 函数加载 Fashion-MNIST 数据集(服饰数据集)
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

# 训练模型
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
d2l.plt.show()

二、代码解析

在上述代码中,需要注意的是 nn.Flatten() 函数,默认情况下将输入的多维矩阵除第一维度外的部分扩展为一维,例如:假设有 n,m,k 的多维矩阵,调用 nn.Flatten() 后变成 n, m\*k 的矩阵。上述代码中,大多数都是调用的封装的函数,可以通过调节超参数以及替换对应的函数(例如:激活函数替换)来提升训练精度。


🎈 感觉有帮助记得**「一键三连** 支持下哦!有问题可在评论区留言💬,感谢大家的一路支持!🤞猿哥将持续输出「优质文章 **」**回馈大家!🤞🌹🌹🌹🌹🌹🌹🤞


相关推荐
kepppt几秒前
LikeShop 开源商城系统新增 AI 协同开发能力
人工智能·开源商城
namexingyun几秒前
GPT-5.6 前端生成能力深度解析:kindle/kepler/Levi三版本UI实测与技术推演
java·前端·人工智能·gpt·机器学习·ui
2301_818527782 分钟前
瑜伽服品牌出海——AI助力中国瑜伽服走向世界
人工智能
掘金酱3 分钟前
📱 TRAE SOLO 移动端上线征文——“我的第一次移动端AI办公” 评测 | 获奖名单公示
前端·人工智能·trae
sou_time3 分钟前
从 0 到 商用:AI Agent x SKILL x MCP 全栈实战教程:L3 商用篇:性能 / 成本 / 可观测性 / 安全 / 部署
人工智能·安全
m0_718677493 分钟前
关于用AI做游戏的分析
人工智能·游戏
逐米时代5 分钟前
为什么制造型企业需要企业知识库建设?
大数据·人工智能
百度Geek说6 分钟前
如何利用 Harness “一句话交付产品功能”?
人工智能
凯丨7 分钟前
Claude Fable 5 与 Mythos 5:Anthropic 新一代模型系列的架构猜想与定位分析
人工智能·gpt
jonyleek7 分钟前
AI与现有系统“两张皮”:如何无缝集成、快速落地?
人工智能·ai·agent·jvs·ai套件·jvs-ai套件