Dropout 和 BatchNorm 在训练和验证中的差异

文章目录

    • [1. Dropout](#1. Dropout)
      • [1.1 作用](#1.1 作用)
      • [1.2 训练和验证的差异](#1.2 训练和验证的差异)
      • [1.3 示例](#1.3 示例)
    • [2. Batch Normalization (BatchNorm)](#2. Batch Normalization (BatchNorm))
      • [2.1 作用](#2.1 作用)
      • [2.2 训练和验证时的差异](#2.2 训练和验证时的差异)
      • [2.3 示例](#2.3 示例)
    • [3. 总结](#3. 总结)
    • [4. 实际使用建议](#4. 实际使用建议)

在神经网络中,Dropout 和 Batch Normalization (BatchNorm) 是常见的层,其行为在 训练阶段 和 验证阶段(推理阶段) 是不同的。这种差异的原因是它们在两个阶段处理数据的方式不同,以适应训练和推理的需求。

1. Dropout

1.1 作用

  • Dropout 是一种正则化方法,用于防止过拟合。
  • 它通过在训练过程中随机"丢弃"一部分神经元(即将它们的输出置为 0)来增加模型的鲁棒性。

1.2 训练和验证的差异

  • 训练阶段: 随机丢弃部分神经元,按照设定的概率 p p p(比如0.5), 使某些神经元的输出置为0。但会通过放神经元的输出(即乘以 1 1 − p \frac {1} {1-p} 1−p1),补偿训练阶段丢弃部分神经元导致的输出缩减,从而确保输出的一致性。
  • 在验证阶段:,不再丢弃神经元,保留所有神经元的输出

1.3 示例

py 复制代码
import torch
import torch.nn as nn

dropout = nn.Dropout(p=0.5)

# Training phase
dropout.train()  # 启用训练模式
x_train = torch.ones(5)  # 输入为全1
output_train = dropout(x_train)  # 部分输出会被置为0

# Validation phase
dropout.eval()  # 启用验证模式
x_val = torch.ones(5)  # 输入为全1
output_val = dropout(x_val)  # 所有输出保持不变,但被缩放
  • 输出
shell 复制代码
output_train tensor([2., 2., 0., 2., 0.])
output_val tensor([1., 1., 1., 1., 1.])

2. Batch Normalization (BatchNorm)

2.1 作用

  • BatchNorm 用于加速训练,解决梯度消失和梯度爆炸的问题。
  • 它通过对每个mini-batch的数据进行归一化(使输出具有零均值和单位方差)来实现稳定的训练过程。

2.2 训练和验证时的差异

原理

  • 在训练过程中,BatchNorm 会计算每个 mini-batch 的均值和方差,同时更新全局的移动平均值(moving mean)和移动方差(moving variance)。
  • 在验证阶段,为了避免小批量数据引入偏差,直接使用训练阶段保存的全局统计信息进行归一化。

2.3 示例

py 复制代码
import torch
import torch.nn as nn

batchnorm = nn.BatchNorm1d(num_features=5)

# Training phase
batchnorm.train()  # 启用训练模式
x_train = torch.rand(10, 5)  # 随机生成输入
output_train = batchnorm(x_train)  # 使用 mini-batch 均值和方差进行归一化

# Validation phase
batchnorm.eval()  # 启用验证模式
x_val = torch.rand(10, 5)  # 随机生成输入
output_val = batchnorm(x_val)  # 使用全局的 moving mean 和 moving variance

3. 总结

注意事项:

  • 在验证或推理阶段,必须调用 model.eval(),否则 Dropout 和 BatchNorm 的行为会与训练阶段一致,导致验证结果或推理结果不正确。
  • 如果模型中没有 Dropout 或 BatchNorm,则 model.eval() 不会改变模型的行为。

4. 实际使用建议

典型推理代码

py 复制代码
model.eval()  # 切换到验证模式
with torch.no_grad():  # 关闭梯度计算
    output = model(input_tensor)  # 推理

训练代码

py 复制代码
model.train()  # 切换到训练模式
output = model(input_tensor)  # 进行前向传播
loss = loss_fn(output, target)  # 计算损失
loss.backward()  # 反向传播
optimizer.step()  # 更新参数
相关推荐
逛逛GitHub1 小时前
飞书多维表“独立”了!功能强大的超出想象。
人工智能·github·产品
机器之心1 小时前
刚刚,DeepSeek-R1论文登上Nature封面,通讯作者梁文锋
人工智能·openai
CoovallyAIHub2 小时前
港大&字节重磅发布DanceGRPO:突破视觉生成RLHF瓶颈,多项任务性能提升超180%!
深度学习·算法·计算机视觉
CoovallyAIHub3 小时前
英伟达ViPE重磅发布!解决3D感知难题,SLAM+深度学习完美融合(附带数据集下载地址)
深度学习·算法·计算机视觉
aneasystone本尊4 小时前
学习 Chat2Graph 的知识库服务
人工智能
IT_陈寒4 小时前
Redis 性能翻倍的 7 个冷门技巧,第 5 个大多数人都不知道!
前端·人工智能·后端
飞哥数智坊14 小时前
GPT-5-Codex 发布,Codex 正在取代 Claude
人工智能·ai编程
倔强青铜三14 小时前
苦练Python第46天:文件写入与上下文管理器
人工智能·python·面试
虫无涯15 小时前
Dify Agent + AntV 实战:从 0 到 1 打造数据可视化解决方案
人工智能
Dm_dotnet17 小时前
公益站Agent Router注册送200刀额度竟然是真的
人工智能