torch.cuda.empty_cache()使用场景

在 PyTorch 训练中,torch.cuda.empty_cache() 的正确使用需要结合具体场景。以下是 5 种典型使用场景和最佳实践建议:


1. 训练循环中显存碎片化严重时 适用场景 • 当出现 CUDA out of memory 但计算显存需求理论上足够时

• 使用动态计算图或频繁创建/释放临时张量

推荐位置

python 复制代码
for epoch in range(epochs):
    for batch in dataloader:
        # 前向/反向计算...
        optimizer.step()
        
        if batch_idx % 100 == 0:  # 每100个batch清理一次
            torch.cuda.empty_cache()  # 清理未使用的缓存

注意事项 • 过于频繁调用(如每个batch)会降低性能(约5-10%速度损失)

• 建议配合 torch.cuda.memory_summary() 诊断碎片化程度


2. 大模型分阶段加载时 适用场景 • 使用梯度检查点 (Gradient Checkpointing)

• 模型太大需要分段加载

示例代码

python 复制代码
def forward_segment(segment, x):
    torch.cuda.empty_cache()  # 加载新段前清理
    segment = segment.to(device)
    return segment(x)

for segment in large_model.segments:
    output = checkpoint(forward_segment, segment, inputs)

3. 多任务交替执行时 适用场景 • 同一个脚本交替执行训练和评估

• 不同模型共享GPU资源

代码示例

python 复制代码
# 训练阶段
train(model_A)
torch.cuda.empty_cache()  # 训练后立即清理

# 评估阶段
evaluate(model_B)  # 确保model_B能获得足够显存

4. 数据预处理与训练混合时 适用场景 • 使用GPU加速数据增强

• 动态生成训练数据

推荐写法

python 复制代码
for epoch in epochs:
    # GPU数据增强
    augmented_batch = gpu_augment(batch)  
    
    # 训练主模型
    train_step(model, augmented_batch)
    
    # 清理增强操作的中间缓存
    del augmented_batch
    torch.cuda.empty_cache()

5. 异常恢复后 适用场景 • 捕获 CUDA OOM 异常后尝试恢复

• 测试最大可用batch size时

代码实现

python 复制代码
try:
    large_batch = next(oversized_loader)
    output = model(large_batch)
except RuntimeError as e:
    if "CUDA out of memory" in str(e):
        torch.cuda.empty_cache()  # 尝试释放残留显存
        reduced_batch = large_batch[:half_size]
        # 重试...

最佳实践总结

场景 调用频率 是否必需 典型性能影响
常规训练 每N个batch ❌ 可选 <5% 减速
大模型加载 每次分段前 ✔️ 必需 可避免OOM
多任务切换 任务边界 ✔️ 推荐 可复用显存
异常恢复 按需 ✔️ 关键 恢复成功率+50%
调试阶段 任意位置 ❌ 避免 干扰内存分析

高级技巧

  1. 与内存分析工具配合:

    python 复制代码
    print(torch.cuda.memory_summary())  # 清理前
    torch.cuda.empty_cache()
    print(torch.cuda.memory_summary())  # 清理后
  2. PyTorch Lightning 集成:

    python 复制代码
    class MyModel(LightningModule):
        def on_train_batch_end(self):
            if self.current_epoch % 10 == 0:
                torch.cuda.empty_cache()
  3. 显存碎片化监控:

    python 复制代码
    def check_fragmentation():
        allocated = torch.cuda.memory_allocated()
        reserved = torch.cuda.memory_reserved()
        if reserved - allocated > 1e9:  # 碎片>1GB
            torch.cuda.empty_cache()

何时应该避免调用

  1. 在关键性能路径上:如高频调用的损失函数内
  2. 使用 torch.no_grad() 块时:此时无梯度缓存需要清理
  3. 确定无显存泄漏时:过度调用会导致不必要的同步点

合理使用此方法可将GPU利用率提升15-30%(特别是在大模型训练中),但需要结合具体场景权衡性能与显存占用的平衡。

相关推荐
火星资讯3 分钟前
“兴火·燎原”总冠军诞生,云宏信息《金融高算力轻量云平台》登顶
人工智能·科技
whaosoft-14319 分钟前
51c自动驾驶~合集37
人工智能
小技工丨26 分钟前
详解大语言模型生态系统概念:lama,llama.cpp,HuggingFace 模型 ,GGUF,MLX,lm-studio,ollama这都是什么?
人工智能·语言模型·llama
陈奕昆29 分钟前
大模型微调之LLaMA-Factory 系列教程大纲
人工智能·llama·大模型微调·llama-factory
上海云盾商务经理杨杨1 小时前
AI如何重塑DDoS防护行业?六大变革与未来展望
人工智能·安全·web安全·ddos
一刀到底2111 小时前
ai agent(智能体)开发 python3基础8 网页抓取中 selenium 和 Playwright 区别和联系
人工智能·python
每天都要写算法(努力版)1 小时前
【神经网络与深度学习】改变随机种子可以提升模型性能?
人工智能·深度学习·神经网络
烟锁池塘柳01 小时前
【计算机视觉】三种图像质量评价指标详解:PSNR、SSIM与SAM
人工智能·深度学习·计算机视觉
小森77672 小时前
(六)机器学习---聚类与K-means
人工智能·机器学习·数据挖掘·scikit-learn·kmeans·聚类
RockLiu@8052 小时前
探索PyTorch中的空间与通道双重注意力机制:实现concise的scSE模块
人工智能·pytorch·python