torch.cuda.empty_cache()使用场景

在 PyTorch 训练中,torch.cuda.empty_cache() 的正确使用需要结合具体场景。以下是 5 种典型使用场景和最佳实践建议:


1. 训练循环中显存碎片化严重时 适用场景 • 当出现 CUDA out of memory 但计算显存需求理论上足够时

• 使用动态计算图或频繁创建/释放临时张量

推荐位置

python 复制代码
for epoch in range(epochs):
    for batch in dataloader:
        # 前向/反向计算...
        optimizer.step()
        
        if batch_idx % 100 == 0:  # 每100个batch清理一次
            torch.cuda.empty_cache()  # 清理未使用的缓存

注意事项 • 过于频繁调用(如每个batch)会降低性能(约5-10%速度损失)

• 建议配合 torch.cuda.memory_summary() 诊断碎片化程度


2. 大模型分阶段加载时 适用场景 • 使用梯度检查点 (Gradient Checkpointing)

• 模型太大需要分段加载

示例代码

python 复制代码
def forward_segment(segment, x):
    torch.cuda.empty_cache()  # 加载新段前清理
    segment = segment.to(device)
    return segment(x)

for segment in large_model.segments:
    output = checkpoint(forward_segment, segment, inputs)

3. 多任务交替执行时 适用场景 • 同一个脚本交替执行训练和评估

• 不同模型共享GPU资源

代码示例

python 复制代码
# 训练阶段
train(model_A)
torch.cuda.empty_cache()  # 训练后立即清理

# 评估阶段
evaluate(model_B)  # 确保model_B能获得足够显存

4. 数据预处理与训练混合时 适用场景 • 使用GPU加速数据增强

• 动态生成训练数据

推荐写法

python 复制代码
for epoch in epochs:
    # GPU数据增强
    augmented_batch = gpu_augment(batch)  
    
    # 训练主模型
    train_step(model, augmented_batch)
    
    # 清理增强操作的中间缓存
    del augmented_batch
    torch.cuda.empty_cache()

5. 异常恢复后 适用场景 • 捕获 CUDA OOM 异常后尝试恢复

• 测试最大可用batch size时

代码实现

python 复制代码
try:
    large_batch = next(oversized_loader)
    output = model(large_batch)
except RuntimeError as e:
    if "CUDA out of memory" in str(e):
        torch.cuda.empty_cache()  # 尝试释放残留显存
        reduced_batch = large_batch[:half_size]
        # 重试...

最佳实践总结

场景 调用频率 是否必需 典型性能影响
常规训练 每N个batch ❌ 可选 <5% 减速
大模型加载 每次分段前 ✔️ 必需 可避免OOM
多任务切换 任务边界 ✔️ 推荐 可复用显存
异常恢复 按需 ✔️ 关键 恢复成功率+50%
调试阶段 任意位置 ❌ 避免 干扰内存分析

高级技巧

  1. 与内存分析工具配合:

    python 复制代码
    print(torch.cuda.memory_summary())  # 清理前
    torch.cuda.empty_cache()
    print(torch.cuda.memory_summary())  # 清理后
  2. PyTorch Lightning 集成:

    python 复制代码
    class MyModel(LightningModule):
        def on_train_batch_end(self):
            if self.current_epoch % 10 == 0:
                torch.cuda.empty_cache()
  3. 显存碎片化监控:

    python 复制代码
    def check_fragmentation():
        allocated = torch.cuda.memory_allocated()
        reserved = torch.cuda.memory_reserved()
        if reserved - allocated > 1e9:  # 碎片>1GB
            torch.cuda.empty_cache()

何时应该避免调用

  1. 在关键性能路径上:如高频调用的损失函数内
  2. 使用 torch.no_grad() 块时:此时无梯度缓存需要清理
  3. 确定无显存泄漏时:过度调用会导致不必要的同步点

合理使用此方法可将GPU利用率提升15-30%(特别是在大模型训练中),但需要结合具体场景权衡性能与显存占用的平衡。

相关推荐
ACCELERATOR_LLC6 分钟前
【DataWhale组队学习】DIY-LLM Task2 PyTorch 与资源核算
人工智能·pytorch·深度学习·大模型
Elastic 中国社区官方博客29 分钟前
Elastic Security、Observability 和 Search 现在在你的 AI 工具中提供交互式 UI
大数据·运维·人工智能·elasticsearch·搜索引擎·安全威胁分析·可用性测试
一碗白开水一35 分钟前
【目标跟踪综述】目标跟踪近3年技术研究,全面了解目标跟踪发展
人工智能·计算机视觉·目标跟踪
Promise微笑1 小时前
AI搜索时代的流量重构:GEO优化深度执行细节与把控体系
人工智能·重构
言萧凡_CookieBoty1 小时前
比 Vibe Coding 更可怕的,是 Vibe Design 吧
人工智能·ai编程
Rick19931 小时前
Spring AI 如何进行权限控制
人工智能·python·spring
Theodore_10221 小时前
深度学习(15):倾斜数据集 & 精确率-召回率权衡
人工智能·笔记·深度学习·机器学习·知识图谱
IT_陈寒1 小时前
SpringBoot自动配置这破玩意儿又坑我一次
前端·人工智能·后端
TechubNews2 小时前
Base 发布首个独立 OP Stack 框架的网络升级 Azul,将是 L2 自主迭代的开端?
大数据·网络·人工智能·区块链·能源
啦啦啦_99992 小时前
1.机器学习概述
人工智能·机器学习