Attention/注意力机制:AI的"聚光灯"
这篇文章深入解析Attention的工作原理,让你真正理解AI是怎么"看"文本的。
前言
上一篇文章,我们认识了Transformer这个"变形金刚"。
今天,我们来深入它的核心------Attention(注意力机制)
这是现代AI最重要的发明之一,理解了它,你就理解了AI的"灵魂"。
一、黑话原文 vs 人话翻译
场景模拟
arduino
🎯 AI算法组会议:
工程师A:"我们试试加个Cross-Attention"
工程师B:"Self-Attention的计算复杂度是O(n²)"
工程师C:"可以用Flash Attention优化"
工程师A:"Masked Attention也加上吧"
工程师B:"那得注意位置编码的配合"
人话翻译表
| 黑话 | 人话翻译 | 一句话理解 |
|---|---|---|
| Self-Attention | 自注意力 | 自己看自己 |
| Cross-Attention | 交叉注意力 | 这个看那个 |
| Masked Attention | 掩码注意力 | 不许偷看后面 |
| Flash Attention | 闪电注意力 | 快速版Attention |
| Multi-Query Attention | 多查询注意力 | 省显存的技术 |
| 位置编码 | Positional Encoding | 给字标上序号 |
二、Attention的本质
2.1 一句话定义
Attention = 告诉AI"该关注谁、关注多少"
人话版:就像老师在课堂上,知道该点谁回答问题,也知道谁的答案更靠谱。
2.2 Attention权重是什么?
scss
┌─────────────────────────────────────────────────────────────┐
│ Attention权重示意 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 句子:"小明 喜欢吃 苹果" │
│ │
│ 当AI处理"吃"时,它看其他词的程度: │
│ │
│ 小明 ████████████████ 0.45 (谁吃?小明) │
│ 喜欢 ████████ 0.20 │
│ 吃 ──────────────── 当前词 │
│ 苹果 ████████████ 0.35 (吃什么?苹果) │
│ │
│ 权重总和 = 1.0 (100%) │
│ │
└─────────────────────────────────────────────────────────────┘
2.3 生活类比
diff
Attention就像:
🎬 电影导演的镜头
- 重要角色:给特写(高权重)
- 次要角色:给远景(低权重)
- 无关路人:不给镜头(零权重)
👨🏫 老师点名
- 知道谁成绩好:更信任他的答案
- 知道谁在走神:不用他回答问题
🔦 舞台聚光灯
- 照在主角身上
- 观众的注意力被引导
三、Attention怎么计算?
3.1 核心公式
markdown
Attention(Q, K, V) = softmax(QKᵀ / √dₖ) × V
翻译成人话:
1. Q和K点积 → 计算相似度(谁和谁有关系)
2. 除以√dₖ → 防止数值太大
3. Softmax → 变成概率(权重加起来=1)
4. 乘以V → 加权求和(得到结果)
3.2 图解计算过程
less
Step 1: 准备Q、K、V
句子: "我 爱 AI"
每个词都有三个向量:
┌─────────────────────────────────────────┐
│ 词 │ Q │ K │ V │
│────────┼──────────┼──────────┼─────────│
│ 我 │ [1,0,0] │ [1,1,0] │ [0,1,1] │
│ 爱 │ [0,1,0] │ [1,0,1] │ [1,0,1] │
│ AI │ [0,0,1] │ [0,1,1] │ [1,1,0] │
└─────────────────────────────────────────┘
Step 2: 计算Q×Kᵀ(相似度矩阵)
我 爱 AI
我 [ 1.0 0.5 0.0 ]
爱 [ 1.0 1.0 1.0 ]
AI [ 0.0 0.5 1.0 ]
Step 3: Softmax归一化
我 爱 AI
我 [ 0.62 0.38 0.00 ]
爱 [ 0.24 0.42 0.34 ]
AI [ 0.00 0.38 0.62 ]
Step 4: 乘以V得到输出
输出 = 权重矩阵 × V值矩阵
每个词的新表示都融合了其他词的信息
3.3 代码实现
python
import torch
import torch.nn.functional as F
def scaled_dot_product_attention(Q, K, V):
"""
缩放点积注意力
"""
d_k = Q.size(-1)
# 1. Q和K点积
scores = torch.matmul(Q, K.transpose(-2, -1))
# 2. 缩放(除以√d_k)
scores = scores / (d_k ** 0.5)
# 3. Softmax归一化
attention_weights = F.softmax(scores, dim=-1)
# 4. 加权求和
output = torch.matmul(attention_weights, V)
return output, attention_weights
# 测试
d_model = 64
seq_len = 10
batch_size = 2
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)
output, weights = scaled_dot_product_attention(Q, K, V)
print(f"输入形状: {Q.shape}")
print(f"输出形状: {output.shape}")
print(f"权重形状: {weights.shape}")
四、不同类型的Attention
4.1 Self-Attention vs Cross-Attention
css
┌─────────────────────────────────────────────────────────────┐
│ Self-Attention │
├─────────────────────────────────────────────────────────────┤
│ │
│ Q、K、V都来自同一个序列 │
│ │
│ 输入序列 ──→ 生成Q ──┐ │
│ ──→ 生成K ──┼──→ 计算Attention │
│ ──→ 生成V ──┘ │
│ │
│ 用途:理解句子内部关系 │
│ │
└─────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ Cross-Attention │
├─────────────────────────────────────────────────────────────┤
│ │
│ Q来自一个序列,K和V来自另一个序列 │
│ │
│ 序列A ──→ 生成Q ──┐ │
│ ├──→ 计算Attention │
│ 序列B ──→ 生成K ──┤ │
│ ──→ 生成V ──┘ │
│ │
│ 用途:翻译、检索(一个看另一个) │
│ │
└─────────────────────────────────────────────────────────────┘
4.2 Masked Attention
css
问题:生成文本时,不能"偷看"后面的字
例如生成"我爱AI":
- 生成"我"时:只能看"我"
- 生成"爱"时:只能看"我爱"
- 生成"AI"时:才能看"我爱AI"
解决方案:Masked Attention
原始权重:
我 爱 AI
我 [ 0.33 0.33 0.33 ]
爱 [ 0.33 0.33 0.33 ]
AI [ 0.33 0.33 0.33 ]
Mask后(-∞表示不能看):
我 爱 AI
我 [ 0.33 -∞ -∞ ]
爱 [ 0.50 0.50 -∞ ]
AI [ 0.33 0.33 0.33 ]
Softmax后:
我 爱 AI
我 [ 1.00 0.00 0.00 ]
爱 [ 0.50 0.50 0.00 ]
AI [ 0.33 0.33 0.33 ]
实现代码:
python
def masked_attention(Q, K, V):
d_k = Q.size(-1)
scores = torch.matmul(Q, K.transpose(-2, -1)) / (d_k ** 0.5)
# 创建mask:上三角为-inf
seq_len = Q.size(1)
mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1) * (-1e9)
scores = scores + mask
attention_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attention_weights, V)
return output, attention_weights
4.3 高效Attention变体
scss
┌─────────────────────────────────────────────────────────────┐
│ Attention效率问题 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 标准Attention复杂度:O(n²) │
│ - 1000个词 → 1,000,000次计算 │
│ - 10000个词 → 100,000,000次计算 │
│ - 序列越长,越慢! │
│ │
│ 优化方案: │
│ │
│ 1. Flash Attention │
│ └── 优化内存访问,快2-4倍 │
│ │
│ 2. Multi-Query Attention (MQA) │
│ └── 共享K和V,省显存 │
│ │
│ 3. Grouped-Query Attention (GQA) │
│ └── MQA的折中版本 │
│ │
│ 4. Linear Attention │
│ └── 复杂度降到O(n),但效果略差 │
│ │
└─────────────────────────────────────────────────────────────┘
五、Attention可视化
5.1 可视化是什么样的?
css
句子:"The animal didn't cross the street because it was too tired"
当模型处理"it"时,Attention分布:
animal ████████████████████████ 0.85 ← it指的是animal
street ███ 0.10
tired ██ 0.05
可视化矩阵(简化版):
animal it street tired
animal [ 0.9 0.6 0.1 0.2 ]
it [ 0.1 0.9 0.8 0.7 ] ← it高度关注animal和street
street [ 0.2 0.7 0.9 0.3 ]
tired [ 0.1 0.6 0.2 0.9 ]
5.2 热力图示意
markdown
小 明 喜 欢 吃 苹 果
┌─────────────────────────────────────┐
小 │███ │ │ │ │ │ │ │
明 │ │███ │ │ │███ │ │ │
喜 │ │ │███ │███ │ │ │ │
欢 │ │ │███ │███ │ │ │ │
吃 │ │███ │ │ │███ │███ │███ │
苹 │ │ │ │ │███ │███ │███ │
果 │ │ │ │ │███ │███ │███ │
└─────────────────────────────────────┘
颜色越深 = Attention权重越高
六、位置编码(Positional Encoding)
6.1 为什么需要位置编码?
arduino
问题:Attention本身不知道"顺序"
句子A:"小明喜欢小红"
句子B:"小红喜欢小明"
Attention看到的:
两个句子除了词序不同,其他都一样
但意思完全不同!
所以需要告诉模型:每个词在第几个位置
6.2 位置编码怎么做?
python
import math
def positional_encoding(seq_len, d_model):
"""
生成位置编码
"""
pe = torch.zeros(seq_len, d_model)
position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1)
# 用正弦和余弦函数生成编码
div_term = torch.exp(
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term) # 偶数位置用sin
pe[:, 1::2] = torch.cos(position * div_term) # 奇数位置用cos
return pe
# 每个位置都有唯一的编码向量
# 位置1: [sin(1/d), cos(1/d), sin(1/d²), cos(1/d²), ...]
# 位置2: [sin(2/d), cos(2/d), sin(2/d²), cos(2/d²), ...]
6.3 位置编码变体
| 类型 | 代表 | 特点 |
|---|---|---|
| 正弦余弦 | 原始Transformer | 固定编码 |
| 可学习 | BERT、GPT | 训练时学习 |
| 旋转编码(RoPE) | LLaMA | 相对位置 |
| ALiBi | 某些模型 | 线性偏置 |
七、Attention的实际应用
7.1 文本理解
arduino
任务:情感分析
句子:"这家餐厅味道不错,但服务太差了"
Attention重点:
味道 ████████ 0.30
不错 ████████████████ 0.45 ← 正面
服务 ████████ 0.15
差 ██████████████ 0.40 ← 负面
模型综合考虑:正面vs负面的权重
7.2 机器翻译
sql
翻译:"I love you" → "我爱你"
Cross-Attention:
英语 → 中文
I → 我 (高权重)
love → 爱 (高权重)
you → 你 (高权重)
每个中文词"看"对应的英文词
7.3 代码生成
python
# 生成代码时的Attention
Prompt: "写一个计算斐波那契数列的函数"
Attention关注点:
1. "斐波那契" → 算法逻辑
2. "函数" → 函数定义格式
3. "计算" → 具体实现
生成过程中,Attention帮助模型记住:
- 变量名要一致
- 逻辑要连贯
- 语法要正确
小结
| 黑话 | 人话 | 记忆口诀 |
|---|---|---|
| Self-Attention | 自注意力 | 自己看自己 |
| Cross-Attention | 交叉注意力 | 这个看那个 |
| Masked Attention | 掩码注意力 | 不许偷看后面 |
| Flash Attention | 闪电注意力 | 快速版 |
| 位置编码 | Positional Encoding | 给字标序号 |
| Attention权重 | 关注程度 | 看得多=重要 |
关键认知:
- Attention是"聚光灯",告诉AI看哪里
- 权重通过Q和K的点积计算
- 不同类型适应不同场景
- 位置编码解决"顺序"问题
黑话等级
⭐⭐⭐⭐ 中级
├── 理解Attention计算过程
├── 知道不同Attention类型
└── 明白位置编码的作用
下一期预告:Embedding/嵌入 - 把文字变成数字的魔法
思考与练习
-
思考题:
- 为什么需要Masked Attention?
- Self-Attention和Cross-Attention的区别是什么?
-
动手练习:
- 实现一个简单的Attention函数
- 可视化一个句子的Attention权重
-
延伸探索:
- 了解Flash Attention的原理
- 研究RoPE位置编码
下期预告
下一篇文章,我们来聊:Embedding/嵌入 - 把文字变成数字的魔法
会解答这些问题:
- AI怎么把文字变成数字?
- 向量空间是什么?
- 为什么相似的词在空间中更近?
关注专栏,不错过后续更新!
作者:ECH00O00 本文首发于掘金专栏《AI黑话翻译官》 欢迎评论区交流讨论,点赞收藏就是最大的鼓励