208_深度学习的鲁棒性之美:暂退法(Dropout)原理与实战

在训练大型神经网络时,神经元之间容易形成"共适应性"(Co-adaptation),即某些神经元过度依赖其他神经元,导致模型过度拟合训练数据的特定细节。暂退法通过在训练期间随机"关掉"一部分神经元,强迫每个神经元都能独立地学习有用的特征。

1. 什么是暂退法?

暂退法(Dropout)的核心操作是在层之间注入噪声:

  • 训练期间 :以概率 将隐藏单元置为 0。
  • 数学期望不变 :为了保证输出的期望值在训练和测试时一致,我们会对未被丢弃的单元进行缩放,除以
  • 测试期间:关闭 Dropout,使用完整的网络进行预测,以确保结果的确定性。

2. 为什么 Dropout 有效?

  • 集成学习视角:Dropout 每次训练都在训练一个不同的子网络。最终的模型可以看作是无数个共享参数的小模型的"平均组合"。
  • 防止过拟合:因为神经元不能依赖特定的邻居(因为邻居随时可能被丢弃),它们必须学到更加"稳健"的特征,从而提升了泛化能力。

3. 代码实战:在多层感知机中应用 Dropout

文件展示了如何在 PyTorch 中通过 nn.Dropout 层轻松实现这一功能。

Python

复制代码
import torch
from torch import nn
from d2l import torch as d2l

# 定义丢弃概率
dropout1, dropout2 = 0.2, 0.5

# 搭建网络结构
net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, 256),
    nn.ReLU(),
    # 在第一个全连接层后添加 Dropout
    nn.Dropout(dropout1),
    nn.Linear(256, 256),
    nn.ReLU(),
    # 在第二个全连接层后添加 Dropout
    nn.Dropout(dropout2),
    nn.Linear(256, 10)
)

# 初始化权重
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)

# 设置超参数与训练
num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss()
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(), lr=lr)

# 开始训练
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

4. 关键细节与最佳实践

① Dropout 放哪里?

通常将 Dropout 放在激活函数之后 ,下一层全连接层之前。在上面的代码中,它被放置在 nn.ReLU() 之后。

② 丢弃概率

如何设置?

  • 靠近输入层的层,通常设置较小的丢弃率(如 0.2),因为输入层包含原始信息。
  • 靠近输出层的隐藏层,可以设置较大的丢弃率(如 0.5)。

train()eval() 模式切换

这是使用 Dropout 时最容易犯错的地方:

  • 在训练时必须调用 net.train(),此时 Dropout 生效。
  • 在测试/验证时必须调用 net.eval(),此时 Dropout 自动失效,所有神经元共同参与计算。

5. 总结:正则化的"双剑合璧"

在实际工程中,我们往往同时使用 权重衰退(L2 正则化)暂退法(Dropout)

  1. 权重衰退:让权重保持细小且平滑。
  2. 暂退法:让网络结构保持稀疏且健壮。
相关推荐
基德爆肝c语言17 小时前
Qt—信号和槽
开发语言·qt
geovindu17 小时前
go:Decorator Pattern
开发语言·设计模式·golang·装饰器模式
故事和你9118 小时前
洛谷-算法2-1-前缀和、差分与离散化1
开发语言·数据结构·c++·算法·深度优先·动态规划·图论
励志的小陈1 天前
贪吃蛇(C语言实现,API)
c语言·开发语言
思绪无限1 天前
YOLOv5至YOLOv12升级:木材表面缺陷检测系统的设计与实现(完整代码+界面+数据集项目)
人工智能·深度学习·目标检测·计算机视觉·木材表面缺陷检测
kishu_iOS&AI1 天前
深度学习 —— 损失函数
人工智能·pytorch·python·深度学习·线性回归
Makoto_Kimur1 天前
java开发面试-AI Coding速成
java·开发语言
好运的阿财1 天前
OpenClaw工具拆解之canvas+message
人工智能·python·ai编程·openclaw·openclaw工具
laowangpython1 天前
Gurobi求解器Matlab安装配置教程
开发语言·其他·matlab
wengqidaifeng1 天前
python启航:1.基础语法知识
开发语言·python