机器学习之多层感知机 MLP简洁实现 《动手深度学习》实例


🎈 作者: Linux猿

**🎈 简介:**CSDN博客专家🏆,华为云享专家🏆,Linux、C/C++、云计算、物联网、面试、刷题、算法尽管咨询我,关注我,有问题私聊!

🎈欢迎小伙伴们点赞👍、收藏⭐、留言💬


本篇文章主要介绍《动手深度学习》实例中的多层感知机 MLP 的简洁实现。

一、代码实现

多层感知机(MLP)的简洁实现如下所示。

python 复制代码
import torch
from torch import nn
from d2l import torch as d2l

'''
1. 设置网络模型
'''
# nn.Sequential() 用于网络模型的各层
# nn.Flatten() 它用来将输入张量展平为 [batch_size, features] 的形式
# nn.Linear() 用于设置网络中的全连接层,需要注意的是全连接层的输入与输出都是二维张量
# https://www.cnblogs.com/douzujun/p/13366939.html
net = nn.Sequential(nn.Flatten(),   # 输入的是 [batch_size, n, m], 经过 nn.Flatten 后变为 [batch_size, nxm], 其中, n和m分别为图像的行和列
                    # [batch_size, in_features], [batch_size, out_features]
                    nn.Linear(784, 256),
                    # 激活函数 ReLU(x) = max(0,x)
                    nn.ReLU(),
                    # [batch_size, in_features], [batch_size, out_features]
                    nn.Linear(256, 10))

'''
2. 设置和应用权重
'''
# 用于初始化权重
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

# 用于初始化权重
net.apply(init_weights)


'''
3. 设置超参数、损失函数、优化函数
'''
batch_size, lr, num_epochs = 256, 0.1, 10
# 交叉熵损失函数,一般用于多分类
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=lr)

'''
4. 训练模型
'''
# 获取训练集和测试集
# load_data_fashion_mnist 函数加载 Fashion-MNIST 数据集(服饰数据集)
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

# 训练模型
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
d2l.plt.show()

二、代码解析

在上述代码中,需要注意的是 nn.Flatten() 函数,默认情况下将输入的多维矩阵除第一维度外的部分扩展为一维,例如:假设有 [n,m,k] 的多维矩阵,调用 nn.Flatten() 后变成 [n, m*k] 的矩阵。上述代码中,大多数都是调用的封装的函数,可以通过调节超参数以及替换对应的函数(例如:激活函数替换)来提升训练精度。


🎈 感觉有帮助记得**「一键三连** 支持下哦!有问题可在评论区留言💬,感谢大家的一路支持!🤞猿哥将持续输出「优质文章 **」**回馈大家!🤞🌹🌹🌹🌹🌹🌹🤞


相关推荐
Dm_dotnet1 小时前
公益站Agent Router注册送200刀额度竟然是真的
人工智能
算家计算2 小时前
7B参数拿下30个世界第一!Hunyuan-MT-7B本地部署教程:腾讯混元开源业界首个翻译集成模型
人工智能·开源
机器之心2 小时前
LLM开源2.0大洗牌:60个出局,39个上桌,AI Coding疯魔,TensorFlow已死
人工智能·openai
Juchecar3 小时前
交叉熵:深度学习中最常用的损失函数
人工智能
林木森ai3 小时前
爆款AI动物运动会视频,用Coze(扣子)一键搞定全流程(附保姆级拆解)
人工智能·aigc
聚客AI4 小时前
🙋‍♀️Transformer训练与推理全流程:从输入处理到输出生成
人工智能·算法·llm
BeerBear5 小时前
【保姆级教程-从0开始开发MCP服务器】一、MCP学习压根没有你想象得那么难!.md
人工智能·mcp
小气小憩5 小时前
“暗战”百度搜索页:Monica悬浮球被“围剿”,一场AI Agent与传统巨头的流量攻防战
前端·人工智能
神经星星5 小时前
准确度提升400%!印度季风预测模型基于36个气象站点,实现城区尺度精细预报
人工智能
IT_陈寒8 小时前
JavaScript 性能优化:5 个被低估的 V8 引擎技巧让你的代码快 200%
前端·人工智能·后端