拆解 Transformer 的灵魂:全景解析 Attention 家族 (Self, Cross, Masked & GQA)

拆解 Transformer 的灵魂:全景解析 Attention 家族 (Self, Cross, Masked & GQA)

作者: Duang777
标签: Deep Learning Transformer LLM Internals PyTorch


在 2017 年之前,NLP 还是 LSTM 和 RNN 的天下。那时候我们处理长文本就像患了"健忘症",读了后面忘了前面。直到 Google 团队扔出了那篇《Attention Is All You Need》,整个 AI 世界被重塑了。

但"注意力"到底是什么?简单来说,注意力机制是一种资源分配方案。在处理海量信息时,它告诉模型:"别试图记住所有东西,只关注那些对当前任务最相关的信息。"

今天,我们将深入源码和原理,彻底搞懂 Q、K、V 的爱恨情仇。


1. 核心原语:Query, Key, Value (Q, K, V)

所有注意力机制的数学本质,都可以看作是一个**数据库查询(Database Retrieval)**过程。

想象你去图书馆找书:

  • Query ( Q Q Q):你手中的书单("我要找关于 Python 的书")
  • Key ( K K K):书脊上的标签("计算机科学/编程语言")
  • Value ( V V V):书里的实际内容
  • Score (权重) :你的书单 ( Q Q Q) 和 标签 ( K K K) 的匹配程度

注意力计算公式(Scaled Dot-Product Attention)

A t t e n t i o n ( Q , K , V ) = softmax ( Q K T d k ) V Attention(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

步骤详解:

  1. 点积 ( Q K T QK^T QKT):计算相似度。向量越相似,点积越大。
  2. 缩放 ( 1 d k \frac{1}{\sqrt{d_k}} dk 1):这是为了防止点积结果过大,导致 Softmax 进入饱和区(梯度消失)。
  3. Softmax:将分数归一化为概率分布(所有权重之和为 1)。
  4. 乘 V V V:根据权重,加权提取信息。

为什么需要缩放因子?

当维度 d k d_k dk 较大时,点积的值会变得很大。例如,假设向量的每个维度都在 [ − 1 , 1 ] [-1, 1] [−1,1] 范围内,那么 d k d_k dk 维向量的点积期望值大约为 d k d_k dk。当 d k = 512 d_k = 512 dk=512 时,点积值可能达到数百级别。

Softmax 函数在输入值较大时会进入饱和区,导致梯度接近 0,模型难以训练。除以 d k \sqrt{d_k} dk 可以将点积值控制在合理范围内,保持梯度的稳定性。


2. 家族族谱:注意力的不同形态

虽然公式一样,但在不同的场景下,Q、K、V 的来源不同,就构成了不同的注意力变体。

2.1 自注意力 (Self-Attention)

别名: 上帝视角

定义: Q , K , V Q, K, V Q,K,V 全部来自同一个输入源( X X X)。

作用: 让序列中的每一个词都能"看见"其他所有的词,并计算它们之间的关联。

例如在句子 "The animal didn't cross the street because it was too tired" 中,Self-Attention 帮助模型将 "it" 关联到 "animal" 而不是 "street"。

场景: Transformer 的 Encoder 内部,BERT。

2.2 掩码注意力 (Masked Self-Attention)

别名: 严守时间线的"预言家"

痛点: 在生成式任务(如 GPT)中,预测第 t t t 个词时,绝不能看到 t + 1 t+1 t+1 之后的词(否则就是作弊)。

实现: 在 Softmax 之前,将矩阵的右上三角区域(未来信息)填充为 − ∞ -\infty −∞。这样 Softmax 后这些位置的概率就变成了 0。

场景: GPT 系列,所有 Decoder-only 模型。

掩码矩阵示例:

python 复制代码
# 对于序列长度为 4 的情况,掩码矩阵如下:
[[1, 0, 0, 0],   # 位置 0 只能看到自己
 [1, 1, 0, 0],   # 位置 1 可以看到 0 和 1
 [1, 1, 1, 0],   # 位置 2 可以看到 0, 1, 2
 [1, 1, 1, 1]]   # 位置 3 可以看到所有位置

2.3 交叉注意力 (Cross-Attention)

别名: 翻译官

定义:

  • Query 来自 Decoder(当前正在生成的句子)
  • Key & Value 来自 Encoder(源句子的编码)

作用: 让生成部分去"查阅"源输入的信息。

场景: 机器翻译(Seq2Seq),Stable Diffusion(文本控制图像生成)。

工作流程:

  1. Encoder 处理源语言句子,生成编码表示(作为 K 和 V)
  2. Decoder 逐词生成目标语言句子(作为 Q)
  3. 在生成每个目标词时,通过 Cross-Attention 查询源语言的相关信息

3. 进阶架构:多头与分组 (Multi-Head & GQA)

为什么一个注意力头(Head)不够用?

如果只用一组 Q、K、V,模型可能只学到一种语义关系(比如语法结构)。我们希望模型能像用"多棱镜"一样看问题:

  • Head 1 关注指代关系(it -> animal)
  • Head 2 关注动宾关系(cross -> street)
  • Head 3 关注时态信息
  • Head 4 关注句法结构
  • Head 5 关注语义相似性
  • Head 6 关注位置关系

3.1 多头注意力 (Multi-Head Attention)

实现: 将 embedding 维度切割成 h h h 份,分别计算 Attention,最后拼接(Concat)起来。

数学表达:

MultiHead ( Q , K , V ) = Concat ( head 1 , ... , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO

其中:

head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) headi=Attention(QWiQ,KWiK,VWiV)

参数量: 假设模型维度为 d m o d e l d_{model} dmodel,头数为 h h h,每个头的维度为 d k = d m o d e l / h d_k = d_{model} / h dk=dmodel/h,则:

  • W i Q ∈ R d m o d e l × d k W_i^Q \in \mathbb{R}^{d_{model} \times d_k} WiQ∈Rdmodel×dk
  • W i K ∈ R d m o d e l × d k W_i^K \in \mathbb{R}^{d_{model} \times d_k} WiK∈Rdmodel×dk
  • W i V ∈ R d m o d e l × d k W_i^V \in \mathbb{R}^{d_{model} \times d_k} WiV∈Rdmodel×dk
  • W O ∈ R h d k × d m o d e l W^O \in \mathbb{R}^{hd_k \times d_{model}} WO∈Rhdk×dmodel

3.2 现代变体:MQA 与 GQA (显存优化的救星)

随着模型越来越大(上下文窗口达到 128k 甚至 1M),KV Cache 占用的显存成为了推理瓶颈。

MHA (Multi-Head Attention) - 标准版
  • 每个头都有自己的 Q, K, V
  • 显存占用最大
  • 性能最好
MQA (Multi-Query Attention) - 极端版
  • 所有头共享同一份 K 和 V,只有 Q 不同
  • 速度极快,但性能有损失
  • 用于 PaLM 等模型
GQA (Grouped-Query Attention) - 平衡版
  • 将头分成若干组,每组共享 K 和 V
  • 在性能和速度之间取得平衡
  • 用于 Llama 2/3, Mistral, Gemma 等现代模型

对比表:

架构 Q 头数 K/V 头数 显存占用 速度 性能
MHA h h 最大 最佳
MQA h 1 最小 最快 较差
GQA h g (1 < g < h) 中等 良好

4. PyTorch 实现详解

下面是一个完整的因果自注意力(Causal Self-Attention)实现,这是 GPT 类模型的核心组件。

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class CausalSelfAttention(nn.Module):
    def __init__(self, d_model, n_head, max_len=1024):
        super().__init__()
        assert d_model % n_head == 0, "Embedding dim must be divisible by n_head"
        
        self.d_head = d_model // n_head
        self.n_head = n_head
        
        # 1. 定义 Q, K, V 的线性投影
        # 通常我们将它们合并为一个大矩阵运算,效率更高
        self.c_attn = nn.Linear(d_model, 3 * d_model)
        
        # 2. 输出投影
        self.c_proj = nn.Linear(d_model, d_model)
        
        # 3. 注册 Mask (因果掩码 / 下三角矩阵)
        # 这里的 register_buffer 保证 mask 是模型状态的一部分但不是可训练参数
        self.register_buffer(
            "bias", 
            torch.tril(torch.ones(max_len, max_len))
                 .view(1, 1, max_len, max_len)
        )

    def forward(self, x):
        # x shape: (Batch_Size, Seq_Len, D_Model)
        B, T, C = x.size()
        
        # 1. 计算 Q, K, V
        # qkv shape: (B, T, 3 * D_Model)
        qkv = self.c_attn(x)
        
        # 分离 Q, K, V 并重塑维度为多头
        # shape 变换: (B, T, 3 * C) -> (B, T, n_head, 3 * d_head) -> (B, n_head, T, 3 * d_head)
        q, k, v = qkv.view(B, T, self.n_head, 3 * self.d_head).transpose(1, 2).split(self.d_head, dim=-1)
        
        # 2. 计算 Attention Score (Scaled Dot-Product)
        # (B, h, T, d) @ (B, h, d, T) -> (B, h, T, T)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        
        # 3. 应用 Mask (Causal Masking)
        # 将 Mask 为 0 的位置替换为 -inf
        att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
        
        # 4. Softmax 归一化
        att = F.softmax(att, dim=-1)
        
        # 5. 聚合 Value
        # (B, h, T, T) @ (B, h, T, d) -> (B, h, T, d)
        y = att @ v
        
        # 6. 拼接多头并输出
        # (B, h, T, d) -> (B, T, h, d) -> (B, T, C)
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        
        return self.c_proj(y)

# 测试代码
if __name__ == "__main__":
    # 假设 Batch=2, 序列长度=10, 维度=32, 4个头
    batch_size = 2
    seq_len = 10
    d_model = 32
    n_head = 4
    
    model = CausalSelfAttention(d_model, n_head)
    input_tensor = torch.randn(batch_size, seq_len, d_model)
    
    output = model(input_tensor)
    print(f"Input shape: {input_tensor.shape}")
    print(f"Output shape: {output.shape}")  # 应该与输入保持一致

代码解析重点

1. 维度变换 (view & transpose)

这是实现多头的关键。我们将 (Batch, Time, Dim) 变成了 (Batch, Head, Time, Head_Dim),这样 PyTorch 就会把 Head 维度视为 Batch 的一部分进行并行计算。

python 复制代码
# 原始形状: (B, T, 3*C)
# view 后: (B, T, h, 3*d_head)
# transpose 后: (B, h, T, 3*d_head)
# split 后: q, k, v 各为 (B, h, T, d_head)
2. masked_fill

这就是 GPT 不能"偷看"未来的核心实现。通过将未来位置填充为 -inf,Softmax 后这些位置的概率变为 0。

3. contiguous

由于 transpose 打乱了内存布局,必须调用 contiguous() 才能进行下一次 view 操作。这是因为 PyTorch 的张量在内存中是连续存储的,某些操作(如 transpose)会创建视图而非副本,导致内存不连续。

4. 效率优化

将 Q, K, V 的三个线性层合并为一个 c_attn 层(输出 3*d_model 维度),这样可以利用 GPU 的并行计算能力,减少 kernel 调用次数。


5. KV Cache:推理加速的关键

在生成式推理中,我们是一个词一个词生成的。如果不使用 KV Cache,每次生成新词时都需要重新计算之前所有位置的 K 和 V,造成巨大的计算浪费。

KV Cache 原理

python 复制代码
# 假设已经生成了前 t 个词,现在要生成第 t+1 个词
# 不使用 KV Cache: 需要计算所有 0 到 t 位置的 K, V
# 使用 KV Cache: 只需计算第 t 位置的 K, V,然后从缓存中读取 0 到 t-1 的 K, V

# 伪代码示例
class CachedAttention:
    def __init__(self):
        self.k_cache = []  # 缓存所有历史 K
        self.v_cache = []  # 缓存所有历史 V
    
    def forward(self, x):
        # 只计算当前 token 的 Q, K, V
        q, k, v = self.compute_qkv(x)
        
        # 将 K, V 加入缓存
        self.k_cache.append(k)
        self.v_cache.append(v)
        
        # 使用缓存的 K, V 计算注意力
        k_all = torch.cat(self.k_cache, dim=1)
        v_all = torch.cat(self.v_cache, dim=1)
        
        att = (q @ k_all.transpose(-2, -1)) / math.sqrt(k.size(-1))
        # ... 后续计算

显存占用分析

KV Cache 的显存占用与序列长度和模型大小成正比:

显存占用 = 2 × Batch × Layers × Heads × SeqLen × d h e a d × Bytes \text{显存占用} = 2 \times \text{Batch} \times \text{Layers} \times \text{Heads} \times \text{SeqLen} \times d_{head} \times \text{Bytes} 显存占用=2×Batch×Layers×Heads×SeqLen×dhead×Bytes

这就是为什么 GQA(分组查询注意力)如此重要------它通过减少 K/V 的头数来大幅降低 KV Cache 的显存占用。


6. 总结:Attention 机制速查表

类型 查询源 (Q) 键值源 (K, V) 掩码 (Mask) 典型应用
Self-Attention 输入 X 输入 X BERT, ViT, Encoder
Masked Self-Attention 输入 X 输入 X 有 (下三角) GPT, Llama, Decoder
Cross-Attention Decoder 输出 Encoder 输出 Transformer 翻译, Stable Diffusion
Grouped Query (GQA) 独立 Heads 分组共享 取决于场景 Llama 3, Mistral (推理加速)

7. 思考

Attention 机制虽然强大,但它是 O ( N 2 ) O(N^2) O(N2) 的复杂度(序列长度翻倍,计算量翻四倍)。这就是为什么 2024 年大家都在卷:

  • FlashAttention:通过平铺 IO 优化读写,减少内存访问次数
  • Linear Attention :试图打破平方级诅咒,将复杂度降至 O ( N ) O(N) O(N)
  • Ring Attention:支持超长上下文(百万级 token)
  • Mamba/SSM:状态空间模型,试图完全取代 Transformer 的注意力机制

但在新的架构完全成熟之前,Attention 依然是 AI 领域的"牛顿定律"------它可能不是最终的答案,但它是目前最可靠、最强大的基础。

未来展望

  1. 更长的上下文:从 128k 到 1M 甚至更长
  2. 更高效的计算:FlashAttention 2.0, 3.0 持续优化
  3. 更灵活的架构:混合注意力、稀疏注意力
  4. 更低的延迟:KV Cache 压缩、量化、蒸馏

8. 推荐阅读

相关推荐
磊-2 小时前
AI Agent 学习计划(一)
人工智能·学习
不会打球的摄影师不是好程序员2 小时前
dify实战-个人知识库搭建
人工智能
xixixi777772 小时前
对 两种不同AI范式——Transformer 和 LSTM 进行解剖和对比
人工智能·深度学习·大模型·lstm·transformer·智能·前沿
lfPCB2 小时前
聚焦机器人算力散热:PCB液冷集成的工程化现实阻碍
人工智能·机器人
sunxunyong2 小时前
CC-Ralph实测
人工智能·自然语言处理
IT_陈寒2 小时前
Vite 5分钟性能优化实战:从3秒到300ms的冷启动提速技巧(附可复用配置)
前端·人工智能·后端
十六年开源服务商2 小时前
WordPress集成GoogleAnalytics最佳实践指南
前端·人工智能·机器学习
市象2 小时前
石头把科技摔掉了
人工智能
子午2 小时前
【2026原创】水稻植物病害识别系统~Python+深度学习+人工智能+resnet50算法+TensorFlow+图像识别
人工智能·python·深度学习