1. 什么是自注意力?
想象一下,你正在阅读一本小说,每看到一个词语时,大脑会自动关注前文中与之相关的信息。这种"聚焦重点"的能力,正是自注意力机制的核心思想。
自注意力(Self-Attention)是一种让序列中的每个元素都能关注整个序列的机制。就像班级讨论时,每个同学发言(查询)都会考虑所有人的观点(键和值)。具体来说:
给定输入序列 X=x1,x2,...,xn,自注意力通过三个步骤生成输出:
- 生成问题纸条 :每个词元创建查询向量 qi=Wqxi
- 制作答案卡 :每个词元生成键向量 kj=Wkxj 和值向量 vj=Wvxj
- 收集答案:每个查询收集所有键值对的加权和:
yi=j=1∑nsoftmax(d qi⊤kj)vj
其中:
- q(Query)是查询矩阵,大小为 (n×d),其中 n 是查询的数量, d 是特征维度。
- k(Key)是键矩阵,大小为 (m×d),其中 m 是键的数量, d 是特征维度。
- v(Value)是值矩阵,大小为 (m×dv)。
- d 1 是一个缩放因子,用于防止大数值导致 softmax 过于极端,从而影响梯度的稳定性。
示例:考虑句子"猫吃鱼",自注意力会让"吃"同时关注"猫"和"鱼",就像我们在理解动词时会自动联系主语和宾语。
下面的代码片段是基于多头注意力对一个张量完成自注意力的计算,输入张量 X 的形状为 (批量大小,序列长度,特征维度),经过自注意力计算后,输出张量与输入张量形状保持一致。
python
import torch
import d2l
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
attention.eval()
print(attention)
"""输出:
MultiHeadAttention(
(attention): DotProductAttention(
(dropout): Dropout(p=0.5, inplace=False)
)
(W_q): Linear(in_features=100, out_features=100, bias=False)
(W_k): Linear(in_features=100, out_features=100, bias=False)
(W_v): Linear(in_features=100, out_features=100, bias=False)
(W_o): Linear(in_features=100, out_features=100, bias=False)
)
"""
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
print(attention(X, X, X, valid_lens).shape)
# 输出:torch.Size([2, 4, 100])
2. 三大序列模型的巅峰对决
2.1 参赛选手介绍
| 模型类型 | 工作方式 | 可视化类比 |
|---|---|---|
| CNN | 滑动窗口扫描 | 望远镜观察局部区域 |
| RNN | 顺序传递信息 | 接力赛传递消息 |
| 自注意力 | 全局直接交互 | 电话会议全员讨论 |
2.2 性能参数对比
使用 n 个词元,每个维度 d,卷积核大小 k:
| 指标 | CNN | RNN | 自注意力 |
|---|---|---|---|
| 计算复杂度 | O(knd2) | O(nd2) | O(n2d) |
| 并行能力 | 高 | 低 | 极高 |
| 最大路径长度 | O(n/k) | O(n) | O(1) |
图1 比较卷积神经网络(填充词元被忽略)、循环神经网络和自注意力三种架构
示例:处理100个词的句子时,自注意力需要100×100=10,000次交互计算,而CNN(假设k=3)只需3×100=300次局部计算。
3. 位置编码:给词语发"座位号"
3.1 为什么需要位置信息?
自注意力虽然强大,但有个致命缺陷------所有词语同时处理,就像把句子里的词全部平铺在桌面上,模型无法知道它们的原始顺序。这时就需要位置编码来标记每个词的位置。
3.2 神奇的三角函数编码
使用正弦和余弦函数的组合生成位置编码矩阵 P,其中第 i 行对应位置,第 2j 和 2j+1 列使用:
Pi,2jPi,2j+1=sin(100002j/di)=cos(100002j/di)
示例 :当 d=4 时,位置1的编码可能是: sin(1/10000\^0), cos(1/10000\^0), sin(1/10000\^(2/4)), cos(1/10000\^(2/4))
让我们在下面的PositionalEncoding类中实现它这种编码方式:
python
class PositionalEncoding(nn.Module):
"""位置编码"""
def __init__(self, num_hiddens, dropout, max_len=1000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(dropout)
# 创建一个足够长的P
self.P = torch.zeros((1, max_len, num_hiddens))
X = torch.arange(max_len, dtype=torch.float32).reshape(
-1, 1) / torch.pow(10000, torch.arange(
0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
self.P[:, :, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)
def forward(self, X):
X = X + self.P[:, :X.shape[1], :].to(X.device)
return self.dropout(X)
在位置嵌入矩阵 P 中,行代表词元在序列中的位置,列代表位置编码的不同维度。
python
encoding_dim, num_steps = 32, 60
pos_encoding = d2l.PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6.18, 3.82), legend=["Col %d" % d for d in torch.arange(6, 10)])
从下面的例子中可以看到位置嵌入矩阵的第6列和第7列的频率高于第8列和第9列。第6列和第7列之间的偏移量(第8列和第9列相同)是由于正弦函数和余弦函数的交替。
3.3 编码特性揭秘
绝对位置感知
不同列对应不同频率的波形,就像钢琴键盘上从左到右音调逐渐降低。高频(左侧列)帮助捕捉相邻词语的位置关系,低频(右侧列)负责编码词语在序列中的整体位置。
python
P = P[0, :, :].unsqueeze(0).unsqueeze(0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
ylabel='Row (position)', figsize=(3.82, 6.18), cmap='Blues')
相对位置推理
关键公式:位置 i+k 的编码可以表示为位置 i 编码的线性变换:
sin(ωj(i+k))cos(ωj(i+k))=sin(ωji)cos(ωjk)+cos(ωji)sin(ωjk)=cos(ωji)cos(ωjk)−sin(ωji)sin(ωjk)
这就像通过三角函数公式,模型可以推导出词语之间的相对距离。
4. 关键知识点总结
- 自注意力的本质:让每个词元都能与序列中所有词元直接交互
- 三大模型对比 :
- CNN:局部感知,适合处理图像
- RNN:顺序处理,适合流式数据
- 自注意力:全局交互,适合长程依赖
- 位置编码的妙用 :
- 绝对位置:通过不同频率的正余弦函数编码
- 相对位置:利用三角恒等式实现位置偏移的线性表示
通过这个魔法般的组合,现代Transformer模型才能在机器翻译、文本生成等任务中展现出惊人的性能。理解这些基础原理,就是打开深度学习宝库的第一把钥匙!