208_深度学习的鲁棒性之美:暂退法(Dropout)原理与实战

在训练大型神经网络时,神经元之间容易形成"共适应性"(Co-adaptation),即某些神经元过度依赖其他神经元,导致模型过度拟合训练数据的特定细节。暂退法通过在训练期间随机"关掉"一部分神经元,强迫每个神经元都能独立地学习有用的特征。

1. 什么是暂退法?

暂退法(Dropout)的核心操作是在层之间注入噪声:

  • 训练期间 :以概率 将隐藏单元置为 0。
  • 数学期望不变 :为了保证输出的期望值在训练和测试时一致,我们会对未被丢弃的单元进行缩放,除以
  • 测试期间:关闭 Dropout,使用完整的网络进行预测,以确保结果的确定性。

2. 为什么 Dropout 有效?

  • 集成学习视角:Dropout 每次训练都在训练一个不同的子网络。最终的模型可以看作是无数个共享参数的小模型的"平均组合"。
  • 防止过拟合:因为神经元不能依赖特定的邻居(因为邻居随时可能被丢弃),它们必须学到更加"稳健"的特征,从而提升了泛化能力。

3. 代码实战:在多层感知机中应用 Dropout

文件展示了如何在 PyTorch 中通过 nn.Dropout 层轻松实现这一功能。

Python

复制代码
import torch
from torch import nn
from d2l import torch as d2l

# 定义丢弃概率
dropout1, dropout2 = 0.2, 0.5

# 搭建网络结构
net = nn.Sequential(
    nn.Flatten(),
    nn.Linear(784, 256),
    nn.ReLU(),
    # 在第一个全连接层后添加 Dropout
    nn.Dropout(dropout1),
    nn.Linear(256, 256),
    nn.ReLU(),
    # 在第二个全连接层后添加 Dropout
    nn.Dropout(dropout2),
    nn.Linear(256, 10)
)

# 初始化权重
def init_weights(m):
    if type(m) == nn.Linear:
        nn.init.normal_(m.weight, std=0.01)

net.apply(init_weights)

# 设置超参数与训练
num_epochs, lr, batch_size = 10, 0.5, 256
loss = nn.CrossEntropyLoss()
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)
trainer = torch.optim.SGD(net.parameters(), lr=lr)

# 开始训练
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

4. 关键细节与最佳实践

① Dropout 放哪里?

通常将 Dropout 放在激活函数之后 ,下一层全连接层之前。在上面的代码中,它被放置在 nn.ReLU() 之后。

② 丢弃概率

如何设置?

  • 靠近输入层的层,通常设置较小的丢弃率(如 0.2),因为输入层包含原始信息。
  • 靠近输出层的隐藏层,可以设置较大的丢弃率(如 0.5)。

train()eval() 模式切换

这是使用 Dropout 时最容易犯错的地方:

  • 在训练时必须调用 net.train(),此时 Dropout 生效。
  • 在测试/验证时必须调用 net.eval(),此时 Dropout 自动失效,所有神经元共同参与计算。

5. 总结:正则化的"双剑合璧"

在实际工程中,我们往往同时使用 权重衰退(L2 正则化)暂退法(Dropout)

  1. 权重衰退:让权重保持细小且平滑。
  2. 暂退法:让网络结构保持稀疏且健壮。
相关推荐
SeatuneWrite2 小时前
AI仿真人剧供应商2025推荐,高效内容创作与分发解决方案
人工智能·python
小草cys2 小时前
review202604032342
开发语言·php
数智工坊2 小时前
【深度学习基础】Focal Loss、Dice Loss、组合损失函数
人工智能·深度学习
一只小阿乐2 小时前
js流式模式输出 函数模式使用
开发语言·javascript·ai·vue·agent·流式数据·node 服务
伯远医学2 小时前
如何判断提取的RNA是否可用?
java·开发语言·前端·javascript·人工智能·eclipse·创业创新
ATMQuant2 小时前
以AI量化为生:20.实时图表交易系统开发
python·量化交易·实盘交易·vnpy·k线图表
搜狐技术产品小编20232 小时前
端侧Python动态算法策略的部署与运行
开发语言·python
cch89182 小时前
C++与PHP:7大核心差异全解析
java·开发语言
时光书签2 小时前
了解脚本语言
python·bash·batch命令