之前有个团队在昇腾NPU上跑Llama-2-7B,模型是FP16权重,seq_len=4096。他们算了算显存:模型权重13.5GB + 激活值4GB + KV Cache 4GB = 21.5GB,昇腾910有32GB显存,绰绰有余。
结果一跑就报OOM(Out Of Memory)。他们懵了:明明只用了21.5GB,32GB显存怎么就满了?
我帮他们排查了一下,发现问题出在FlashAttention的内存分配策略------它不是一次性分配所有显存,而是分块分配。每个分块在SRAM里处理完之后,要暂时存一个"中间结果",这个中间结果会占用额外的显存。
今天把这个内存分配的问题讲清楚,帮你排查OOM。
先打个比方:厨房的台面空间
想象你在厨房做饭,台面空间(SRAM)有限。你要做一道复杂的菜,需要很多步骤。
标准Attention的做法:先把所有食材摆在台面上(一次性分配所有显存),一次做完,最后清理台面。
FlashAttention的做法:每次只拿一小块食材到台面上(分块分配),做完这一步,放到一边(临时存储),再拿下一块。这样台面不用一直占满。
问题在哪? FlashAttention的分块策略虽然省了"总显存",但峰值显存可能更高------因为每个分块处理完的"临时结果"还在显存里,要等所有分块处理完才能释放。
FlashAttention的内存分配到底怎么回事?
FlashAttention在昇腾NPU上的内存分配分三层:
第一层:模型权重(静态分配)
模型权重是最稳定的显存占用,在整个推理/训练过程中都存在。
Llama-2-7B FP16的权重分布:
QKV投影:4096 × (3 × 4096) × 2 = 302 MB
输出投影:4096 × 4096 × 2 = 32 MB
FFN门控:4 × 4096 × 11008 × 2 = 361 MB
FFN降维:11008 × 4096 × 2 = 90 MB
词表投影:32000 × 4096 × 2 = 262 MB
位置编码:513 × 4096 × 2 = 4 MB
LayerNorm:4096 × 2 × 2 × 2 = 0.06 MB
总计:~1 GB
32层堆叠 × 1 GB = ~13.5 GB(跟之前算的一致)。
第二层:KV Cache(动态分配)
KV Cache是推理时动态分配的。每个token生成完,它的K和V要存下来,供后续token的Attention用。
KV Cache的计算:
每个token的KV大小 = num_kv_heads × head_dim × 2(K+V)
= 32 × 128 × 2 = 8 KB
seq_len=4096的KV Cache = 4096 × 8 KB = 32 MB(单层)
32层的KV Cache = 32 × 32 MB = **1 GB**
第三层:FlashAttention的中间结果(最容易OOM的地方)
FlashAttention在分块处理的时候,会产生一些中间结果。这些中间结果不会立即释放,要等整个Attention层处理完才能释放。
FlashAttention的中间结果(每层):
Q分块缓冲:block_size × head_dim × 2 × num_heads = 128 × 128 × 2 × 32 = 1 MB
K分块缓冲:block_size × head_dim × 2 × num_kv_heads = 128 × 128 × 2 × 32 = 1 MB
V分块缓冲:同上 = 1 MB
输出缓冲:block_size × head_dim × 2 × num_heads = 1 MB
在线Softmax状态(m和l):block_size × 4 × 2 = 1 KB(可忽略)
每层的中间结果总计:~4 MB
32层的中间结果:32 × 4 = **128 MB**
等等,128MB看起来不多啊,为什么会OOM?
真正的问题:Gradient Checkpointing的中间结果
问题出在Gradient Checkpointing(激活重计算)------训练的时候,为了省显存,会重计算前向传播的中间结果。但FlashAttention的中间结果不在这个策略里!
python
# Gradient Checkpointing的配置
model = GradientCheckpointing(model, checkpoint_ratio=0.5)
# FlashAttention的中间结果不在checkpoint策略里!
# 这些中间结果会一直占用显存,不会被释放
FlashAttention的中间结果不在Gradient Checkpointing的保护范围内,所以会占用额外的显存。
OOM的常见原因
原因1:batch_size太大
batch_size太大会导致KV Cache显存占用暴涨。
batch_size=1:KV Cache = 1 GB
batch_size=4:KV Cache = 4 GB(线性增长)
batch_size=16:KV Cache = 16 GB
batch_size=32:KV Cache = 32 GB(爆炸)
原因2:seq_len太长
seq_len太长会导致KV Cache和Attention矩阵的中间结果都变大。
seq_len=2048:KV Cache = 512 MB,中间结果 = 64 MB
seq_len=4096:KV Cache = 1 GB,中间结果 = 128 MB
seq_len=8192:KV Cache = 2 GB,中间结果 = 256 MB
seq_len=16384:KV Cache = 4 GB,中间结果 = 512 MB
原因3:没有开PagedAttention
PagedAttention能把KV Cache的显存利用率从34%提升到91%。不开PagedAttention,同样的显存,能跑的batch_size更小。
原因4:混合精度配置不对
如果开了FP16训练但BF16推理,或者反过来,昇腾NPU要做额外的精度转换,占用额外显存。
OOM排查清单
你的FlashAttention报OOM,按这个清单查:
python
def diagnose_oom():
"""FlashAttention OOM诊断"""
# 1. 检查模型权重大小
total_params = sum(p.numel() for p in model.parameters())
weight_mem = total_params * 2 / (1024 ** 3) # FP16
print(f"模型权重显存:{weight_mem:.2f} GB")
# 2. 检查KV Cache大小
kv_mem_per_layer = seq_len * num_kv_heads * head_dim * 2 * 2 / (1024 ** 2) # FP16, MB
total_kv_mem = kv_mem_per_layer * num_layers
print(f"KV Cache显存(单层):{kv_mem_per_layer:.2f} MB")
print(f"KV Cache显存(总):{total_kv_mem:.2f} GB")
# 3. 检查中间结果大小
intermediate_mem_per_layer = (
block_size * head_dim * 2 * num_heads * 4 / (1024 ** 2) # QKV输出缓冲
)
total_intermediate = intermediate_mem_per_layer * num_layers
print(f"中间结果显存(单层):{intermediate_mem_per_layer:.2f} MB")
print(f"中间结果显存(总):{total_intermediate:.2f} GB")
# 4. 计算总显存
total_mem = weight_mem + total_kv_mem + total_intermediate
print(f"\n估算总显存:{total_mem:.2f} GB")
print(f"可用显存:{torch.npu.get_device_properties(0).total_memory / (1024**3):.2f} GB")
# 5. 判断
if total_mem > torch.npu.get_device_properties(0).total_memory / (1024**3):
print("\n❌ 显存不足!")
print("建议:")
print(" - 减小batch_size")
print(" - 减小seq_len")
print(" - 开PagedAttention")
print(" - 用INT8 KV Cache量化")
else:
print(f"\n✅ 显存估算足够(剩余 {torch.npu.get_device_properties(0).total_memory / (1024**3) - total_mem:.2f} GB)")
print("实际OOM可能是其他原因(内存碎片、驱动问题等)")
# 运行诊断
diagnose_oom()
解决OOM的方法
方法1:开PagedAttention
PagedAttention能把KV Cache的显存利用率从34%提升到91%。
python
# vLLM配置
python -m vllm.entrypoints.openai.api_server \
--model ./models/Llama-2-7b-chat-hf \
--enable-flash-attn \
--use-paged-attention \ # 开PagedAttention
--max-num-seqs 32
方法2:用INT8 KV Cache量化
INT8量化能把KV Cache的显存减半。
python
# vLLM配置
python -m vllm.entrypoints.openai.api_server \
--model ./models/Llama-2-7b-chat-hf \
--enable-flash-attn \
--kv-cache-dtype int8 \ # INT8 KV Cache
--max-num-seqs 32
方法3:减小batch_size
batch_size减半,KV Cache显存减半。
方法4:用Gradient Checkpointing(训练)
训练时开Gradient Checkpointing,省前向传播的激活值显存。
python
model = GradientCheckpointing(model, checkpoint_ratio=0.5)
总结一下
FlashAttention的OOM原因:
- batch_size太大 → KV Cache显存爆炸
- seq_len太长 → KV Cache和中间结果都变大
- 没开PagedAttention → KV Cache显存利用率低
- 混合精度配置不对 → 额外的精度转换占用显存
估算公式:
总显存 = 模型权重 + batch_size × KV_Cache_per_token × seq_len + 中间结果
解决OOM的方法:
- 开PagedAttention
- 用INT8 KV Cache量化
- 减小batch_size
- 开Gradient Checkpointing(训练)
代码和文档: