【深度学习5】多层感知机

1、概念

  • 将许多全连接层堆叠在一起。 每一层都输出到上面的层,直到生成最后的输出。 我们可以把前L-1层看作表示,把最后一层看作线性预测器。 这种架构通常称为多层感知机。
  • 活性值:使用仿射函数进行放射变换得到的结果,再对这个结果应用激活函数,此时得到的结果是活性值。
  • 激活函数对仿射变换的结果中的元素逐个进行处理,每个元素的输出只依赖自身,和其他元素无关。激活函数(ReLU、Sigmoid、Tanh、GELU 等)是 "按元素独立运算" 的,所以计算每个隐藏单元的活性值时,不需要依赖其他隐藏单元的中间结果。
  • 隐藏层的计算分为两步:仿射变换(z=Wx+b)→ 激活变换(a=σ(z))。
  • 隐藏层神经元会传递活性值,比如一个隐藏层神经元传递一个活性值,代表某一个特征,当隐藏层神经元数量足够多时,代表的特征也足够多,可以覆盖任意复杂函数的所有细节,就可以逼近任意函数。
  • 近似定理核心逻辑:每个神经元对应一个 "局部特征 / 片段",足够多的神经元就能覆盖任意连续函数的所有局部细节,再通过线性组合拼接,就能以任意精度逼近目标函数------ 这就是通用近似定理的核心思想。
  • 隐藏层神经元传递的信息:表层是 "活性值 a"(0~1 或其他范围的数值),深层是 "输入是否包含该神经元关注的特征" 的筛选结果(a 越大,特征越明显)。
  • 激活函数的角色:不是传递的信息,而是神经元 "产生传递信息" 的工具 ------ 对仿射变换的结果做非线性筛选,让神经元能输出 "有意义的特征信号"。
  • 修正线性单元(Rectified linear unit,ReLU):ReLU函数通过将相应的活性值设为0,仅保留正元素并丢弃所有负元素。
  • 挤压函数(squashing function,sigmoid):sigmoid函数将输入变换为区间(0, 1)上的输出。
  • tanh(双曲正切)函数:将其输入压缩转换到区间(-1, 1)上。

2、多层感知机的从零开始实现

python 复制代码
import matplotlib
matplotlib.use('TkAgg')  # 保证 PyCharm 可以显示图像
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from d2l import torch as d2l

# ================== 数据加载 ==================
batch_size = 256
transform = transforms.ToTensor()
train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)
train_iter = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_iter = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

# ================== 模型参数 ==================
num_inputs, num_outputs, num_hiddens = 784, 10, 256

W1 = nn.Parameter(torch.randn(num_inputs, num_hiddens) * 0.01, requires_grad=True)
b1 = nn.Parameter(torch.zeros(num_hiddens, requires_grad=True))
W2 = nn.Parameter(torch.randn(num_hiddens, num_outputs) * 0.01, requires_grad=True)
b2 = nn.Parameter(torch.zeros(num_outputs, requires_grad=True))
params = [W1, b1, W2, b2]

# ================== 激活函数和模型 ==================
def relu(X):
    return torch.max(X, torch.zeros_like(X))

def net(X):
    X = X.reshape((-1, num_inputs))
    H = relu(X @ W1 + b1)
    return H @ W2 + b2

# ================== 损失函数与优化器 ==================
loss = nn.CrossEntropyLoss(reduction='none')
lr, num_epochs = 0.1, 10
updater = torch.optim.SGD(params, lr=lr)

# ================== 训练 ==================
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, updater)

# ================== 预测可视化 ==================
X, y = next(iter(test_iter))
true_labels = [d2l.get_fashion_mnist_labels([int(i)]) for i in y[:8]]
pred_labels = [d2l.get_fashion_mnist_labels([int(i)]) for i in net(X[:8]).argmax(dim=1)]

titles = [t[0] + '\n' + p[0] for t, p in zip(true_labels, pred_labels)]
d2l.show_images(X[:8].reshape(8, 28, 28), 1, 8, titles=titles)
plt.show()  # 显示图像(PyCharm 需要显式调用)

3、多层感知机的简洁实现

python 复制代码
import matplotlib
matplotlib.use('TkAgg')  # ✅ 让 PyCharm 可以显示图像窗口
import matplotlib.pyplot as plt
import torch
from torch import nn
from d2l import torch as d2l

# ================== 模型 ==================
net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Linear(256, 10)
)

def init_weights(m):
    if isinstance(m, nn.Linear):
        nn.init.normal_(m.weight, std=0.01)
net.apply(init_weights)

# ================== 超参数 ==================
batch_size, lr, num_epochs = 256, 0.1, 10
loss = nn.CrossEntropyLoss(reduction='none')
trainer = torch.optim.SGD(net.parameters(), lr=lr)

# ================== 数据加载 ==================
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

# ================== 训练 ==================
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

# ================== 预测(带图像显示) ==================
d2l.predict_ch3(net, test_iter)
plt.show()  # ✅ 关键:让图像在 PyCharm 弹出
相关推荐
中國龍在廣州10 分钟前
现在人工智能的研究路径可能走反了
人工智能·算法·搜索引擎·chatgpt·机器人
攻城狮7号20 分钟前
小米具身大模型 MiMo-Embodied 发布并全面开源:统一机器人与自动驾驶
人工智能·机器人·自动驾驶·开源大模型·mimo-embodied·小米具身大模型
搜移IT科技24 分钟前
【无标题】2025ARCE亚洲机器人大会暨展览会将带来哪些新技术与新体验?
人工智能
信也科技布道师FTE44 分钟前
当AMIS遇见AI智能体:如何为低代码开发装上“智慧大脑”?
人工智能·低代码·llm
青瓷程序设计1 小时前
植物识别系统【最新版】Python+TensorFlow+Vue3+Django+人工智能+深度学习+卷积神经网络算法
人工智能·python·深度学习
AI即插即用1 小时前
即插即用系列 | CVPR 2025 WPFormer:用于表面缺陷检测的查询式Transformer
人工智能·深度学习·yolo·目标检测·cnn·视觉检测·transformer
唐兴通个人1 小时前
数字化AI大客户营销TOB营销客户开发专业销售技巧培训讲师培训师唐兴通老师分享AI销冠人工智能销售AI赋能销售医药金融工业品制造业
人工智能·金融
人机与认知实验室2 小时前
国内主流大语言模型之比较
人工智能·语言模型·自然语言处理
T0uken2 小时前
【Python】UV:境内的深度学习环境搭建
人工智能·深度学习·uv
七宝大爷2 小时前
基于人类反馈的强化学习(RLHF):ChatGPT“对齐”人类的秘密武器
人工智能·chatgpt