机器学习之多层感知机 MLP简洁实现 《动手深度学习》实例


🎈 作者: Linux猿

**🎈 简介:**CSDN博客专家🏆,华为云享专家🏆,Linux、C/C++、云计算、物联网、面试、刷题、算法尽管咨询我,关注我,有问题私聊!

🎈欢迎小伙伴们点赞👍、收藏⭐、留言💬


本篇文章主要介绍《动手深度学习》实例中的多层感知机 MLP 的简洁实现。

一、代码实现

多层感知机(MLP)的简洁实现如下所示。

python 复制代码
import torch
from torch import nn
from d2l import torch as d2l

'''
1. 设置网络模型
'''
# nn.Sequential() 用于网络模型的各层
# nn.Flatten() 它用来将输入张量展平为 [batch_size, features] 的形式
# nn.Linear() 用于设置网络中的全连接层,需要注意的是全连接层的输入与输出都是二维张量
# https://www.cnblogs.com/douzujun/p/13366939.html
net = nn.Sequential(nn.Flatten(),   # 输入的是 [batch_size, n, m], 经过 nn.Flatten 后变为 [batch_size, nxm], 其中, n和m分别为图像的行和列
                    # [batch_size, in_features], [batch_size, out_features]
                    nn.Linear(784, 256),
                    # 激活函数 ReLU(x) = max(0,x)
                    nn.ReLU(),
                    # [batch_size, in_features], [batch_size, out_features]
                    nn.Linear(256, 10))

'''
2. 设置和应用权重
'''
# 用于初始化权重
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

# 用于初始化权重
net.apply(init_weights)


'''
3. 设置超参数、损失函数、优化函数
'''
batch_size, lr, num_epochs = 256, 0.1, 10
# 交叉熵损失函数,一般用于多分类
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=lr)

'''
4. 训练模型
'''
# 获取训练集和测试集
# load_data_fashion_mnist 函数加载 Fashion-MNIST 数据集(服饰数据集)
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

# 训练模型
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)
d2l.plt.show()

二、代码解析

在上述代码中,需要注意的是 nn.Flatten() 函数,默认情况下将输入的多维矩阵除第一维度外的部分扩展为一维,例如:假设有 [n,m,k] 的多维矩阵,调用 nn.Flatten() 后变成 [n, m*k] 的矩阵。上述代码中,大多数都是调用的封装的函数,可以通过调节超参数以及替换对应的函数(例如:激活函数替换)来提升训练精度。


🎈 感觉有帮助记得**「一键三连** 支持下哦!有问题可在评论区留言💬,感谢大家的一路支持!🤞猿哥将持续输出「优质文章 **」**回馈大家!🤞🌹🌹🌹🌹🌹🌹🤞


相关推荐
孙同学要努力32 分钟前
全连接神经网络案例——手写数字识别
人工智能·深度学习·神经网络
Eric.Lee202132 分钟前
yolo v5 开源项目
人工智能·yolo·目标检测·计算机视觉
其实吧32 小时前
基于Matlab的图像融合研究设计
人工智能·计算机视觉·matlab
丕羽2 小时前
【Pytorch】基本语法
人工智能·pytorch·python
ctrey_2 小时前
2024-11-1 学习人工智能的Day20 openCV(2)
人工智能·opencv·学习
SongYuLong的博客2 小时前
Air780E基于LuatOS编程开发
人工智能
Jina AI2 小时前
RAG 系统的分块难题:小型语言模型如何找到最佳断点?
人工智能·语言模型·自然语言处理
-派神-2 小时前
大语言模型(LLM)量化基础知识(一)
人工智能·语言模型·自然语言处理
johnny_hhh2 小时前
AI大模型重塑软件开发流程:定义、应用场景、优势、挑战及未来展望
人工智能
Elastic 中国社区官方博客2 小时前
释放专利力量:Patently 如何利用向量搜索和 NLP 简化协作
大数据·数据库·人工智能·elasticsearch·搜索引擎·自然语言处理