在当今主流大模型(LLM)架构中,**注意力机制(Attention Mechanism)**是最核心的计算模块之一。无论是文本生成模型、视觉模型,还是多模态模型,几乎都建立在以注意力为基础的结构之上。自从 Ashish Vaswani 等人在 2017 年提出 Google 的论文《Attention Is All You Need》以来,Transformer 架构已成为大模型的标准范式。
本文将系统总结注意力机制的原理、结构演化与工程优化方向。
一、为什么需要注意力机制?
在传统序列模型(如 RNN、LSTM)中,模型需要将前文压缩到一个固定长度的隐状态中,这种"信息瓶颈"会导致:
- 长距离依赖难以建模
- 梯度消失或爆炸
- 并行计算效率低
注意力机制的核心思想是:
在处理当前 token 时,动态地关注序列中的不同位置,并为其分配不同权重。
换句话说,模型不再依赖单一隐向量,而是对所有历史信息进行加权聚合。
二、注意力的基本计算公式
1. Query / Key / Value
注意力机制的输入通常是三组向量:
- Query (Q):当前查询向量
- Key (K):用于匹配的索引向量
- Value (V):实际信息载体
2. Scaled Dot-Product Attention
标准注意力公式为:
Attention(Q,K,V)=softmax(QKTdk)V Attention(Q,K,V)=softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
解释:
- Q 与 K 做点积,计算相似度
- 除以 dk\sqrt{d_k}dk 做缩放(防止梯度过大)
- 经 softmax 得到注意力权重
- 对 V 加权求和
本质上,这是一个可学习的加权平均机制。
PyTorch 实现(基础版)
python
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V, mask=None):
"""
Q, K, V: (batch, heads, seq_len, dim)
"""
d_k = Q.size(-1)
# 1. 计算相似度
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
# 2. mask(用于因果或padding)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# 3. softmax
attn_weights = F.softmax(scores, dim=-1)
# 4. 加权求和
output = torch.matmul(attn_weights, V)
return output, attn_weights
三、自注意力(Self-Attention)
在大模型中,最重要的是 Self-Attention。
特点:
- Q、K、V 来自同一个序列
- 每个 token 都可以"看到"其他 token
- 能直接建模长距离依赖
例如:
"The animal didn't cross the street because it was tired."
模型可以通过注意力将 it 指向 animal。
四、多头注意力(Multi-Head Attention)
单头注意力可能只捕获某一类关系,因此 Transformer 引入:
MultiHead(Q,K,V)=Concat(head1,...,headh)WO MultiHead(Q,K,V)=Concat(head_1,...,head_h)W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO
每个 head 在不同的线性空间中计算注意力。
优势:
- 不同 head 关注不同语义模式
- 增强表达能力
- 类似"多视角观察"
Multi-Head Attention 实现
python
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
assert d_model % num_heads == 0
self.num_heads = num_heads
self.head_dim = d_model // num_heads
self.qkv_proj = nn.Linear(d_model, 3 * d_model)
self.out_proj = nn.Linear(d_model, d_model)
def forward(self, x, mask=None):
B, T, C = x.shape
qkv = self.qkv_proj(x)
qkv = qkv.reshape(B, T, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
Q, K, V = qkv[0], qkv[1], qkv[2]
out, attn = scaled_dot_product_attention(Q, K, V, mask)
out = out.transpose(1, 2).reshape(B, T, C)
return self.out_proj(out)
五、注意力优化方向
所有注意力优化,本质都围绕三件事:
- 少算(减少计算量)
- 少存(减少显存占用)
- 让长度变长(支持更长上下文)
方向一:少算(降低计算复杂度)
🎯 目标
把 Attention 的计算复杂度从:
O(n2)→O(nlogn) 或 O(n) O(n^2) \rightarrow O(n\log n) \text{ 或 } O(n) O(n2)→O(nlogn) 或 O(n)
1️⃣ 稀疏注意力(Sparse Attention)
代表模型:
- Longformer
- BigBird
核心原理
不再计算完整 N×N 注意力矩阵,而是只计算:
- 局部窗口
- 块内 attention
- 少量全局 token
📊 原理图
标准 Attention(满矩阵)
█ █ █ █ █
█ █ █ █ █
█ █ █ █ █
█ █ █ █ █
█ █ █ █ █
稀疏 Attention(带状 + 全局)
█ █ . . .
█ █ █ . .
. █ █ █ .
. . █ █ █
. . . █ █
. 表示不计算。
✅ 优点
- 理论复杂度下降
- 显存占用减少
❌ 缺点
- 远距离依赖能力下降
- 稀疏模式难设计
- GPU 不友好
🏭 落地情况
- 文档任务、长文本建模
- ❌ 主流 LLM 已基本不用纯稀疏结构
2️⃣ 线性注意力(Linear Attention)
核心原理
把:
softmax(QKT)V softmax(QK^T)V softmax(QKT)V
改写为:
(Qϕ(K)T)(ϕ(V)) (Q\phi(K)^T)(\phi(V)) (Qϕ(K)T)(ϕ(V))
避免构造 N×N 矩阵。
📊 原理图
标准流程
Q × Kᵀ → N×N矩阵 → softmax → ×V
线性 Attention
φ(K)ᵀ × φ(V) → 中间聚合
Q × 上述结果 → 输出
等价于:
先把序列维度"压掉"
✅ 优点
- 理论 O(n)
- 超长序列潜力大
❌ 缺点
- softmax 被近似
- 精度下降明显
- 训练不稳定
🏭 落地情况
- 学术研究活跃
- ❌ 商业 LLM 几乎不用
方向二:少存(降低显存 & IO)
这是当前工业界真正成功的方向。
1️⃣ FlashAttention
提出团队:Stanford University
核心原理
不改变数学结构,而是改变执行方式:
- 分块计算
- 不存完整 attention matrix
- 在 SRAM 内做 softmax
FlashAttention 的"图"和前面完全不一样
因为它 结构没变,只是计算方式变了
逻辑结构(和标准 Attention 一样)
Q × Kᵀ → softmax → × V
物理执行图(关键)
HBM (显存)
├─ Load Q block
├─ Load K block
├─ Load V block
└─ Compute softmax + V inside SRAM
↓
Write Output
Blocked Attention 示意
[Q1] x [K1] → partial softmax → partial output
[Q1] x [K2] → partial softmax → accumulate
[Q1] x [K3] → ...
减少了什么?
- ❌ 没减少计算量
- ✅ 大幅减少显存 IO
- ✅ 不存整张 attention matrix
为什么是王者?
- 数学完全等价
- 精度 0 损失
- GPU 友好
落地
- PyTorch 2.x 默认
- 所有主流 LLM
2️⃣ KV Cache(推理期核心优化)
无 KV Cache(每步都重算)
Step t:
[Token1 ... Token t] → Q K V → Attention
有 KV Cache
Cache:
K1 K2 K3 ... K(t-1)
V1 V2 V3 ... V(t-1)
Step t:
Qt × [K1...K(t-1)] → output
示意图
┌───────────────┐
New Q ──▶│ KV Cache │──▶ Attention
│ K1 K2 ... Kt │
│ V1 V2 ... Vt │
└───────────────┘
减少了什么?
- ❌ 历史 K/V 不再重复计算
代价
- KV cache 显存线性增长
落地
- 所有 GPT 类模型
- 推理必备
3️⃣ MQA / GQA(减少 KV 的"宽度")
代表模型:
- LLaMA
标准 Multi-Head
Q1 K1 V1
Q2 K2 V2
Q3 K3 V3
Q4 K4 V4
MQA(极端)
Q1 ┐
Q2 ├── shared K, V
Q3 ┤
Q4 ┘
GQA(折中)
(Q1 Q2) ── K1 V1
(Q3 Q4) ── K2 V2
减少了什么?
-
KV Cache 从:
heads × seq_len × dim变成:
groups × seq_len × dim
代价
- 轻微表达能力下降
落地
- LLaMA 2 / 3
- 长上下文推理必用
三、方向三:让长度变长(扩展上下文)
1️⃣ RoPE 外推
原理
使用旋转位置编码,使位置关系具有外推能力。
📊 示意
标准位置编码:
pos=1 pos=2 pos=3 ...
RoPE:
向量按角度旋转
θ = pos × 频率
角度可扩展。
🏭 落地
- 几乎所有开源 LLM
2️⃣ 分布式 / Ring Attention
原理,把序列分布在多 GPU 上:
GPU0: tokens 0--8k
GPU1: tokens 8k--16k
GPU2: tokens 16k--24k
通过 ring 传递 K/V。
📊 示意图
GPU0 → GPU1 → GPU2 → GPU3
↑ ↓
└──────── ring ────────┘
🏭 落地
- 100k+ context
- 企业级长文本系统
大模型中的注意力优化,本质上都围绕三个方向展开:
- 少算 ------ 降低计算复杂度
- 少存 ------ 降低显存占用和内存访问
- 变长 ------ 在资源可控的前提下支持更长上下文
进一步抽象来看,这些优化都是在平衡四种资源约束:
计算(FLOPs)、显存(Memory)、带宽(IO)、并行结构(Parallelism)。
在当前硬件条件下,带宽和显存往往比算力更稀缺,因此工程上最成功的优化通常集中在"少存"和"少搬运"上,而不是单纯减少理论计算量。
结论:
注意力机制的演进,已经从"算法问题"逐渐转向"系统工程问题"。