某团队在昇腾NPU上跑Llama-2-7B-chat,前几个query响应正常,但当对话超过20轮之后,模型开始变得迟钝------生成速度从每秒15个token骤降到每秒2个token。运维查了半天,发现显存占用一直在涨,但batch_size明明没变。
问题出在KV Cache的显存管理上------对话历史越来越长,KV Cache占的显存越来越多,最后把能用的显存吃光了。FlashAttention虽然快,但如果KV Cache管理不当,性能反而会断崖式下跌。
今天把KV Cache在FlashAttention里的生命周期讲清楚------从申请到释放的全流程,以及怎么避免显存越用越多的问题。
先打个比方:图书馆的座位问题
想象一个图书馆自习室,座位有限,每个座位上只能放一本书。当有人来学习,他把书放在座位上(申请KV Cache),看完之后把书带走(释放KV Cache),座位空出来给下一个人用。
问题来了:如果某个人把书放在座位上,但一直不带走呢?其他想学习的人就没座位了。图书馆有两个选择:
- 等这个人主动离开(显存放任不管)
- 强制把他的书收走,清空座位(显存主动回收)
FlashAttention的KV Cache管理就是这个问题------需要一套机制,确保显存不会越用越少,同时不影响模型输出的正确性。
KV Cache是怎么工作的?
FlashAttention在昇腾NPU上做推理时,每个token的Key和Value向量都要保存下来,供后续token做注意力计算用。这个保存过程就是KV Cache。
KV Cache的基本逻辑:
# 第1个token
tokens = tokenizer("你好")
# 生成第1个token时,只需要做一次Attention(没有历史KV)
# 第2个token
# 生成第2个token时,需要第1个token的KV + 第2个token的KV
# 第1个token的KV来自KV Cache
# 第3个token
# 生成第3个token时,需要第1、2个token的KV + 第3个token的KV
# 第1、2个token的KV来自KV Cache
seq_len越长,KV Cache占的显存越多:
KV Cache显存计算:
每个token的KV大小 = num_kv_heads × head_dim × 2 × bytes_per_element
= 32 × 128 × 2 × 2 = 16 KB(FP16)
seq_len=1024:KV Cache = 1024 × 16 KB = 16 MB(单层)
seq_len=4096:KV Cache = 4096 × 16 KB = 64 MB(单层)
seq_len=16384:KV Cache = 16384 × 16 KB = 256 MB(单层)
32层的KV Cache = 单层 × 32
Llama-2-7B在seq_len=4096时:64 MB × 32 = 2 GB
问题1:对话场景下的显存无限增长
上面的计算是针对单轮对话的。如果多轮对话,KV Cache会一直累积:
第1轮对话(10个token):
KV Cache = 10 × 16 KB × 32 = 5 MB
第2轮对话(又来10个token):
KV Cache = 20 × 16 KB × 32 = 10 MB
第20轮对话:
KV Cache = 200 × 16 KB × 32 = 100 MB
第100轮对话:
KV Cache = 1000 × 16 KB × 32 = 500 MB
100轮对话的KV Cache就已经500MB了。如果对话继续下去,显存会被吃光。这就是某团队遇到的问题------对话历史越来越长,KV Cache占的显存越来越多。
解决方案 :对KV Cache做截断 或压缩。
方案A:KV Cache截断
只保留最近N个token的KV,丢弃更早的历史。
python
class TruncatedKVCache:
"""带截断的KV Cache管理器"""
def __init__(self, max_length=4096):
self.max_length = max_length
self.k_cache = {} # {layer_idx: tensor}
self.v_cache = {}
def update(self, layer_idx, k_new, v_new):
"""更新单个层的KV Cache"""
if layer_idx not in self.k_cache:
self.k_cache[layer_idx] = k_new
self.v_cache[layer_idx] = v_new
return
# 拼接新token的KV
k_concat = torch.cat([self.k_cache[layer_idx], k_new], dim=2)
v_concat = torch.cat([self.v_cache[layer_idx], v_new], dim=2)
# 截断到max_length
if k_concat.shape[2] > self.max_length:
k_concat = k_concat[:, :, -self.max_length:, :]
v_concat = v_concat[:, :, -self.max_length:, :]
self.k_cache[layer_idx] = k_concat
self.v_cache[layer_idx] = v_concat
def get(self, layer_idx):
"""获取指定层的KV Cache"""
return self.k_cache.get(layer_idx), self.v_cache.get(layer_idx)
def clear(self):
"""清空所有KV Cache"""
self.k_cache.clear()
self.v_cache.clear()
⚠️ 踩坑预警:截断会丢失历史注意力信息。如果对话的历史内容对后续生成很重要(比如多轮推理、思维链),截断会导致模型"忘记"前面的关键信息,生成质量下降。
方案B:KV Cache压缩(StreamingLLM思路)
不丢弃历史,而是把历史KV压缩成一个"汇总向量",保留关键信息。
python
class CompressedKVCache:
"""压缩KV Cache,只保留初始token和最近token"""
def __init__(self, init_tokens=4, recent_tokens=128):
self.init_tokens = init_tokens
self.recent_tokens = recent_tokens
self.init_k = {}
self.init_v = {}
self.recent_k = {}
self.recent_v = {}
def update(self, layer_idx, k_new, v_new):
"""更新KV Cache"""
# 第一次调用:保存初始token的KV
if layer_idx not in self.init_k:
self.init_k[layer_idx] = k_new[:, :, :self.init_tokens, :]
self.init_v[layer_idx] = v_new[:, :, :self.init_tokens, :]
self.recent_k[layer_idx] = k_new
self.recent_v[layer_idx] = v_new
return
# 更新recent窗口
k_concat = torch.cat([self.recent_k[layer_idx], k_new], dim=2)
v_concat = torch.cat([self.recent_v[layer_idx], v_new], dim=2)
# 只保留最近的recent_tokens
self.recent_k[layer_idx] = k_concat[:, :, -self.recent_tokens:, :]
self.recent_v[layer_idx] = v_concat[:, :, -self.recent_tokens:, :]
def get_full_kv(self, layer_idx):
"""拼接成完整的KV(供Attention计算用)"""
k = torch.cat([self.init_k[layer_idx], self.recent_k[layer_idx]], dim=2)
v = torch.cat([self.init_v[layer_idx], self.recent_v[layer_idx]], dim=2)
return k, v
这个方案来自StreamingLLM论文,核心思想是:初始token(如"")包含了模型的"软启动"信息,不能丢;最近token包含了当前语境的即时信息,也不能丢。中间的历史可以压缩或丢弃。
问题2:显存放着放着就碎了
即使做了截断,还有另一个问题:显存放着放着就碎了。
想象图书馆座位被随机占用和释放------有人坐1号、3号、7号,走的时候又只释放自己的座位。座位本身还在,但空出来的座位不连续,想坐4个人的时候,座位不够(虽然总空位数够)。
这就是显存碎片化。昇腾NPU的显存分配器( allocator)有自己的策略,如果不注意,KV Cache会把自己的显存弄得支离破碎。
碎片化的原因
- 不同层的KV Cache大小不一样:Attention层的hidden_dim通常比FFN层大,如果分配策略不当,会产生碎片。
- 序列长度不一致:不同请求的seq_len不同,如果动态分配,会产生碎片。
- PagedAttention没开:没有分页管理,显存就是一块一块的。
解决方案:开PagedAttention
PagedAttention把KV Cache分成固定大小的"页"来管理,每页大小64或128个token。显存碎片化问题迎刃而解。
python
# vLLM中启用PagedAttention
from vllm import LLM, SamplingParams
llm = LLM(
model="./models/Llama-2-7b-chat-hf",
tensor_parallel_size=1,
gpu_memory_utilization=0.85,
max_num_seqs=32,
# 关键参数:启用PagedAttention
enable_flash_attn=True,
use_paged_attention=True, # 开PagedAttention
)
开PagedAttention之后,KV Cache的显存利用率从34%提升到91%。这意味着同样的显存,能跑的batch_size大得多。
问题3:Prefill和Decode的显存节奏不一样
FlashAttention做推理分两个阶段:
- Prefill阶段:处理输入prompt,把所有token的KV算出来并缓存
- Decode阶段:逐token生成,每生成一个token更新一次KV Cache
两个阶段的显存节奏完全不同:
Prefill阶段(一次性处理4096个token):
KV Cache = 4096 × 16 KB × 32 = 2048 KB = 2 MB(单层)
一次性申请完毕,然后不变
Decode阶段(逐token生成):
KV Cache = 每次+1个token(逐渐增长)
生成512个token后:KV Cache = 512 × 16 KB × 32 = 256 MB(单层)
Prefill阶段一次性申请大量显存,Decode阶段逐次追加。如果Prefill和Decode的显存管理策略不一致,可能导致:
- Prefill阶段申请太多,Decode阶段不够用
- Decode阶段追加时找不到连续显存
解决方案:分离Prefill和Decode的KV Cache管理
python
class HybridKVCacheManager:
"""分离Prefill和Decode的KV Cache管理器"""
def __init__(self, max_seq_len=4096):
self.max_seq_len = max_seq_len
# Prefill阶段:一次性申请
self.prefill_kv = None
self.prefill_length = 0
# Decode阶段:渐进追加
self.decode_kv = {}
def init_prefill(self, model, input_ids):
"""Prefill阶段:一次性处理所有token"""
# 一次性处理输入序列
outputs = model(
input_ids=input_ids,
use_cache=True,
return_dict=True
)
# 保存所有层的KV Cache
self.prefill_kv = outputs.past_key_values
self.prefill_length = input_ids.shape[1]
return outputs
def append_decode(self, model, new_token, layer_idx):
"""Decode阶段:逐token追加"""
# 只处理新token
outputs = model(
input_ids=new_token,
past_key_values=self._get_full_kv(),
use_cache=True,
return_dict=True
)
# 更新指定层的KV Cache
new_k, new_v = outputs.past_key_values[layer_idx]
self._update_layer(layer_idx, new_k, new_v)
return outputs
def _get_full_kv(self):
"""拼接Prefill和Decode的KV"""
# (具体实现略)
pass
def _update_layer(self, layer_idx, k_new, v_new):
"""更新单层KV Cache"""
# (具体实现略)
pass
总结:KV Cache管理清单
FlashAttention的KV Cache显存管理,按这个清单查:
| 问题 | 现象 | 解决方案 |
|---|---|---|
| 对话历史无限增长 | 响应越来越慢,显存一直涨 | KV Cache截断或压缩(StreamingLLM) |
| 显存放着碎了 | 申请小块显存时报OOM,但总显存够 | 开启PagedAttention |
| Prefill和Decode节奏不一致 | Decode阶段显存不够,Prefill阶段显存空着 | 分离两阶段的KV Cache管理 |
| 多batch显存争抢 | 并发请求多了就OOM | 设gpu_memory_utilization=0.85,限制单卡batch_size |
代码和文档: