【PyTorch][chapter-33][transformer-5] MHA MQA GQA, KV-Cache

主要翻译外网: 解剖Deep Seek 系列,详细见参考部分。


目录:

  1. Multi-Head Attention (MHA)
  2. KV-Cache
  3. KV-Cache 公式
  4. Multi-Query Attention(MQA)
  5. Grouped-Query Attention(GQA)
  6. Multi-Head Latent Attention
  7. PyTorch Implementing MHA, MQA, and GQA

一 Multi-Head Attention (MHA)

输入: Q,K,V 通常为x ,其形状为:

batch_size, seq_len, d_model

++第一步: 进行线性空间变换:++

++第二步: 子空间投影(projection)++

其中为子空间头数量一般设置为8

++第三步: 做 self-attention++


二 KV-Cache

在Transformer的Decoder推理过程中,由于自注意力机制需要遍历整个先前输入的序列来计算每个新token的注意力权重,这导致了显著的计算负担。随着序列长度的延伸,计算复杂度急剧上升,不仅增加了延迟,还限制了模型处理长序列的能力。因此,优化Decoder的自注意力机制,减少不必要的计算开销,成为提升Transformer模型推理效率的关键所在。

KV 缓存发生在多个 token 生成步骤中,并且仅发生在解码器中 (即,在 GPT 等仅解码器模型中,或在 T5 等编码器-解码器模型的解码器部分中)。BERT 等模型不是生成式的,因此没有 KV 缓存。

这种自回归行为重复了一些操作,我们可以通过放大在解码器中计算的掩蔽缩放点积注意力计算来更好地理解这一点。

由于解码器是auto-regressive 的,因此在每个生成步骤中**,我们都在重新计算相同的先前标记的注意力,而实际上我们只是想计算新标记的注意力。**

这就是 KV 发挥作用的地方。通过缓存以前的 Keys 和 Values,我们可以专注于计算新 token 的注意力。

为什么这种优化很重要?如上图所示,使用 KV 缓存获得的矩阵要小得多,从而可以加快矩阵乘法的速度。唯一的缺点是它需要更多的 GPU VRAM(如果不使用 GPU,则需要 CPU RAM)来缓存 Key 和 Value 状态。


三 KV-Cache 公式

基本和MHA 过程差不多,区别是每次输入的是: 第t 时刻的token

3.1 对当前时刻的输入进行线性变换

d: embedding 的维度

3.2 进行子空间投影

其中:

attention head 数量

attention head 的维度

: 第i个头,t时刻的查询向量

3.3 做self-attention

我们把存储的K,V缓存叫做K-V Cache. 对于一个L层的模型,每个t个token 一共需要

缓存。

是一个head的size ,MLA 就是研究这个size 如何降维降低KV-Cache


四 Multi-Query Attention(MQA)

为了缓解多头注意力(MHA)中的键值缓存瓶颈问题,Shazeer在2019年提出了多查询注意力(MQA)机制。在该机制中,所有的不同注意力头共享相同的键和值,即除了不同的注意力头共享同一组键和值之外,其余部分与MHA相同。这大大减轻了键值缓存的负担,从而显著加快了解码器的推理速度。然而,MQA会导致质量下降和训练不稳定。


五 Grouped Query Attention --- (GQA)

分组查询注意力(GQA)通过在多头注意力(MHA)和多查询注意力(MQA)之间引入一定数量的查询头子组(少于总注意力头的数量),每个子组有一个单独的键头和值头,从而实现了一种插值。与MQA相比,随着模型规模的增加,GQA在内存带宽和容量上保持了相同比例的减少。中间数量的子组导致了一个插值模型,该模型的质量高于MQA但推理速度快于MHA。很明显,只有一个组的GQA等同于MQA。


六 PyTorch Implementing MHA, MQA, and GQA

num_kv_heads 和 num_heads 一样的时候就是 MHA

num_kv_heads= 1 就是MQA

num_kv_heads<num_heads 就是GQA

复制代码
# -*- coding: utf-8 -*-
"""
Created on Fri Feb 21 15:02:18 2025

@author: chengxf2
"""

import torch.nn as nn
import torch.nn.functional as F
import torch


def scaled_dot_product_attention(query, key, value, mask=None, dropout=None):
    # 获取 key 的维度大小,用于缩放
    d_k = query.size(-1)
    # 计算点积注意力得分
    scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
    # 如果提供了 mask,将其应用到得分上
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)
    # 对得分进行 softmax 操作,得到注意力权重
    p_attention = F.softmax(scores, dim=-1)
    # 如果提供了 dropout,应用 dropout
    if dropout is not None:
        p_attention = dropout(p_attention)
    # 使用注意力权重对 value 进行加权求和
    return torch.matmul(p_attention, value)


class  Attention(nn.Module):
    def __init__(self, d_model=512,num_heads=8, num_kv_heads=2,dropout=0.5):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim =  d_model//num_heads
        self.num_kv_heads = num_kv_heads
        assert self.num_heads%self.num_kv_heads == 0
        self.num_queries_per_kv = self.num_heads//self.num_kv_heads
        #Linear
        self.query =   nn.Linear(d_model, self.head_dim * self.num_heads)
        self.key =     nn.Linear(d_model,   self.head_dim * self.num_kv_heads)
        self.value =   nn.Linear(d_model, self.head_dim * self.num_kv_heads)
        #输出
        self.proj = nn.Linear(d_model, d_model)
        self.attn_dropout = nn.Dropout(dropout)
    
    def forward(self, inputs):
        
        batch, seq_len, d_model = inputs.shape
        q = self.query(inputs)
        k = self.key(inputs)
        v = self.value(inputs)
        # shape = (B, seq_len, num_heads, head_dim)
        q = q.view(batch, seq_len, -1,  self.head_dim)
        k = k.view(batch, seq_len, -1 , self.head_dim)  
        v = v.view(batch, seq_len, -1,  self.head_dim)
     
        print("default q.shape",q.shape)
        print("default k.shape",k.shape)
        print("default v.shape",v.shape)
        # Grouped Query Attention
        #[batch, seq_len, num_kv_heads, head_dim]->[batch, seq_len, num_heads, head_dim]
        if self.num_kv_heads != self.num_heads:
           k = torch.repeat_interleave(k, self.num_queries_per_kv, dim=2)
           v = torch.repeat_interleave(v, self.num_queries_per_kv, dim=2)
        # shape = (B, num_heads, seq_len, head_dim) 
        k = k.transpose(1, 2)  
        q = q.transpose(1, 2)
        v = v.transpose(1, 2)
        print("q.shape",q.shape)
        print("k.shape",k.shape)
        print("v.shape",v.shape)

        output = scaled_dot_product_attention(
            q,
            k,
            v,  # order impotent
            None,
            self.attn_dropout,
        )
        print("v.shape",v.shape)
        print("output.shape",output.shape)
        output = output.transpose(1, 2).contiguous().view(batch, seq_len, d_model)
        # final projection into the residual stream
        output = self.proj(output)
        return output
net = Attention()
batch_size =2
seq_len = 5
d_model =512
x = torch.randn(batch_size,seq_len, d_model)
net(x)

七 Multi-Head Latent Attention --- (MLA)

Multi-Head Latent Attention (MLA) achieves superior performance than MHA, as well as significantly reduces KV-cache boosting inference efficiency. Instead of reducing KV-heads as in MQA and GQA, MLA jointly compresses the Key and Value into a latent vector.

Low-Rank Key-Value Joint Compression

Instead of caching both the Key and Value matrices, MLA jointly compresses them in a low-rank vector which allows caching fewer items since the compression dimension is much less compared to the output projection matrix dimension in MHA.

Comparison of Deepseek's new Multi-latent head attention with MHA, MQA, and GQA.

参考:

https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf

缓存与效果的极限拉扯:从MHA、MQA、GQA到MLA - 科学空间|Scientific Spaces

DeepSeek-V3 Explained 1: Multi-head Latent Attention | Towards Data Science

https://medium.com/@zaiinn440/mha-vs-mqa-vs-gqa-vs-mla-c6cf8285bbec

https://medium.com/@zaiinn440/mha-vs-mqa-vs-gqa-vs-mla-c6cf8285bbec

deepseek技术解读(1)-彻底理解MLA(Multi-Head Latent Attention)-CSDN博客

怎么加快大模型推理?10分钟学懂VLLM内部原理,KV Cache,PageAttention_哔哩哔哩_bilibili

https://medium.com/@joaolages/kv-caching-explained-276520203249

相关推荐
阿坡RPA13 小时前
手搓MCP客户端&服务端:从零到实战极速了解MCP是什么?
人工智能·aigc
用户277844910499313 小时前
借助DeepSeek智能生成测试用例:从提示词到Excel表格的全流程实践
人工智能·python
机器之心13 小时前
刚刚,DeepSeek公布推理时Scaling新论文,R2要来了?
人工智能
算AI15 小时前
人工智能+牙科:临床应用中的几个问题
人工智能·算法
凯子坚持 c16 小时前
基于飞桨框架3.0本地DeepSeek-R1蒸馏版部署实战
人工智能·paddlepaddle
你觉得20516 小时前
哈尔滨工业大学DeepSeek公开课:探索大模型原理、技术与应用从GPT到DeepSeek|附视频与讲义下载方法
大数据·人工智能·python·gpt·学习·机器学习·aigc
8K超高清17 小时前
中国8K摄像机:科技赋能文化传承新图景
大数据·人工智能·科技·物联网·智能硬件
hyshhhh17 小时前
【算法岗面试题】深度学习中如何防止过拟合?
网络·人工智能·深度学习·神经网络·算法·计算机视觉
薛定谔的猫-菜鸟程序员17 小时前
零基础玩转深度神经网络大模型:从Hello World到AI炼金术-详解版(含:Conda 全面使用指南)
人工智能·神经网络·dnn
币之互联万物17 小时前
2025 AI智能数字农业研讨会在苏州启幕,科技助农与数据兴业成焦点
人工智能·科技