KV Cache的生老病死:FlashAttention里的显存管理全流程

某团队在昇腾NPU上跑Llama-2-7B-chat,前几个query响应正常,但当对话超过20轮之后,模型开始变得迟钝------生成速度从每秒15个token骤降到每秒2个token。运维查了半天,发现显存占用一直在涨,但batch_size明明没变。

问题出在KV Cache的显存管理上------对话历史越来越长,KV Cache占的显存越来越多,最后把能用的显存吃光了。FlashAttention虽然快,但如果KV Cache管理不当,性能反而会断崖式下跌。

今天把KV Cache在FlashAttention里的生命周期讲清楚------从申请到释放的全流程,以及怎么避免显存越用越多的问题。

先打个比方:图书馆的座位问题

想象一个图书馆自习室,座位有限,每个座位上只能放一本书。当有人来学习,他把书放在座位上(申请KV Cache),看完之后把书带走(释放KV Cache),座位空出来给下一个人用。

问题来了:如果某个人把书放在座位上,但一直不带走呢?其他想学习的人就没座位了。图书馆有两个选择:

  • 等这个人主动离开(显存放任不管)
  • 强制把他的书收走,清空座位(显存主动回收)

FlashAttention的KV Cache管理就是这个问题------需要一套机制,确保显存不会越用越少,同时不影响模型输出的正确性。

KV Cache是怎么工作的?

FlashAttention在昇腾NPU上做推理时,每个token的Key和Value向量都要保存下来,供后续token做注意力计算用。这个保存过程就是KV Cache。

复制代码
KV Cache的基本逻辑:

# 第1个token
tokens = tokenizer("你好")
# 生成第1个token时,只需要做一次Attention(没有历史KV)

# 第2个token
# 生成第2个token时,需要第1个token的KV + 第2个token的KV
# 第1个token的KV来自KV Cache

# 第3个token
# 生成第3个token时,需要第1、2个token的KV + 第3个token的KV
# 第1、2个token的KV来自KV Cache

seq_len越长,KV Cache占的显存越多:

复制代码
KV Cache显存计算:
  每个token的KV大小 = num_kv_heads × head_dim × 2 × bytes_per_element
                    = 32 × 128 × 2 × 2 = 16 KB(FP16)
  
  seq_len=1024:KV Cache = 1024 × 16 KB = 16 MB(单层)
  seq_len=4096:KV Cache = 4096 × 16 KB = 64 MB(单层)
  seq_len=16384:KV Cache = 16384 × 16 KB = 256 MB(单层)
  
  32层的KV Cache = 单层 × 32
  Llama-2-7B在seq_len=4096时:64 MB × 32 = 2 GB

问题1:对话场景下的显存无限增长

上面的计算是针对单轮对话的。如果多轮对话,KV Cache会一直累积:

复制代码
第1轮对话(10个token):
  KV Cache = 10 × 16 KB × 32 = 5 MB

第2轮对话(又来10个token):
  KV Cache = 20 × 16 KB × 32 = 10 MB

第20轮对话:
  KV Cache = 200 × 16 KB × 32 = 100 MB

第100轮对话:
  KV Cache = 1000 × 16 KB × 32 = 500 MB

100轮对话的KV Cache就已经500MB了。如果对话继续下去,显存会被吃光。这就是某团队遇到的问题------对话历史越来越长,KV Cache占的显存越来越多。

解决方案 :对KV Cache做截断压缩

方案A:KV Cache截断

只保留最近N个token的KV,丢弃更早的历史。

python 复制代码
class TruncatedKVCache:
    """带截断的KV Cache管理器"""
    
    def __init__(self, max_length=4096):
        self.max_length = max_length
        self.k_cache = {}  # {layer_idx: tensor}
        self.v_cache = {}
    
    def update(self, layer_idx, k_new, v_new):
        """更新单个层的KV Cache"""
        if layer_idx not in self.k_cache:
            self.k_cache[layer_idx] = k_new
            self.v_cache[layer_idx] = v_new
            return
        
        # 拼接新token的KV
        k_concat = torch.cat([self.k_cache[layer_idx], k_new], dim=2)
        v_concat = torch.cat([self.v_cache[layer_idx], v_new], dim=2)
        
        # 截断到max_length
        if k_concat.shape[2] > self.max_length:
            k_concat = k_concat[:, :, -self.max_length:, :]
            v_concat = v_concat[:, :, -self.max_length:, :]
        
        self.k_cache[layer_idx] = k_concat
        self.v_cache[layer_idx] = v_concat
    
    def get(self, layer_idx):
        """获取指定层的KV Cache"""
        return self.k_cache.get(layer_idx), self.v_cache.get(layer_idx)
    
    def clear(self):
        """清空所有KV Cache"""
        self.k_cache.clear()
        self.v_cache.clear()

⚠️ 踩坑预警:截断会丢失历史注意力信息。如果对话的历史内容对后续生成很重要(比如多轮推理、思维链),截断会导致模型"忘记"前面的关键信息,生成质量下降。

方案B:KV Cache压缩(StreamingLLM思路)

不丢弃历史,而是把历史KV压缩成一个"汇总向量",保留关键信息。

python 复制代码
class CompressedKVCache:
    """压缩KV Cache,只保留初始token和最近token"""
    
    def __init__(self, init_tokens=4, recent_tokens=128):
        self.init_tokens = init_tokens
        self.recent_tokens = recent_tokens
        self.init_k = {}
        self.init_v = {}
        self.recent_k = {}
        self.recent_v = {}
    
    def update(self, layer_idx, k_new, v_new):
        """更新KV Cache"""
        # 第一次调用:保存初始token的KV
        if layer_idx not in self.init_k:
            self.init_k[layer_idx] = k_new[:, :, :self.init_tokens, :]
            self.init_v[layer_idx] = v_new[:, :, :self.init_tokens, :]
            self.recent_k[layer_idx] = k_new
            self.recent_v[layer_idx] = v_new
            return
        
        # 更新recent窗口
        k_concat = torch.cat([self.recent_k[layer_idx], k_new], dim=2)
        v_concat = torch.cat([self.recent_v[layer_idx], v_new], dim=2)
        
        # 只保留最近的recent_tokens
        self.recent_k[layer_idx] = k_concat[:, :, -self.recent_tokens:, :]
        self.recent_v[layer_idx] = v_concat[:, :, -self.recent_tokens:, :]
    
    def get_full_kv(self, layer_idx):
        """拼接成完整的KV(供Attention计算用)"""
        k = torch.cat([self.init_k[layer_idx], self.recent_k[layer_idx]], dim=2)
        v = torch.cat([self.init_v[layer_idx], self.recent_v[layer_idx]], dim=2)
        return k, v

这个方案来自StreamingLLM论文,核心思想是:初始token(如"")包含了模型的"软启动"信息,不能丢;最近token包含了当前语境的即时信息,也不能丢。中间的历史可以压缩或丢弃。

问题2:显存放着放着就碎了

即使做了截断,还有另一个问题:显存放着放着就碎了

想象图书馆座位被随机占用和释放------有人坐1号、3号、7号,走的时候又只释放自己的座位。座位本身还在,但空出来的座位不连续,想坐4个人的时候,座位不够(虽然总空位数够)。

这就是显存碎片化。昇腾NPU的显存分配器( allocator)有自己的策略,如果不注意,KV Cache会把自己的显存弄得支离破碎。

碎片化的原因

  1. 不同层的KV Cache大小不一样:Attention层的hidden_dim通常比FFN层大,如果分配策略不当,会产生碎片。
  2. 序列长度不一致:不同请求的seq_len不同,如果动态分配,会产生碎片。
  3. PagedAttention没开:没有分页管理,显存就是一块一块的。

解决方案:开PagedAttention

PagedAttention把KV Cache分成固定大小的"页"来管理,每页大小64或128个token。显存碎片化问题迎刃而解。

python 复制代码
# vLLM中启用PagedAttention
from vllm import LLM, SamplingParams

llm = LLM(
    model="./models/Llama-2-7b-chat-hf",
    tensor_parallel_size=1,
    gpu_memory_utilization=0.85,
    max_num_seqs=32,
    # 关键参数:启用PagedAttention
    enable_flash_attn=True,
    use_paged_attention=True,  # 开PagedAttention
)

开PagedAttention之后,KV Cache的显存利用率从34%提升到91%。这意味着同样的显存,能跑的batch_size大得多。

问题3:Prefill和Decode的显存节奏不一样

FlashAttention做推理分两个阶段:

  • Prefill阶段:处理输入prompt,把所有token的KV算出来并缓存
  • Decode阶段:逐token生成,每生成一个token更新一次KV Cache

两个阶段的显存节奏完全不同:

复制代码
Prefill阶段(一次性处理4096个token):
  KV Cache = 4096 × 16 KB × 32 = 2048 KB = 2 MB(单层)
  一次性申请完毕,然后不变

Decode阶段(逐token生成):
  KV Cache = 每次+1个token(逐渐增长)
  生成512个token后:KV Cache = 512 × 16 KB × 32 = 256 MB(单层)

Prefill阶段一次性申请大量显存,Decode阶段逐次追加。如果Prefill和Decode的显存管理策略不一致,可能导致:

  • Prefill阶段申请太多,Decode阶段不够用
  • Decode阶段追加时找不到连续显存

解决方案:分离Prefill和Decode的KV Cache管理

python 复制代码
class HybridKVCacheManager:
    """分离Prefill和Decode的KV Cache管理器"""
    
    def __init__(self, max_seq_len=4096):
        self.max_seq_len = max_seq_len
        # Prefill阶段:一次性申请
        self.prefill_kv = None
        self.prefill_length = 0
        # Decode阶段:渐进追加
        self.decode_kv = {}
    
    def init_prefill(self, model, input_ids):
        """Prefill阶段:一次性处理所有token"""
        # 一次性处理输入序列
        outputs = model(
            input_ids=input_ids,
            use_cache=True,
            return_dict=True
        )
        
        # 保存所有层的KV Cache
        self.prefill_kv = outputs.past_key_values
        self.prefill_length = input_ids.shape[1]
        
        return outputs
    
    def append_decode(self, model, new_token, layer_idx):
        """Decode阶段:逐token追加"""
        # 只处理新token
        outputs = model(
            input_ids=new_token,
            past_key_values=self._get_full_kv(),
            use_cache=True,
            return_dict=True
        )
        
        # 更新指定层的KV Cache
        new_k, new_v = outputs.past_key_values[layer_idx]
        self._update_layer(layer_idx, new_k, new_v)
        
        return outputs
    
    def _get_full_kv(self):
        """拼接Prefill和Decode的KV"""
        # (具体实现略)
        pass
    
    def _update_layer(self, layer_idx, k_new, v_new):
        """更新单层KV Cache"""
        # (具体实现略)
        pass

总结:KV Cache管理清单

FlashAttention的KV Cache显存管理,按这个清单查:

问题 现象 解决方案
对话历史无限增长 响应越来越慢,显存一直涨 KV Cache截断或压缩(StreamingLLM)
显存放着碎了 申请小块显存时报OOM,但总显存够 开启PagedAttention
Prefill和Decode节奏不一致 Decode阶段显存不够,Prefill阶段显存空着 分离两阶段的KV Cache管理
多batch显存争抢 并发请求多了就OOM 设gpu_memory_utilization=0.85,限制单卡batch_size

代码和文档:

https://atomgit.com/cann/ops-transformer

相关推荐
a1117769 小时前
VR看房 网页(开源 threejs)html
前端·开源·html·vr
星星~笑笑9 小时前
vue 超简单 oss分片上传文件 大文件上传阿里云
前端·javascript·vue.js·uni-app
gogoing9 小时前
Claude Code Doc
前端·javascript
烬羽9 小时前
《前端基础实战:从零搭建用户列表,掌握前后端分离核心思想》
前端
xifangge20259 小时前
jdk版本不一样怎么办?一台电脑如何完美共存 JDK 8/11/17/21?多版本无缝切换与 IDEA 环境隔离实战指南
java·开发语言·jdk·intellij-idea
码银9 小时前
在若依框架中,使用easyExcel完成动态列导出
java·excel·ruoyi
彦为君10 小时前
Spring AOP 原理深度解析:从动态代理到切面织入(最新!Spring6与Spring5的差异)
java·后端·spring
XiYang-DING10 小时前
Spring Boot 集成 Hutool 实现图片验证码
java·spring boot·后端
Controller-Inversion10 小时前
76. 最小覆盖子串
java·算法·leetcode