【深度学习5】多层感知机

1、概念

  • 将许多全连接层堆叠在一起。 每一层都输出到上面的层,直到生成最后的输出。 我们可以把前L-1层看作表示,把最后一层看作线性预测器。 这种架构通常称为多层感知机。
  • 活性值:使用仿射函数进行放射变换得到的结果,再对这个结果应用激活函数,此时得到的结果是活性值。
  • 激活函数对仿射变换的结果中的元素逐个进行处理,每个元素的输出只依赖自身,和其他元素无关。激活函数(ReLU、Sigmoid、Tanh、GELU 等)是 "按元素独立运算" 的,所以计算每个隐藏单元的活性值时,不需要依赖其他隐藏单元的中间结果。
  • 隐藏层的计算分为两步:仿射变换(z=Wx+b)→ 激活变换(a=σ(z))。
  • 隐藏层神经元会传递活性值,比如一个隐藏层神经元传递一个活性值,代表某一个特征,当隐藏层神经元数量足够多时,代表的特征也足够多,可以覆盖任意复杂函数的所有细节,就可以逼近任意函数。
  • 近似定理核心逻辑:每个神经元对应一个 "局部特征 / 片段",足够多的神经元就能覆盖任意连续函数的所有局部细节,再通过线性组合拼接,就能以任意精度逼近目标函数------ 这就是通用近似定理的核心思想。
  • 隐藏层神经元传递的信息:表层是 "活性值 a"(0~1 或其他范围的数值),深层是 "输入是否包含该神经元关注的特征" 的筛选结果(a 越大,特征越明显)。
  • 激活函数的角色:不是传递的信息,而是神经元 "产生传递信息" 的工具 ------ 对仿射变换的结果做非线性筛选,让神经元能输出 "有意义的特征信号"。
  • 修正线性单元(Rectified linear unit,ReLU):ReLU函数通过将相应的活性值设为0,仅保留正元素并丢弃所有负元素。
  • 挤压函数(squashing function,sigmoid):sigmoid函数将输入变换为区间(0, 1)上的输出。
  • tanh(双曲正切)函数:将其输入压缩转换到区间(-1, 1)上。

2、多层感知机的从零开始实现

python 复制代码
import matplotlib
matplotlib.use('TkAgg')  # 保证 PyCharm 可以显示图像
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from d2l import torch as d2l

# ================== 数据加载 ==================
batch_size = 256
transform = transforms.ToTensor()
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
train_iter = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_iter = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# ================== 模型参数 ==================
num_inputs, num_outputs, num_hiddens = 784, 10, 256

W1 = nn.Parameter(torch.randn(num_inputs, num_hiddens) * 0.01, requires_grad=True)
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
W2 = nn.Parameter(torch.randn(num_hiddens, num_outputs) * 0.01, requires_grad=True)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))
params = [W1, b1, W2, b2]

# ================== 激活函数和模型 ==================
def relu(X):
    return torch.max(X, torch.zeros_like(X))

def net(X):
    X = X.reshape((-1, num_inputs))
    H = relu(X @ W1 + b1)
    return H @ W2 + b2

# ================== 损失函数与优化器 ==================
loss = nn.CrossEntropyLoss(reduction='none')
lr, num_epochs = 0.1, 10
updater = torch.optim.SGD(params, lr=lr)

# ================== 训练 ==================
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

# ================== 预测可视化 ==================
X, y = next(iter(test_iter))
true_labels = [d2l.get_fashion_mnist_labels([int(i)]) for i in y[:8]]
pred_labels = [d2l.get_fashion_mnist_labels([int(i)]) for i in net(X[:8]).argmax(dim=1)]

titles = [t[0] + '\n' + p[0] for t, p in zip(true_labels, pred_labels)]
d2l.show_images(X[:8].reshape(8, 28, 28), 1, 8, titles=titles)
plt.show()  # 显示图像(PyCharm 需要显式调用)

3、多层感知机的简洁实现

python 复制代码
import matplotlib
matplotlib.use('TkAgg')  # ✅ 让 PyCharm 可以显示图像窗口
import matplotlib.pyplot as plt
import torch
from torch import nn
from d2l import torch as d2l

# ================== 模型 ==================
net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, std=0.01)
net.apply(init_weights)

# ================== 超参数 ==================
batch_size, lr, num_epochs = 256, 0.1, 10
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=lr)

# ================== 数据加载 ==================
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

# ================== 训练 ==================
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

# ================== 预测(带图像显示) ==================
d2l.predict_ch3(net, test_iter)
plt.show()  # ✅ 关键:让图像在 PyCharm 弹出
相关推荐
蕤葳-几秒前
非编程背景学习AI的方法
人工智能
北京耐用通信3 分钟前
不换设备、不重写程序:耐达讯自动化网关如何实现CC-Link IE转Modbus TCP的高效互通?
人工智能·科技·物联网·网络协议·自动化·信息与通信
计算机毕业设计指导4 分钟前
基于机器学习和深度学习的恶意WebURL检测系统实战详解
人工智能·深度学习·机器学习·网络安全
珂朵莉MM5 分钟前
第七届全球校园人工智能算法精英大赛-算法巅峰赛产业命题赛第3赛季优化题--多策略混合算法
人工智能·算法
GlobalInfo7 分钟前
2026-2032全球AI服务器连接器市场洞察:规模、竞争与趋势深度解析
人工智能
Elastic 中国社区官方博客11 分钟前
使用 Jina-VLM 小型多语言视觉语言模型来和图片对话
大数据·人工智能·elasticsearch·语言模型·自然语言处理·jina
罗西的思考12 分钟前
【OpenClaw】通过 Nanobot 源码学习架构---(6)Skills
人工智能·深度学习·算法
uzong12 分钟前
软件人员可以关注的 Skill,亲测确实不错,值得试一下
人工智能·后端
志栋智能13 分钟前
超自动化巡检:实现运维“事前预防”的关键拼图
大数据·运维·网络·人工智能·机器学习·自动化
枫叶林FYL15 分钟前
【自然语言处理 NLP】7.2 红队测试与对抗鲁棒性(Red Teaming & Adversarial Robustness)
人工智能·算法·机器学习