在前面的内容Pytorch深入浅出Ⅶ之优化器(Optimizer)中,我们已经系统地理解了 Optimizer 的工作机制,但在真实训练中,仅仅"会用 Optimizer"远远不够。
随着模型变深、Batch 变大、训练周期变长,我们会遇到一些并非算法本身、而是工程实现上的挑战。这些问题往往不需要更换 Optimizer,而是要在使用阶段引入一些进阶技巧。
本篇介绍三种工程中非常常见、且稳定有效的进阶技巧:梯度裁剪、梯度防护与梯度累积。
一、梯度裁剪(Gradient Clipping)
1. 为什么需要梯度裁剪?
在深层网络(如 RNN 或 Transformer)中,梯度通过链式法则逐层相乘。如果梯度值较大,结果会呈指数级增长,导致梯度爆炸。
2. 核心思想
保持方向,缩减步长
最常用的是 按梯度范数裁剪 (Norm-based Clipping) :如果梯度的整体范数(L2 Norm)超过了阈值,就按比例缩小所有梯度。
g = g × max _ n o r m max ( grad_norm , max _ n o r m ) g = g \times \frac{\max\_norm}{\max(\text{grad\_norm}, \max\_norm)} g=g×max(grad_norm,max_norm)max_norm
3. PyTorch 中的标准用法
python
loss.backward() # ① 先计算梯度
# # max_norm 通常设为 1.0, 5.0 或 10.0
torch.nn.utils.clip_grad_norm_(model.parameters(),max_norm=1.0) # ② 裁剪梯度
optimizer.step() # ③ 再更新参数
⚠️ 注意 :梯度裁剪必须在 backward() 之后、optimizer.step() 之前执行。
4. 适用场景
专业级代码通常会默认开启梯度裁剪,即使训练看上去很稳定。它可以让优化路径更平滑,显著提升训练的稳健性。
👉 形象比喻:允许你往这个方向走,但不允许你一步迈得太跨格,防止"扯着蛋"。
二、梯度防护( NaN / Inf )
1. 问题来源?
训练过程中,梯度突然变成 NaN(非数)或 Inf(无穷大)会导致参数瞬间被"污染",模型几乎无法救回。这通常源于数值不稳定操作(如 log ( 0 ) \log(0) log(0))、数据异常或混合精度训练中的溢出。
2. 核心思想
在执行更新前,检查梯度是否合法。如果梯度非法,则放弃本次更新,保护参数不被污染。
3. 实现方式
python
def has_invalid_grad(model):
for p in model.parameters():
if p.grad is not None:
if torch.isnan(p.grad).any() or torch.isinf(p.grad).any():
return True
return False
# -------- 训练循环中使用 --------
loss.backward()
if not has_invalid_grad(model):
optimizer.step()
else:
print("Warning: NaN/Inf detected in gradients. Skipping this step.")
optimizer.zero_grad() # 无论是否更新,最后都要清零
4. 适用场景
- 使用 AMP(混合精度训练)
- 调试新设计的复杂网络结构时
- 训练过程中偶发崩溃、且难以通过调整学习率解决时
在使用 torch.cuda.amp.GradScaler 时,PyTorch 已内置部分 NaN/Inf 检测逻辑,但在自定义训练流程或非 AMP 场景下,显式防护仍然非常有价值。
👉 形象比喻 :这相当于给训练过程加了一根保险丝 。这一招不提升上限,但极大提高下限
三、梯度累积(Gradient Accumulation)
1. 为什么需要梯度累积?
理想状态下,大 Batch Size 能带来更稳定的梯度估计。但现实中:
- 显存受限:显卡塞不下大 Batch。
- Batch Size 太小:导致梯度波动剧烈,模型难以收敛。
梯度累积提供了一种"曲线救国"的思路:用时间换空间,通过多个小 Batch 模拟一个大 Batch
2. 梯度累积的核心思想
PyTorch 的梯度默认是累加 的。 当调用 loss.backward() 时,计算出的梯度会加到 param.grad 上。我们可以利用这个特性,连续多次计算梯度而不清零,最后统一更新。
text
.grad ← .grad + 当前 batch 的梯度
我们连续计算 N N N 个小 Batch 的梯度而不清零,最后统一更新,效果上等价于执行了一个 N N N 倍大的 Batch。
3. 标准实现方式
假设我们想用 accumulation_steps = 4,模拟 4 倍 batch size:
python
accumulation_steps = 4
optimizer.zero_grad()
for step, (inputs, labels) in enumerate(train_loader):
outputs = model(inputs)
loss = criterion(outputs, labels)
# 关键:梯度是累加的,所以 Loss 必须除以步数来求平均
loss = loss / accumulation_steps # 求梯度的平均值,而不是求梯度的总和
loss.backward()
# 每累积 4 步更新一次
if (step + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
4. 典型使用场景
- 大模型训练:如 BERT、GPT 等,单卡只能跑 Batch=1 或 2,必须靠累积。
- 显存捉襟见肘:当你想尝试大 Batch 的效果却买不起更多显卡时。
👉 形象比喻:梯度累积是穷苦炼丹师的"空间折叠术"。
| 技巧 | 解决的问题 | 核心作用 |
|---|---|---|
| 梯度裁剪 | 梯度爆炸 | 限制更新步长,平滑优化路径 |
| NaN/Inf 防护 | 数值崩溃 | 充当保险丝,防止参数被污染 |
| 梯度累积 | 显存不足 | 时间换空间,模拟大 Batch 效果 |