拆解 Transformer 的灵魂:全景解析 Attention 家族 (Self, Cross, Masked & GQA)
作者: Duang777
标签: Deep Learning Transformer LLM Internals PyTorch
在 2017 年之前,NLP 还是 LSTM 和 RNN 的天下。那时候我们处理长文本就像患了"健忘症",读了后面忘了前面。直到 Google 团队扔出了那篇《Attention Is All You Need》,整个 AI 世界被重塑了。
但"注意力"到底是什么?简单来说,注意力机制是一种资源分配方案。在处理海量信息时,它告诉模型:"别试图记住所有东西,只关注那些对当前任务最相关的信息。"
今天,我们将深入源码和原理,彻底搞懂 Q、K、V 的爱恨情仇。
1. 核心原语:Query, Key, Value (Q, K, V)
所有注意力机制的数学本质,都可以看作是一个**数据库查询(Database Retrieval)**过程。
想象你去图书馆找书:
- Query ( Q Q Q):你手中的书单("我要找关于 Python 的书")
- Key ( K K K):书脊上的标签("计算机科学/编程语言")
- Value ( V V V):书里的实际内容
- Score (权重) :你的书单 ( Q Q Q) 和 标签 ( K K K) 的匹配程度
注意力计算公式(Scaled Dot-Product Attention)
A t t e n t i o n ( Q , K , V ) = softmax ( Q K T d k ) V Attention(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
步骤详解:
- 点积 ( Q K T QK^T QKT):计算相似度。向量越相似,点积越大。
- 缩放 ( 1 d k \frac{1}{\sqrt{d_k}} dk 1):这是为了防止点积结果过大,导致 Softmax 进入饱和区(梯度消失)。
- Softmax:将分数归一化为概率分布(所有权重之和为 1)。
- 乘 V V V:根据权重,加权提取信息。
为什么需要缩放因子?
当维度 d k d_k dk 较大时,点积的值会变得很大。例如,假设向量的每个维度都在 [ − 1 , 1 ] [-1, 1] [−1,1] 范围内,那么 d k d_k dk 维向量的点积期望值大约为 d k d_k dk。当 d k = 512 d_k = 512 dk=512 时,点积值可能达到数百级别。
Softmax 函数在输入值较大时会进入饱和区,导致梯度接近 0,模型难以训练。除以 d k \sqrt{d_k} dk 可以将点积值控制在合理范围内,保持梯度的稳定性。
2. 家族族谱:注意力的不同形态
虽然公式一样,但在不同的场景下,Q、K、V 的来源不同,就构成了不同的注意力变体。
2.1 自注意力 (Self-Attention)
别名: 上帝视角
定义: Q , K , V Q, K, V Q,K,V 全部来自同一个输入源( X X X)。
作用: 让序列中的每一个词都能"看见"其他所有的词,并计算它们之间的关联。
例如在句子 "The animal didn't cross the street because it was too tired" 中,Self-Attention 帮助模型将 "it" 关联到 "animal" 而不是 "street"。
场景: Transformer 的 Encoder 内部,BERT。
2.2 掩码注意力 (Masked Self-Attention)
别名: 严守时间线的"预言家"
痛点: 在生成式任务(如 GPT)中,预测第 t t t 个词时,绝不能看到 t + 1 t+1 t+1 之后的词(否则就是作弊)。
实现: 在 Softmax 之前,将矩阵的右上三角区域(未来信息)填充为 − ∞ -\infty −∞。这样 Softmax 后这些位置的概率就变成了 0。
场景: GPT 系列,所有 Decoder-only 模型。
掩码矩阵示例:
python
# 对于序列长度为 4 的情况,掩码矩阵如下:
[[1, 0, 0, 0], # 位置 0 只能看到自己
[1, 1, 0, 0], # 位置 1 可以看到 0 和 1
[1, 1, 1, 0], # 位置 2 可以看到 0, 1, 2
[1, 1, 1, 1]] # 位置 3 可以看到所有位置
2.3 交叉注意力 (Cross-Attention)
别名: 翻译官
定义:
- Query 来自 Decoder(当前正在生成的句子)
- Key & Value 来自 Encoder(源句子的编码)
作用: 让生成部分去"查阅"源输入的信息。
场景: 机器翻译(Seq2Seq),Stable Diffusion(文本控制图像生成)。
工作流程:
- Encoder 处理源语言句子,生成编码表示(作为 K 和 V)
- Decoder 逐词生成目标语言句子(作为 Q)
- 在生成每个目标词时,通过 Cross-Attention 查询源语言的相关信息
3. 进阶架构:多头与分组 (Multi-Head & GQA)
为什么一个注意力头(Head)不够用?
如果只用一组 Q、K、V,模型可能只学到一种语义关系(比如语法结构)。我们希望模型能像用"多棱镜"一样看问题:
- Head 1 关注指代关系(it -> animal)
- Head 2 关注动宾关系(cross -> street)
- Head 3 关注时态信息
- Head 4 关注句法结构
- Head 5 关注语义相似性
- Head 6 关注位置关系
3.1 多头注意力 (Multi-Head Attention)
实现: 将 embedding 维度切割成 h h h 份,分别计算 Attention,最后拼接(Concat)起来。
数学表达:
MultiHead ( Q , K , V ) = Concat ( head 1 , ... , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \dots, \text{head}_h)W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO
其中:
head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) headi=Attention(QWiQ,KWiK,VWiV)
参数量: 假设模型维度为 d m o d e l d_{model} dmodel,头数为 h h h,每个头的维度为 d k = d m o d e l / h d_k = d_{model} / h dk=dmodel/h,则:
- W i Q ∈ R d m o d e l × d k W_i^Q \in \mathbb{R}^{d_{model} \times d_k} WiQ∈Rdmodel×dk
- W i K ∈ R d m o d e l × d k W_i^K \in \mathbb{R}^{d_{model} \times d_k} WiK∈Rdmodel×dk
- W i V ∈ R d m o d e l × d k W_i^V \in \mathbb{R}^{d_{model} \times d_k} WiV∈Rdmodel×dk
- W O ∈ R h d k × d m o d e l W^O \in \mathbb{R}^{hd_k \times d_{model}} WO∈Rhdk×dmodel
3.2 现代变体:MQA 与 GQA (显存优化的救星)
随着模型越来越大(上下文窗口达到 128k 甚至 1M),KV Cache 占用的显存成为了推理瓶颈。
MHA (Multi-Head Attention) - 标准版
- 每个头都有自己的 Q, K, V
- 显存占用最大
- 性能最好
MQA (Multi-Query Attention) - 极端版
- 所有头共享同一份 K 和 V,只有 Q 不同
- 速度极快,但性能有损失
- 用于 PaLM 等模型
GQA (Grouped-Query Attention) - 平衡版
- 将头分成若干组,每组共享 K 和 V
- 在性能和速度之间取得平衡
- 用于 Llama 2/3, Mistral, Gemma 等现代模型
对比表:
| 架构 | Q 头数 | K/V 头数 | 显存占用 | 速度 | 性能 |
|---|---|---|---|---|---|
| MHA | h | h | 最大 | 慢 | 最佳 |
| MQA | h | 1 | 最小 | 最快 | 较差 |
| GQA | h | g (1 < g < h) | 中等 | 快 | 良好 |
4. PyTorch 实现详解
下面是一个完整的因果自注意力(Causal Self-Attention)实现,这是 GPT 类模型的核心组件。
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class CausalSelfAttention(nn.Module):
def __init__(self, d_model, n_head, max_len=1024):
super().__init__()
assert d_model % n_head == 0, "Embedding dim must be divisible by n_head"
self.d_head = d_model // n_head
self.n_head = n_head
# 1. 定义 Q, K, V 的线性投影
# 通常我们将它们合并为一个大矩阵运算,效率更高
self.c_attn = nn.Linear(d_model, 3 * d_model)
# 2. 输出投影
self.c_proj = nn.Linear(d_model, d_model)
# 3. 注册 Mask (因果掩码 / 下三角矩阵)
# 这里的 register_buffer 保证 mask 是模型状态的一部分但不是可训练参数
self.register_buffer(
"bias",
torch.tril(torch.ones(max_len, max_len))
.view(1, 1, max_len, max_len)
)
def forward(self, x):
# x shape: (Batch_Size, Seq_Len, D_Model)
B, T, C = x.size()
# 1. 计算 Q, K, V
# qkv shape: (B, T, 3 * D_Model)
qkv = self.c_attn(x)
# 分离 Q, K, V 并重塑维度为多头
# shape 变换: (B, T, 3 * C) -> (B, T, n_head, 3 * d_head) -> (B, n_head, T, 3 * d_head)
q, k, v = qkv.view(B, T, self.n_head, 3 * self.d_head).transpose(1, 2).split(self.d_head, dim=-1)
# 2. 计算 Attention Score (Scaled Dot-Product)
# (B, h, T, d) @ (B, h, d, T) -> (B, h, T, T)
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
# 3. 应用 Mask (Causal Masking)
# 将 Mask 为 0 的位置替换为 -inf
att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
# 4. Softmax 归一化
att = F.softmax(att, dim=-1)
# 5. 聚合 Value
# (B, h, T, T) @ (B, h, T, d) -> (B, h, T, d)
y = att @ v
# 6. 拼接多头并输出
# (B, h, T, d) -> (B, T, h, d) -> (B, T, C)
y = y.transpose(1, 2).contiguous().view(B, T, C)
return self.c_proj(y)
# 测试代码
if __name__ == "__main__":
# 假设 Batch=2, 序列长度=10, 维度=32, 4个头
batch_size = 2
seq_len = 10
d_model = 32
n_head = 4
model = CausalSelfAttention(d_model, n_head)
input_tensor = torch.randn(batch_size, seq_len, d_model)
output = model(input_tensor)
print(f"Input shape: {input_tensor.shape}")
print(f"Output shape: {output.shape}") # 应该与输入保持一致
代码解析重点
1. 维度变换 (view & transpose)
这是实现多头的关键。我们将 (Batch, Time, Dim) 变成了 (Batch, Head, Time, Head_Dim),这样 PyTorch 就会把 Head 维度视为 Batch 的一部分进行并行计算。
python
# 原始形状: (B, T, 3*C)
# view 后: (B, T, h, 3*d_head)
# transpose 后: (B, h, T, 3*d_head)
# split 后: q, k, v 各为 (B, h, T, d_head)
2. masked_fill
这就是 GPT 不能"偷看"未来的核心实现。通过将未来位置填充为 -inf,Softmax 后这些位置的概率变为 0。
3. contiguous
由于 transpose 打乱了内存布局,必须调用 contiguous() 才能进行下一次 view 操作。这是因为 PyTorch 的张量在内存中是连续存储的,某些操作(如 transpose)会创建视图而非副本,导致内存不连续。
4. 效率优化
将 Q, K, V 的三个线性层合并为一个 c_attn 层(输出 3*d_model 维度),这样可以利用 GPU 的并行计算能力,减少 kernel 调用次数。
5. KV Cache:推理加速的关键
在生成式推理中,我们是一个词一个词生成的。如果不使用 KV Cache,每次生成新词时都需要重新计算之前所有位置的 K 和 V,造成巨大的计算浪费。
KV Cache 原理
python
# 假设已经生成了前 t 个词,现在要生成第 t+1 个词
# 不使用 KV Cache: 需要计算所有 0 到 t 位置的 K, V
# 使用 KV Cache: 只需计算第 t 位置的 K, V,然后从缓存中读取 0 到 t-1 的 K, V
# 伪代码示例
class CachedAttention:
def __init__(self):
self.k_cache = [] # 缓存所有历史 K
self.v_cache = [] # 缓存所有历史 V
def forward(self, x):
# 只计算当前 token 的 Q, K, V
q, k, v = self.compute_qkv(x)
# 将 K, V 加入缓存
self.k_cache.append(k)
self.v_cache.append(v)
# 使用缓存的 K, V 计算注意力
k_all = torch.cat(self.k_cache, dim=1)
v_all = torch.cat(self.v_cache, dim=1)
att = (q @ k_all.transpose(-2, -1)) / math.sqrt(k.size(-1))
# ... 后续计算
显存占用分析
KV Cache 的显存占用与序列长度和模型大小成正比:
显存占用 = 2 × Batch × Layers × Heads × SeqLen × d h e a d × Bytes \text{显存占用} = 2 \times \text{Batch} \times \text{Layers} \times \text{Heads} \times \text{SeqLen} \times d_{head} \times \text{Bytes} 显存占用=2×Batch×Layers×Heads×SeqLen×dhead×Bytes
这就是为什么 GQA(分组查询注意力)如此重要------它通过减少 K/V 的头数来大幅降低 KV Cache 的显存占用。
6. 总结:Attention 机制速查表
| 类型 | 查询源 (Q) | 键值源 (K, V) | 掩码 (Mask) | 典型应用 |
|---|---|---|---|---|
| Self-Attention | 输入 X | 输入 X | 无 | BERT, ViT, Encoder |
| Masked Self-Attention | 输入 X | 输入 X | 有 (下三角) | GPT, Llama, Decoder |
| Cross-Attention | Decoder 输出 | Encoder 输出 | 无 | Transformer 翻译, Stable Diffusion |
| Grouped Query (GQA) | 独立 Heads | 分组共享 | 取决于场景 | Llama 3, Mistral (推理加速) |
7. 思考
Attention 机制虽然强大,但它是 O ( N 2 ) O(N^2) O(N2) 的复杂度(序列长度翻倍,计算量翻四倍)。这就是为什么 2024 年大家都在卷:
- FlashAttention:通过平铺 IO 优化读写,减少内存访问次数
- Linear Attention :试图打破平方级诅咒,将复杂度降至 O ( N ) O(N) O(N)
- Ring Attention:支持超长上下文(百万级 token)
- Mamba/SSM:状态空间模型,试图完全取代 Transformer 的注意力机制
但在新的架构完全成熟之前,Attention 依然是 AI 领域的"牛顿定律"------它可能不是最终的答案,但它是目前最可靠、最强大的基础。
未来展望
- 更长的上下文:从 128k 到 1M 甚至更长
- 更高效的计算:FlashAttention 2.0, 3.0 持续优化
- 更灵活的架构:混合注意力、稀疏注意力
- 更低的延迟:KV Cache 压缩、量化、蒸馏