Scaled_dot_product_attention(SDPA)使用详解

在学习huggingFace的Transformer库时,我们不可避免会遇到scaled_dot_product_attention(SDPA)这个函数,它被用来加速大模型的Attention计算,本文就详细介绍一下它的使用方法,核心内容主要参考了torch.nn.functional中该函数的注释。

1. Attention计算公式

Attention的计算主要涉及三个矩阵:Q、K、V。我们先不考虑multi-head attention,只考虑one head的self attention。在大模型的prefill阶段,这三个矩阵的维度均为N x d,N即为上下文的长度;在decode阶段,Q的维度为1 x d, KV还是N x d。然后通过下面的公式计算attention矩阵:
O = A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d ) V O=Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt d})V O=Attention(Q,K,V)=softmax(d QKT)V

在真正使用attention的时候,我们往往采用multi-head attention(MHA)。MHA的计算公式和one head attention基本一致,它改变了Q、K、V每一行的定义:将维度d的向量分成h组变成一个h x dk的矩阵,Q、K、V此时成为了 N ∗ h ∗ d k N * h * d_k N∗h∗dk的三维矩阵(不考虑batch维)。分别将Q、K、V的第一和第二维进行转置得到三个维度为 h ∗ N ∗ d k h * N * d_k h∗N∗dk的三维矩阵。此时的三个矩阵就是具有h个头的Q、K、V,我们就可以按照self attention的定义计算h个头的attention值。

不过,在真正进行大模型推理的时候就会发现KV Cache是非常占显存的,所以大家尝试各种手段压缩KV Cache,具体可以参考《大模型推理--KV Cache压缩》。一种手段就是将MHA替换成group query attention(GQA),这块在torch2.5以上的SDPA中也已经得到了支持。

2. SDPA伪代码

在SDPA的注释中,给出了伪代码:

python 复制代码
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,
                is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:
	L, S = query.size(-2), key.size(-2)
	scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
	
    attn_bias = torch.zeros(L, S, dtype=query.dtype)
	if is_causal:
    	 assert attn_mask is None
         temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
         attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
         attn_bias.to(query.dtype)

    if attn_mask is not None:
    	 if attn_mask.dtype == torch.bool:
         	 attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
         else:
             attn_bias += attn_mask

    if enable_gqa:
    	 key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)
         value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    
	return attn_weight @ value

可以看出,我们实际在使用SDPA时除了query、key和value之外,还有另外几个参数:attn_mask、dropout_p、is_causal、scale和enable_gqa。scale就是计算Attention时的缩放因子,一般无需传递。dropout_p表示Dropout概率,在推理阶段也不需要传递,不过官方建议如下输入:dropout_p=(self.p if self.training else 0.0)。我们着重看一下另外三个参数在使用时该如何设置。

先看enable_gqa。前面提到GQA是一种KV Cache压缩方法,MHA的KV和Q一样,也会有h个头,GQA则将KV的h个头进行压缩来减小KV Cache的大小。比如Qwen2-7B-Instruct这个模型,Q的h等于28,KV的h等于4,相当于把KV Cache压缩到之前的七分之一。GQA虽然压缩了KV Cache,但是真正要计算Attention的时候还是需要对齐KV与Q的head数,所以我们可以看到HF Transformer库中的qwen2.py在Attention计算时会有一个repeat_kv的操作,目的就是将QKV的head数统一。在torch2.5以后的版本中,我们无需再手动去执行repeat_kv,直接将SDPA的enable_gqa设置为True即可自动完成repeat_kv,而且速度比自己去做repaet_kv还要更快。

attn_mask和is_causal两个参数的作用相同,目的都是要给softmax之前的QKT矩阵添加mask。只不过attn_mask是自己在外面构造mask矩阵,is_causal则是根据大模型推理的阶段属于prefill还是decode来进行设置。通过看伪代码可以看出,SDPA会首先构造一个L x S的零矩阵attn_bias,L表示Q的上下文长度,S表示KV Cache的长度。在prefill阶段,L和S相等,在decode阶段,L为1,S还是N。所以在prefill阶段,attn_bias就是一个N x N的矩阵,将is_causal设置为True时就会构造一个下三角为0,上三角为负无穷的矩阵作为attn_bias,然后将其加到QKT矩阵上,这样就实现了因果关系的Attention计算。在decode阶段,attn_bias就是一个1 x N的向量,此时可以将is_causal设置为False,attn_bias始终为0就不会对 Q K T QK^T QKT行向量产生影响,表示KV Cache所有的行都参与计算,因果关系保持正确。

attn_mask作用和is_causal一样,但是需要我们自行构造,如果你对如何构造不了解建议就使用is_causal选项,prefill阶段设置为True,decode阶段设置为False,attn_mask设置为None。不过,如果prefill按照chunk来执行也即chunk_prefill阶段,我们会发现is_causal设置为True时的attn_bias设置的不正确,我们不是从左上角开始构造下三角矩阵,而是要从右下角开始构造下三角矩阵,这种情况下我们可以从外面自行构造attn_mask矩阵代替SDPA的构造。attn_mask有两种构造方式,一种是bool类型,True的位置会保持不变,False的位置会置为负无穷;一种是float类型,会直接将attn_mask加到SDPA内部的attn_bias上,和bool类型一样,我们一般是构造一个下三角为0上三角为负无穷的矩阵。总结来说,绝大多数情况下我们只需要设置is_causal选项,prefill阶段设置为True,decode阶段设置为False,attn_mask设置为None即可。如果推理阶段引入了chunk_prefill,则我们需要自行构造attn_mask,但是要注意构造的attn_mask矩阵是从右下角开始的下三角矩阵。

3. SDPA实现(翻译自SDPA注释)

目前SDPA有三种实现:

  1. 基于FlashAttention-2的实现;
  2. Memory-Efficient Attention(facebook xformers);
  3. Pytorch版本对上述伪代码的c++实现(对应MATH后端)。

针对CUDA后端,SDPA可能会调用经过优化的内核以提高性能。对于所有其他后端,将使用PyTorch实现。所有实现方式默认都是启用的,SDPA会尝试根据输入自动选择最优的实现方式。为了对使用哪种实现方式提供更细粒度的控制,torch提供了以下函数来启用和禁用各种实现方式:

  1. torch.nn.attention.sdpa_kernel:一个上下文管理器,用于启用或禁用任何一种实现方式;
  2. torch.backends.cuda.enable_flash_sdp:全局启用或禁用FlashAttention
  3. torch.backends.cuda.enable_mem_efficient_sdp:全局启用或禁用memory efficient attention
  4. torch.backends.cuda.enable_math_sdp:全局启用或禁用PyTorch的C++实现。

每个融合内核都有特定的输入限制。如果用户需要使用特定的融合实现方式,请使用torch.nn.attention.sdpa_kernel 禁用PyTorch 的C++实现。如果某个融合实现方式不可用,将会发出警告,说明该融合实现方式无法运行的原因。由于融合浮点运算的特性,此函数的输出可能会因所选择的后端内核而异。C++ 实现支持torch.float64,当需要更高精度时可以使用。对于math后端,如果输入是torch.half或torch.bfloat16类型,那么所有中间计算结果都会保持为torch.float类型。

4. SDPA使用示例

首先强调一点,灌入SDPA的QKV都是做过转置的,也即维度为batch x head x N x d,在老版本的torch中还需要QKV都是contiguous的,新版本下无此要求。SDPA注释中还给了两个示例,我们在此也给出:

python 复制代码
# Optionally use the context manager to ensure one of the fused kernels is run
 query = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
 key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
 value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
 with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):
 	F.scaled_dot_product_attention(query,key,value)

上述示例中,给定的输入为batch等于32,head等于8,上下文长度128,embedding维度64,然后通过sdpa_kernel选择使用FlashAttention。

示例二:

python 复制代码
# Sample for GQA for llama3
query = torch.rand(32, 32, 128, 64, dtype=torch.float16, device="cuda")
key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
with sdpa_kernel(backends=[SDPBackend.MATH]):
	F.scaled_dot_product_attention(query,key,value,enable_gqa=True)

示例二演示了GQA的用法,给定的query head数为32,key和value均为8,此时我们可以通过enable_gqa选项来实现对GQA的支持,此外代码还通过sdpa_kernel选项使用了MATH后端。

5. 参考

  1. FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
  2. Memory-Efficient Attention
  3. Grouped-Query Attention
  4. Attention Is All You Need
相关推荐
京东零售技术10 分钟前
WWW2025论文解读【前瞻技术布局】京东零售广告创意:引入场域目标的创意图片生成
人工智能
GIS数据转换器12 分钟前
智慧城市运行管理服务平台建设方案
人工智能·智慧城市
jonyleek24 分钟前
【JVS更新日志】智能BI、智能排产、低代码、视频会议3.12更新说明!JVS-AI助手即将上线!
java·人工智能·低代码·团队开发·制造·软件需求·erp
带电的小王25 分钟前
【大模型基础_毛玉仁】2.3 基于 Encoder-only 架构的大语言模型
人工智能·语言模型·自然语言处理
北京青翼科技26 分钟前
【TES817】基于XCZU19EG FPGA的高性能实时信号处理平台
图像处理·人工智能·ai·fpga开发·信号处理
数智大号28 分钟前
Net5.5G引领未来:企业如何布局新一代互联网战略
人工智能
芷栀夏28 分钟前
零成本本地化搭建开源AI神器LocalAI支持CPU推理运行部署方案
linux·人工智能·ai·开源
蚝油菜花28 分钟前
MM-StoryAgent:交大阿里联合开源!多模态AI一键生成儿童故事绘本+配音
人工智能·开源
飞哥数智坊30 分钟前
Manus还没发,先复刻一套OpenManus吧!
人工智能
Coovally AI模型快速验证31 分钟前
何恺明团队新突破:用“物理直觉“重构AI视觉系统,去噪神经网络让机器看懂世界规律
人工智能·深度学习·神经网络·机器学习·计算机视觉·目标跟踪·重构