大模型推理加速 KV cache

目录

[一、Attention 计算](#一、Attention 计算)

[Attention 计算详解](#Attention 计算详解)

备注:

[1. Mask 的核心作用](#1. Mask 的核心作用)

[2. 两种主要的 Mask 类型](#2. 两种主要的 Mask 类型)

[1. Padding Mask](#1. Padding Mask)

[2.Causal Mask(因果掩码 / Look-Ahead Mask)](#2.Causal Mask(因果掩码 / Look-Ahead Mask))

[二、KV cache 步骤](#二、KV cache 步骤)

[KV cache 显存计算](#KV cache 显存计算)

[三、减少KV cache](#三、减少KV cache)

[1. 为什么减少KV cache?](#1. 为什么减少KV cache?)

[1.1 更长的上下文](#1.1 更长的上下文)

[1.2 提升推理效率](#1.2 提升推理效率)

[2. MHA(multi-Head Attention)](#2. MHA(multi-Head Attention))

[3. MQA(Multi-Query Attention)](#3. MQA(Multi-Query Attention))

[4. GQA(Grouped-Query Attention)](#4. GQA(Grouped-Query Attention))

四、Q&A

[1. Q:为什么GQA在70B大模型上的优势比7B更明显?](#1. Q:为什么GQA在70B大模型上的优势比7B更明显?)

[2. Q:如何验证某层是否适合改用GQA?](#2. Q:如何验证某层是否适合改用GQA?)

[2.1 理论依据:为什么注意力熵与GQA相关?](#2.1 理论依据:为什么注意力熵与GQA相关?)

[2.2 如何计算注意力熵?](#2.2 如何计算注意力熵?)

[2.3 阈值设定(如 < 3.5)的合理性](#2.3 阈值设定(如 < 3.5)的合理性)


一、Attention 计算

Attention 计算详解

对于 decode only 的 Transformer 架构,我们给一个输入文本,模型会输出一个回答(长度为 N),每一步根据之前 token 生成下一个 token,也就是每次推理只输出一个 token,输出的 token 会与输入 token 拼接在一起,然后作为下一次推理的输入,这样不断反复直到遇到终止符。

在没有KV Cache的情况下,其计算过程如下所示:

正常情况下,Attention计算公式:

为了看上去方便,我们暂时忽略scale项,因此,Attention的计算公式如下所示(softmaxed 表示已经按行进行了softmax):

变为矩阵时,softmax 会针对行进行计算,详细如下(softmaxed 表示已经按行进行了softmax):

其中, 表示 Attention 的第一行, 表示 Attention 的第二行。

对于 ,由于这个值会mask掉,你会发现,在第二步参与的计算与第一步是完全一样的,并且 参与计算Attention时也仅仅依赖于,与 毫无关系。

对于参与计算Attention时也仅仅依赖于 ,与 毫无关系。

其计算方式如 Step2 所示。

对于 参与计算Attention时也仅仅依赖于

看上面图和公式,我们可以得出以下结论:

  1. 当前计算方式存在大量冗余计算,每一次生成新的Token都需要计算之前的KV。

  2. 的计算过程中,主要与有关。 参与计算Attention时也仅仅依赖于

  3. 每一步中,其实只需要根据 计算 就可以,之前已经计算的Attention完全不需要重新计算。但是 K 和 V 是全程参与计算的,所以这里我们需要把每一步的 K 、 V 缓存起来。

备注(Mask):

1. Mask 的核心作用

Mask 的主要目的是:在计算注意力时,忽略无效或未来的位置信息。

具体来说,Mask 用于两种常见场景:

  1. Padding Mask:忽略填充(padding)部分

  2. Causal Mask(或 Look-Ahead Mask):防止模型在生成当前 token 时"偷看"未来的 token

2. 两种主要的 Mask 类型
1. Padding Mask
  • 场景:处理变长序列时,将短序列填充到统一长度

  • 问题:填充的 token(如 [PAD])不应参与注意力计算

  • 解决方法:使用 Padding Mask 将这些位置的注意力分数设为 -inf,softmax 后权重为 0

    示例:假设 padding 位置为 0

    mask = (input_ids != tokenizer.pad_token_id) # True 表示有效 token
    mask = mask.unsqueeze(1).unsqueeze(2) # 扩展维度以匹配注意力矩阵

2.Causal Mask(因果掩码 / Look-Ahead Mask)
  • 场景:自回归生成(如 GPT),模型逐个生成 token

  • 问题:在预测第 t 个 token 时,不能依赖第 t+1 及之后的信息

  • 解决方法:使用上三角矩阵作为 mask,屏蔽未来位置

    import torch

    def create_causal_mask(seq_len):
    # 创建一个上三角矩阵,对角线以下为 1,以上为 0
    mask = torch.tril(torch.ones(seq_len, seq_len))
    return mask.unsqueeze(0).unsqueeze(1) # (1, 1, seq_len, seq_len)

    使用

    scores = scores.masked_fill(mask == 0, -1e9)

这确保了每个位置只能关注它自身及之前的位置,符合自回归假设。

二、KV cache 步骤

一个典型的带有 KV Cache 的推理过程包含以下两个阶段:

  • 预填充阶段:输入一个 prompt 序列,为每个 transformer 层生成 Key Cache 和 Value Cache(KV cache)。

  • 解码阶段:使用并更新 KV Cache,一个接一个地生成 token,当前生成的 token 依赖于之前已经生成的token。

KV cache 显存计算

|----|----------------------------------|----------------------------------------|
| 符号 | 含义 | 说明 |
| 2 | Key 和 Value 两部分 | 每个 token 都会生成 K 和 V 两个向量,需要分别缓存 |
| L | 层数(Number of Layers) | Transformer 模型的层数(如 Llama-3-8B 有 32 层) |
| S | 序列长度(Sequence Length) | 当前已处理的总 token 数(prompt + 已生成的 token) |
| H | 注意力头数(Number of Attention Heads) | 多头注意力的头数(如 32) |
| D | 头的维度(Head dimension) | 每个注意力头的向量维度 |
| T | 数据类型大小(Type Size) | 每个数值占用的字节数(fp16或bf16为 2 字节,fp32为 4 字节) |

  • 单个 token 的 K/V 向量大小:H × D (所有头拼接起来)

  • 单层的缓存大小:S × H × D (S 个 token)

  • 所有层的缓存大小:L × S × H × D

  • Key 和 Value 两部分:2 × L × S × H × D

  • 乘以数据类型大小:2 × L × S × H × D × T

三、减少KV cache

1. 为什么减少KV cache?

1.1 更长的上下文

众所周知,LLM 的推理过程通常在 GPU 上进行。然而,单张 GPU 的显存容量是有限的。一部分容量需要用来存放模型的参数和前向计算的激活值,这部分的大小主要取决于模型的体量,选定模型后就成为了一个固定值。另一部分容量则用来存放模型的 KV Cache,它的大小不仅与模型体量有关,还随着模型输入长度的增加而动态增长。当 Context 长度足够长时,KV Cache 的大小将占据主导地位,甚至可能超出一张GPU甚至一台机器(包含8张GPU)的总显存容量。

1.2 提升推理效率

在GPU上部署模型时,我们遵循的原则是:能在一张卡上部署的,就不要跨多张卡;能在一台机器上部署的,就不要跨多台机器。这是因为"卡内通信带宽 > 卡间通信带宽 > 机间通信带宽"。由于"木桶效应",模型部署时跨的设备越多,受到设备间通信带宽的制约就越大。

因此,减少 KV Cache 的目的是为了在更少的设备上推理更长的 Context,或者在相同的 Context 长度下实现更大的推理 batch size,从而提升推理速度或增加吞吐总量。最终目的都是为了降低推理成本

2. MHA(multi-Head Attention)

原始架构:每个注意力头都有独立的 Query、Key 和 Value 投影矩阵。

  • 优点:模型表达能力最强,能捕捉最丰富的特征。

  • 缺点:

    • KV Cache 显存占用最高:每个头都维护自己的 K 和 V,显存占用 = num_heads × head_dim × seq_len。

    • 训练和推理成本高

3. MQA(Multi-Query Attention)

核心思想:所有注意力头共享同一组 KeyValue,只有 Query 是每个头独立的。

  • 优点:

    • KV Cache 显存大幅降低:从 num_heads 份 K/V 变为 1 份,显存占用接近 MHA 的 1/num_heads

    • 推理速度显著提升,尤其在长序列生成时。

  • 缺点:模型表达能力略有下降,精度可能微降。

  • 应用:Google 的 PaLM 模型采用 MQA; 适合对推理延迟要求极高的场景。

4. GQA(Grouped-Query Attention)

核心思想:MHA 和 MQA 的折中方案。将 num_heads 分为 num_groups 组,每组内共享 K/V。

  • 优点:

    • 显存占用和推理速度介于 MHA 和 MQA 之间。

    • 在速度和精度之间取得良好平衡。

  • 缺点:实现比 MHA 复杂。

  • 应用:Llama-2-70B、Llama-3 系列使用 GQA。

四、Q&A

1. Q:为什么GQA在70B大模型上的优势比7B更明显?

→ 大模型的注意力头维度更高(d_model=8192),KV投影参数占比更大,分组共享的收益呈超线性增长。

2. Q:如何验证某层是否适合改用GQA?

计算该层注意力矩阵的熵值,若熵值<3.5(高确定性注意力),可安全改用GQA

2.1 理论依据:为什么注意力熵与GQA相关?
  • GQA的核心:将多个Query头(Query Heads)映射到更少的Key/Value头(KV Heads)上进行共享。

  • 潜在风险:如果不同Query头关注的是完全不同的信息源(即注意力分布差异大),强行共享KV头会导致信息混淆,损害模型性能。

  • 注意力熵:衡量注意力权重分布的"集中度"。

    • 低熵(如 < 3.5):注意力集中在少数几个token上,分布尖锐。说明模型决策很"确定",对KV信息的多样性要求低,适合GQA。

    • 高熵(如 > 5.0):注意力分布平坦,分散在很多token上。说明模型在"犹豫",需要更丰富的KV信息,不适合GQA或需谨慎使用。

🔍 类比:如果一个老师(Query)只关注班里最优秀的一个学生(高确定性),那么多个老师共享这个"优秀学生名单"(KV)是合理的。但如果每个老师关注的学生都不同,共享名单就会出错。

2.2 如何计算注意力熵?

在模型推理过程中,获取某一层的注意力权重 attn_weights(形状为 [batch_size, num_heads, seq_len, seq_len]),然后计算其熵:

2.3 阈值设定(如 < 3.5)的合理性
  • 阈值 3.5 是一个经验性阈值。对于一个长度为 n 的均匀分布,其最大熵为 log(n)

    • 例如,seq_len=1024 时,最大熵为 log(1024) ≈ 6.93

    • 熵值 3.5 大约是最大熵的一半,表明分布已经相当集中。

  • 经验法则:

    • 熵 < 3.5:注意力高度集中,强烈推荐使用GQA。

    • 3.5 ≤ 熵 < 5.0:中等集中度,可以尝试GQA,但需评估性能。

    • 熵 ≥ 5.0:注意力分散,不建议使用GQA,应保持MHA。

相关推荐
ILL11IIL15 天前
Mysql 集群技术
数据库·mysql·mha
dawdo2222 个月前
自己动手从头开始编写LLM推理引擎(9)-KV缓存实现和优化
缓存·llm·transformer·qwen·kv cache
被制作时长两年半的个人练习生2 个月前
KV Cache
kv cache
enjoy编程3 个月前
Spring AI 大模型工程核心:效率的极限博弈
注意力机制·flashattention·kv cache·pd分离·pagedattention·epd分离·radixattention
每天都要写算法(努力版)3 个月前
【混合注意力模型的 KV Cache 设计与统一管理实践解析】
llm·vllm·kv cache
缘友一世5 个月前
LLama 3分组查询注意力与KV缓存机制
人工智能·深度学习·缓存·transformer·llama·gqa·kv缓存
奥尔特星云大使5 个月前
mysql高可用架构之MHA部署(三)——故障转移后邮件告警配置(保姆级)
mysql·dba·高可用·mha·邮件告警
奥尔特星云大使5 个月前
mysql高可用架构之MHA部署(二)VIP漂移(保姆级)
android·mysql·架构·mha·ip漂移
张鱼小丸子7 个月前
MySQL企业级部署与高可用实战
运维·数据库·mysql·云原生·高可用·mha·组从复制