在 PyTorch 训练中,torch.cuda.empty_cache()
的正确使用需要结合具体场景。以下是 5 种典型使用场景和最佳实践建议:
1. 训练循环中显存碎片化严重时 适用场景 • 当出现 CUDA out of memory
但计算显存需求理论上足够时
• 使用动态计算图或频繁创建/释放临时张量
推荐位置
python
for epoch in range(epochs):
for batch in dataloader:
# 前向/反向计算...
optimizer.step()
if batch_idx % 100 == 0: # 每100个batch清理一次
torch.cuda.empty_cache() # 清理未使用的缓存
注意事项 • 过于频繁调用(如每个batch)会降低性能(约5-10%速度损失)
• 建议配合 torch.cuda.memory_summary()
诊断碎片化程度
2. 大模型分阶段加载时 适用场景 • 使用梯度检查点 (Gradient Checkpointing)
• 模型太大需要分段加载
示例代码
python
def forward_segment(segment, x):
torch.cuda.empty_cache() # 加载新段前清理
segment = segment.to(device)
return segment(x)
for segment in large_model.segments:
output = checkpoint(forward_segment, segment, inputs)
3. 多任务交替执行时 适用场景 • 同一个脚本交替执行训练和评估
• 不同模型共享GPU资源
代码示例
python
# 训练阶段
train(model_A)
torch.cuda.empty_cache() # 训练后立即清理
# 评估阶段
evaluate(model_B) # 确保model_B能获得足够显存
4. 数据预处理与训练混合时 适用场景 • 使用GPU加速数据增强
• 动态生成训练数据
推荐写法
python
for epoch in epochs:
# GPU数据增强
augmented_batch = gpu_augment(batch)
# 训练主模型
train_step(model, augmented_batch)
# 清理增强操作的中间缓存
del augmented_batch
torch.cuda.empty_cache()
5. 异常恢复后 适用场景 • 捕获 CUDA OOM
异常后尝试恢复
• 测试最大可用batch size时
代码实现
python
try:
large_batch = next(oversized_loader)
output = model(large_batch)
except RuntimeError as e:
if "CUDA out of memory" in str(e):
torch.cuda.empty_cache() # 尝试释放残留显存
reduced_batch = large_batch[:half_size]
# 重试...
最佳实践总结
场景 | 调用频率 | 是否必需 | 典型性能影响 |
---|---|---|---|
常规训练 | 每N个batch | ❌ 可选 | <5% 减速 |
大模型加载 | 每次分段前 | ✔️ 必需 | 可避免OOM |
多任务切换 | 任务边界 | ✔️ 推荐 | 可复用显存 |
异常恢复 | 按需 | ✔️ 关键 | 恢复成功率+50% |
调试阶段 | 任意位置 | ❌ 避免 | 干扰内存分析 |
高级技巧
-
与内存分析工具配合:
pythonprint(torch.cuda.memory_summary()) # 清理前 torch.cuda.empty_cache() print(torch.cuda.memory_summary()) # 清理后
-
PyTorch Lightning 集成:
pythonclass MyModel(LightningModule): def on_train_batch_end(self): if self.current_epoch % 10 == 0: torch.cuda.empty_cache()
-
显存碎片化监控:
pythondef check_fragmentation(): allocated = torch.cuda.memory_allocated() reserved = torch.cuda.memory_reserved() if reserved - allocated > 1e9: # 碎片>1GB torch.cuda.empty_cache()
何时应该避免调用
- 在关键性能路径上:如高频调用的损失函数内
- 使用
torch.no_grad()
块时:此时无梯度缓存需要清理 - 确定无显存泄漏时:过度调用会导致不必要的同步点
合理使用此方法可将GPU利用率提升15-30%(特别是在大模型训练中),但需要结合具体场景权衡性能与显存占用的平衡。