Dropout 和 BatchNorm 在训练和验证中的差异

文章目录

    • [1. Dropout](#1. Dropout)
      • [1.1 作用](#1.1 作用)
      • [1.2 训练和验证的差异](#1.2 训练和验证的差异)
      • [1.3 示例](#1.3 示例)
    • [2. Batch Normalization (BatchNorm)](#2. Batch Normalization (BatchNorm))
      • [2.1 作用](#2.1 作用)
      • [2.2 训练和验证时的差异](#2.2 训练和验证时的差异)
      • [2.3 示例](#2.3 示例)
    • [3. 总结](#3. 总结)
    • [4. 实际使用建议](#4. 实际使用建议)

在神经网络中,Dropout 和 Batch Normalization (BatchNorm) 是常见的层,其行为在 训练阶段 和 验证阶段(推理阶段) 是不同的。这种差异的原因是它们在两个阶段处理数据的方式不同,以适应训练和推理的需求。

1. Dropout

1.1 作用

  • Dropout 是一种正则化方法,用于防止过拟合。
  • 它通过在训练过程中随机"丢弃"一部分神经元(即将它们的输出置为 0)来增加模型的鲁棒性。

1.2 训练和验证的差异

  • 训练阶段: 随机丢弃部分神经元,按照设定的概率 p p p(比如0.5), 使某些神经元的输出置为0。但会通过放神经元的输出(即乘以 1 1 − p \frac {1} {1-p} 1−p1),补偿训练阶段丢弃部分神经元导致的输出缩减,从而确保输出的一致性。
  • 在验证阶段:,不再丢弃神经元,保留所有神经元的输出

1.3 示例

py 复制代码
import torch
import torch.nn as nn

dropout = nn.Dropout(p=0.5)

# Training phase
dropout.train()  # 启用训练模式
x_train = torch.ones(5)  # 输入为全1
output_train = dropout(x_train)  # 部分输出会被置为0

# Validation phase
dropout.eval()  # 启用验证模式
x_val = torch.ones(5)  # 输入为全1
output_val = dropout(x_val)  # 所有输出保持不变,但被缩放
  • 输出
shell 复制代码
output_train tensor([2., 2., 0., 2., 0.])
output_val tensor([1., 1., 1., 1., 1.])

2. Batch Normalization (BatchNorm)

2.1 作用

  • BatchNorm 用于加速训练,解决梯度消失和梯度爆炸的问题。
  • 它通过对每个mini-batch的数据进行归一化(使输出具有零均值和单位方差)来实现稳定的训练过程。

2.2 训练和验证时的差异

原理

  • 在训练过程中,BatchNorm 会计算每个 mini-batch 的均值和方差,同时更新全局的移动平均值(moving mean)和移动方差(moving variance)。
  • 在验证阶段,为了避免小批量数据引入偏差,直接使用训练阶段保存的全局统计信息进行归一化。

2.3 示例

py 复制代码
import torch
import torch.nn as nn

batchnorm = nn.BatchNorm1d(num_features=5)

# Training phase
batchnorm.train()  # 启用训练模式
x_train = torch.rand(10, 5)  # 随机生成输入
output_train = batchnorm(x_train)  # 使用 mini-batch 均值和方差进行归一化

# Validation phase
batchnorm.eval()  # 启用验证模式
x_val = torch.rand(10, 5)  # 随机生成输入
output_val = batchnorm(x_val)  # 使用全局的 moving mean 和 moving variance

3. 总结

注意事项:

  • 在验证或推理阶段,必须调用 model.eval(),否则 Dropout 和 BatchNorm 的行为会与训练阶段一致,导致验证结果或推理结果不正确。
  • 如果模型中没有 Dropout 或 BatchNorm,则 model.eval() 不会改变模型的行为。

4. 实际使用建议

典型推理代码

py 复制代码
model.eval()  # 切换到验证模式
with torch.no_grad():  # 关闭梯度计算
    output = model(input_tensor)  # 推理

训练代码

py 复制代码
model.train()  # 切换到训练模式
output = model(input_tensor)  # 进行前向传播
loss = loss_fn(output, target)  # 计算损失
loss.backward()  # 反向传播
optimizer.step()  # 更新参数
相关推荐
英码科技1 小时前
昇腾系列双处理边缘计算盒子DA500I,打造高效低延迟的视觉推理解决方案
人工智能·边缘计算
SEVEN-YEARS1 小时前
深入理解BERT模型:BertModel类详解
人工智能·深度学习·自然语言处理·bert
weixin_543662861 小时前
BERT的中文问答系统34
python·深度学习·bert
Matrix_111 小时前
论文阅读:Uni-ISP Unifying the Learning of ISPs from Multiple Cameras
人工智能·计算摄影
马甲是掉不了一点的<.<1 小时前
2021TCSVT,VDM-DA:面向无源数据域自适应的虚拟域建模
深度学习·计算机视觉·无源域风格图像生成·无源域适应
DisonTangor1 小时前
TableGPT2-7B:用于表格数据分析的大规模解码器模型
人工智能·数据挖掘·数据分析
Altair澳汰尔2 小时前
数据分析丨世界杯冠军猜想:EA 体育游戏模拟能成功预测吗?
人工智能·游戏·数据挖掘·数据分析
噜噜噜噜鲁先森2 小时前
零基础利用实战项目学会Pytorch
人工智能·pytorch·python·深度学习·神经网络·算法·回归
程序员陆通2 小时前
Streamlit + AI大模型API实现视频字幕提取
人工智能·音视频