FlashAttention的OOM排查:为什么显存够了还是报内存不足?

之前有个团队在昇腾NPU上跑Llama-2-7B,模型是FP16权重,seq_len=4096。他们算了算显存:模型权重13.5GB + 激活值4GB + KV Cache 4GB = 21.5GB,昇腾910有32GB显存,绰绰有余。

结果一跑就报OOM(Out Of Memory)。他们懵了:明明只用了21.5GB,32GB显存怎么就满了?

我帮他们排查了一下,发现问题出在FlashAttention的内存分配策略------它不是一次性分配所有显存,而是分块分配。每个分块在SRAM里处理完之后,要暂时存一个"中间结果",这个中间结果会占用额外的显存。

今天把这个内存分配的问题讲清楚,帮你排查OOM。

先打个比方:厨房的台面空间

想象你在厨房做饭,台面空间(SRAM)有限。你要做一道复杂的菜,需要很多步骤。

标准Attention的做法:先把所有食材摆在台面上(一次性分配所有显存),一次做完,最后清理台面。

FlashAttention的做法:每次只拿一小块食材到台面上(分块分配),做完这一步,放到一边(临时存储),再拿下一块。这样台面不用一直占满。

问题在哪? FlashAttention的分块策略虽然省了"总显存",但峰值显存可能更高------因为每个分块处理完的"临时结果"还在显存里,要等所有分块处理完才能释放。

FlashAttention的内存分配到底怎么回事?

FlashAttention在昇腾NPU上的内存分配分三层:

第一层:模型权重(静态分配)

模型权重是最稳定的显存占用,在整个推理/训练过程中都存在。

复制代码
Llama-2-7B FP16的权重分布:
  QKV投影:4096 × (3 × 4096) × 2 = 302 MB
  输出投影:4096 × 4096 × 2 = 32 MB
  FFN门控:4 × 4096 × 11008 × 2 = 361 MB
  FFN降维:11008 × 4096 × 2 = 90 MB
  词表投影:32000 × 4096 × 2 = 262 MB
  位置编码:513 × 4096 × 2 = 4 MB
  LayerNorm:4096 × 2 × 2 × 2 = 0.06 MB
  
  总计:~1 GB

32层堆叠 × 1 GB = ~13.5 GB(跟之前算的一致)。

第二层:KV Cache(动态分配)

KV Cache是推理时动态分配的。每个token生成完,它的K和V要存下来,供后续token的Attention用。

复制代码
KV Cache的计算:
  每个token的KV大小 = num_kv_heads × head_dim × 2(K+V)
                    = 32 × 128 × 2 = 8 KB
  seq_len=4096的KV Cache = 4096 × 8 KB = 32 MB(单层)
  32层的KV Cache = 32 × 32 MB = **1 GB**

第三层:FlashAttention的中间结果(最容易OOM的地方)

FlashAttention在分块处理的时候,会产生一些中间结果。这些中间结果不会立即释放,要等整个Attention层处理完才能释放。

复制代码
FlashAttention的中间结果(每层):
  Q分块缓冲:block_size × head_dim × 2 × num_heads = 128 × 128 × 2 × 32 = 1 MB
  K分块缓冲:block_size × head_dim × 2 × num_kv_heads = 128 × 128 × 2 × 32 = 1 MB
  V分块缓冲:同上 = 1 MB
  输出缓冲:block_size × head_dim × 2 × num_heads = 1 MB
  在线Softmax状态(m和l):block_size × 4 × 2 = 1 KB(可忽略)
  
  每层的中间结果总计:~4 MB
  32层的中间结果:32 × 4 = **128 MB**

等等,128MB看起来不多啊,为什么会OOM?

真正的问题:Gradient Checkpointing的中间结果

问题出在Gradient Checkpointing(激活重计算)------训练的时候,为了省显存,会重计算前向传播的中间结果。但FlashAttention的中间结果不在这个策略里!

python 复制代码
# Gradient Checkpointing的配置
model = GradientCheckpointing(model, checkpoint_ratio=0.5)

# FlashAttention的中间结果不在checkpoint策略里!
# 这些中间结果会一直占用显存,不会被释放

FlashAttention的中间结果不在Gradient Checkpointing的保护范围内,所以会占用额外的显存。

OOM的常见原因

原因1:batch_size太大

batch_size太大会导致KV Cache显存占用暴涨。

复制代码
batch_size=1:KV Cache = 1 GB
batch_size=4:KV Cache = 4 GB(线性增长)
batch_size=16:KV Cache = 16 GB
batch_size=32:KV Cache = 32 GB(爆炸)

原因2:seq_len太长

seq_len太长会导致KV Cache和Attention矩阵的中间结果都变大。

复制代码
seq_len=2048:KV Cache = 512 MB,中间结果 = 64 MB
seq_len=4096:KV Cache = 1 GB,中间结果 = 128 MB
seq_len=8192:KV Cache = 2 GB,中间结果 = 256 MB
seq_len=16384:KV Cache = 4 GB,中间结果 = 512 MB

原因3:没有开PagedAttention

PagedAttention能把KV Cache的显存利用率从34%提升到91%。不开PagedAttention,同样的显存,能跑的batch_size更小。

原因4:混合精度配置不对

如果开了FP16训练但BF16推理,或者反过来,昇腾NPU要做额外的精度转换,占用额外显存。

OOM排查清单

你的FlashAttention报OOM,按这个清单查:

python 复制代码
def diagnose_oom():
    """FlashAttention OOM诊断"""
    
    # 1. 检查模型权重大小
    total_params = sum(p.numel() for p in model.parameters())
    weight_mem = total_params * 2 / (1024 ** 3)  # FP16
    print(f"模型权重显存:{weight_mem:.2f} GB")
    
    # 2. 检查KV Cache大小
    kv_mem_per_layer = seq_len * num_kv_heads * head_dim * 2 * 2 / (1024 ** 2)  # FP16, MB
    total_kv_mem = kv_mem_per_layer * num_layers
    print(f"KV Cache显存(单层):{kv_mem_per_layer:.2f} MB")
    print(f"KV Cache显存(总):{total_kv_mem:.2f} GB")
    
    # 3. 检查中间结果大小
    intermediate_mem_per_layer = (
        block_size * head_dim * 2 * num_heads * 4 / (1024 ** 2)  # QKV输出缓冲
    )
    total_intermediate = intermediate_mem_per_layer * num_layers
    print(f"中间结果显存(单层):{intermediate_mem_per_layer:.2f} MB")
    print(f"中间结果显存(总):{total_intermediate:.2f} GB")
    
    # 4. 计算总显存
    total_mem = weight_mem + total_kv_mem + total_intermediate
    print(f"\n估算总显存:{total_mem:.2f} GB")
    print(f"可用显存:{torch.npu.get_device_properties(0).total_memory / (1024**3):.2f} GB")
    
    # 5. 判断
    if total_mem > torch.npu.get_device_properties(0).total_memory / (1024**3):
        print("\n❌ 显存不足!")
        print("建议:")
        print("  - 减小batch_size")
        print("  - 减小seq_len")
        print("  - 开PagedAttention")
        print("  - 用INT8 KV Cache量化")
    else:
        print(f"\n✅ 显存估算足够(剩余 {torch.npu.get_device_properties(0).total_memory / (1024**3) - total_mem:.2f} GB)")
        print("实际OOM可能是其他原因(内存碎片、驱动问题等)")

# 运行诊断
diagnose_oom()

解决OOM的方法

方法1:开PagedAttention

PagedAttention能把KV Cache的显存利用率从34%提升到91%。

python 复制代码
# vLLM配置
python -m vllm.entrypoints.openai.api_server \
  --model ./models/Llama-2-7b-chat-hf \
  --enable-flash-attn \
  --use-paged-attention \  # 开PagedAttention
  --max-num-seqs 32

方法2:用INT8 KV Cache量化

INT8量化能把KV Cache的显存减半。

python 复制代码
# vLLM配置
python -m vllm.entrypoints.openai.api_server \
  --model ./models/Llama-2-7b-chat-hf \
  --enable-flash-attn \
  --kv-cache-dtype int8 \  # INT8 KV Cache
  --max-num-seqs 32

方法3:减小batch_size

batch_size减半,KV Cache显存减半。

方法4:用Gradient Checkpointing(训练)

训练时开Gradient Checkpointing,省前向传播的激活值显存。

python 复制代码
model = GradientCheckpointing(model, checkpoint_ratio=0.5)

总结一下

FlashAttention的OOM原因:

  1. batch_size太大 → KV Cache显存爆炸
  2. seq_len太长 → KV Cache和中间结果都变大
  3. 没开PagedAttention → KV Cache显存利用率低
  4. 混合精度配置不对 → 额外的精度转换占用显存

估算公式

复制代码
总显存 = 模型权重 + batch_size × KV_Cache_per_token × seq_len + 中间结果

解决OOM的方法

  1. 开PagedAttention
  2. 用INT8 KV Cache量化
  3. 减小batch_size
  4. 开Gradient Checkpointing(训练)

代码和文档:

https://atomgit.com/cann/ops-transformer

相关推荐
2601_9571909010 小时前
迷拟极速飞车:多人同台竞速,轻量化高效落地
人工智能
灰灰勇闯IT10 小时前
AI Agent 推理:从单次对话到多轮工具调用
人工智能·microsoft
L、21810 小时前
CANN异构计算实践:CPU+NPU协同工作的最佳模式
网络·人工智能·pytorch·python·安全
nix.gnehc10 小时前
agentic 源码深度拆解:启动流程与会话调用流程全解
人工智能·agent
fa_lsyk10 小时前
安装部署Claude Code及测试
人工智能
2601_9578822410 小时前
一条视频如何自动适配5大平台的技术实现
人工智能·算法·机器学习
AI小百科10 小时前
目前开源AI编辑器面临的主要挑战是什么
人工智能·开源·编辑器
TDK村田muRata10 小时前
CUS200M-12 | TDK医疗电源|直流12V 16.7A |CUS200M-12/A
服务器·人工智能·3d·机器人·无人机
csdn小瓯10 小时前
日志规范化与结构化输出:构建可观测的 AI 后端系统
人工智能