机器学习之多层感知机 MLP简洁实现 《动手深度学习》实例


🎈 作者: Linux猿

**🎈 简介:**CSDN博客专家🏆,华为云享专家🏆,Linux、C/C++、云计算、物联网、面试、刷题、算法尽管咨询我,关注我,有问题私聊!

🎈欢迎小伙伴们点赞👍、收藏⭐、留言💬


本篇文章主要介绍《动手深度学习》实例中的多层感知机 MLP 的简洁实现。

一、代码实现

多层感知机(MLP)的简洁实现如下所示。

python 复制代码
import torch
from torch import nn
from d2l import torch as d2l

'''
1. 设置网络模型
'''
# nn.Sequential() 用于网络模型的各层
# nn.Flatten() 它用来将输入张量展平为 [batch_size, features] 的形式
# nn.Linear() 用于设置网络中的全连接层,需要注意的是全连接层的输入与输出都是二维张量
# https://www.cnblogs.com/douzujun/p/13366939.html
net = nn.Sequential(nn.Flatten(),   # 输入的是 [batch_size, n, m], 经过 nn.Flatten 后变为 [batch_size, nxm], 其中, n和m分别为图像的行和列
                    # [batch_size, in_features], [batch_size, out_features]
                    nn.Linear(784, 256),
                    # 激活函数 ReLU(x) = max(0,x)
                    nn.ReLU(),
                    # [batch_size, in_features], [batch_size, out_features]
                    nn.Linear(256, 10))

'''
2. 设置和应用权重
'''
# 用于初始化权重
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

# 用于初始化权重
net.apply(init_weights)


'''
3. 设置超参数、损失函数、优化函数
'''
batch_size, lr, num_epochs = 256, 0.1, 10
# 交叉熵损失函数,一般用于多分类
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=lr)

'''
4. 训练模型
'''
# 获取训练集和测试集
# load_data_fashion_mnist 函数加载 Fashion-MNIST 数据集(服饰数据集)
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

# 训练模型
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
d2l.plt.show()

二、代码解析

在上述代码中,需要注意的是 nn.Flatten() 函数,默认情况下将输入的多维矩阵除第一维度外的部分扩展为一维,例如:假设有 [n,m,k] 的多维矩阵,调用 nn.Flatten() 后变成 [n, m*k] 的矩阵。上述代码中,大多数都是调用的封装的函数,可以通过调节超参数以及替换对应的函数(例如:激活函数替换)来提升训练精度。


🎈 感觉有帮助记得**「一键三连** 支持下哦!有问题可在评论区留言💬,感谢大家的一路支持!🤞猿哥将持续输出「优质文章 **」**回馈大家!🤞🌹🌹🌹🌹🌹🌹🤞


相关推荐
子午4 分钟前
【2026计算机毕设~AI项目】鸟类识别系统~Python+深度学习+人工智能+图像识别+算法模型
图像处理·人工智能·python·深度学习
发哥来了4 分钟前
《AI视频生成工具选型评测:多维度解析主流产品优劣势》
人工智能
DisonTangor6 分钟前
美团龙猫开源LongCat-Flash-Lite
人工智能·语言模型·自然语言处理·开源·aigc
杨浦老苏7 分钟前
Docker方式安装你的私人AI电脑助手Moltbot
人工智能·docker·ai·群晖
矢志航天的阿洪18 分钟前
IGRF-13 数学细节与公式说明
线性代数·机器学习·矩阵
昨夜见军贴061623 分钟前
功能决定效率:IACheck的AI审核在生产型检测报告中的实践观察
人工智能
传说故事41 分钟前
【论文自动阅读】Goal Force: 教视频模型实现Physics-Conditioned Goals
人工智能·深度学习·视频生成
FPGA小c鸡1 小时前
【FPGA深度学习加速】RNN与LSTM硬件加速完全指南:从算法原理到硬件实现
rnn·深度学习·fpga开发
186******205311 小时前
项目开发基础知识:从概念到落地的全流程指南
大数据·人工智能
说私域1 小时前
AI智能名片商城小程序数据清洗的持续运营策略与实践研究
大数据·人工智能·小程序·流量运营·私域运营