scaled_dot_product_attention实现

一、先明确 SDPA 的核心原理

SDPA 是 Transformer 注意力机制的核心,公式如下:

关键要素:

  • 缩放(Scaled):除以(\sqrt{d_k})((d_k)是每个 head 的维度),避免(QK^T)数值过大导致 softmax 饱和;
  • 点积(Dot-Product):Q 和 K 的转置做点积,计算注意力分数;
  • 掩码(Mask):支持 padding 掩码 / 因果掩码,过滤无效 token 或未来 token;
  • Softmax:将注意力分数归一化为概率分布;
  • 加权求和:用归一化的分数对 V 加权,得到上下文感知的输出。

二、手动实现 SDPA(理解核心逻辑)

以下是纯 PyTorch 手动实现的 SDPA,包含缩放、注意力掩码、因果掩码核心逻辑,注释详细且适配新手理解:

python 复制代码
import torch
import torch.nn.functional as F

def scaled_dot_product_attention_manual(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    attn_mask: torch.Tensor = None,  # padding掩码:[batch_size, seq_len_q, seq_len_k]
    is_causal: bool = False,         # 是否启用因果掩码(下三角)
    dropout_p: float = 0.0           # dropout概率
) -> torch.Tensor:
    """
    手动实现SDPA,参数与PyTorch原生API对齐
    参数说明:
    - q: [batch_size, num_heads, seq_len_q, head_dim]  查询向量
    - k: [batch_size, num_heads, seq_len_k, head_dim]  键向量
    - v: [batch_size, num_heads, seq_len_k, head_dim]  值向量
    - attn_mask: 注意力掩码(0=有效,-inf=无效),None则无掩码
    - is_causal: 是否启用因果掩码(仅看前文)
    - dropout_p: dropout概率,0则不使用
    """
    # 1. 获取head_dim,计算缩放因子
    head_dim = q.size(-1)
    scale = torch.sqrt(torch.tensor(head_dim, dtype=torch.float32))
    
    # 2. 计算QK^T(点积)并缩放
    # Q: [bs, n_head, len_q, d_k] → K^T: [bs, n_head, d_k, len_k] → attn_scores: [bs, n_head, len_q, len_k]
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / scale
    
    # 3. 应用因果掩码(如果启用)
    if is_causal:
        # 生成下三角因果掩码:len_q × len_k,未来位置设为-∞
        causal_mask = torch.tril(torch.ones(q.size(-2), k.size(-2), dtype=torch.bool)).to(q.device)
        attn_scores = attn_scores.masked_fill(~causal_mask, float('-inf'))
    
    # 4. 应用注意力掩码(padding掩码/自定义掩码)
    if attn_mask is not None:
        # 适配掩码维度:如果是2D([bs, len_k]),扩展为4D([bs, 1, 1, len_k])
        if attn_mask.dim() == 2:
            attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)  # [bs, 1, 1, len_k]
        attn_scores = attn_scores + attn_mask  # 无效位置(-inf)叠加
    
    # 5. Softmax归一化,得到注意力权重
    attn_weights = F.softmax(attn_scores, dim=-1)
    
    # 6. 应用dropout(可选)
    if dropout_p > 0.0:
        attn_weights = F.dropout(attn_weights, p=dropout_p)
    
    # 7. 加权求和V,得到最终注意力输出
    attn_output = torch.matmul(attn_weights, v)  # [bs, n_head, len_q, d_k]
    
    return attn_output

手动实现的测试示例(模拟 Qwen3 单 head 场景)

python 复制代码
# 模拟Qwen3-7B的单head输入:batch_size=1,num_heads=1,seq_len=5,head_dim=128
bs, n_head, seq_len, head_dim = 1, 1, 5, 128
q = torch.randn(bs, n_head, seq_len, head_dim).to("cuda")
k = torch.randn(bs, n_head, seq_len, head_dim).to("cuda")
v = torch.randn(bs, n_head, seq_len, head_dim).to("cuda")

# 测试1:启用因果掩码(模拟自回归生成)
output_causal = scaled_dot_product_attention_manual(q, k, v, is_causal=True)
print("因果掩码输出形状:", output_causal.shape)  # torch.Size([1, 1, 5, 128])

# 测试2:添加padding掩码(模拟批量输入)
padding_mask = torch.tensor([[1,1,1,0,0]]).to("cuda")  # 后2个token是padding
padding_mask = padding_mask.masked_fill(padding_mask == 0, float('-inf'))  # 0→-inf
output_padding = scaled_dot_product_attention_manual(q, k, v, attn_mask=padding_mask)
print("Padding掩码输出形状:", output_padding.shape)  # torch.Size([1, 1, 5, 128])

mask作用

mask主要是屏蔽掉attention矩阵无效的权重,

  • 比如说padding值(attention_mask得来)
  • 防止前面的字符看到后面字符的值(casual_mask矩阵得来)
相关推荐
Luhui Dev2 小时前
2025 开源大模型生态回顾一览
人工智能·开源
CoderJia程序员甲2 小时前
GitHub 热榜项目 - 日榜(2025-12-26)
开源·大模型·llm·github·ai教程
木头左2 小时前
LSTM量化交易策略的环境适应性与入参稳定性评估
人工智能·rnn·lstm
longfei.li2 小时前
AI项目工程化落地如何降本30%?
人工智能·语言模型
燕双嘤2 小时前
LLM:RAG,设计模式,Agent框架
人工智能·机器学习·设计模式
汉克老师2 小时前
小学生0基础学大语言模型应用(第4课 《数字盒子与算数魔法》)
人工智能·语言模型·自然语言处理·小学生0基础学习大语言模型
雅欣鱼子酱2 小时前
Type-C受电端芯片ECP5702演示:串口发送电压电流,给外部MCU读取
c语言·人工智能·单片机·嵌入式硬件·芯片·电子元器件
ECT-OS-JiuHuaShan2 小时前
麻烦是第一推动力,不厌其烦就是负熵流
开发语言·人工智能·数学建模·学习方法·量子计算
AI大模型2 小时前
24页 大语言模型(LLM)入门指南:从核心定义、训练三步法到 Llama 3.1 实操部署
程序员·llm·agent