1. 什么是自注意力?
想象一下,你正在阅读一本小说,每看到一个词语时,大脑会自动关注前文中与之相关的信息。这种"聚焦重点"的能力,正是自注意力机制的核心思想。
自注意力(Self-Attention)是一种让序列中的每个元素都能关注整个序列的机制。就像班级讨论时,每个同学发言(查询)都会考虑所有人的观点(键和值)。具体来说:
给定输入序列 <math xmlns="http://www.w3.org/1998/Math/MathML"> X = [ x 1 , x 2 , . . . , x n ] \mathbf{X} = [x_1, x_2, ..., x_n] </math>X=[x1,x2,...,xn],自注意力通过三个步骤生成输出:
- 生成问题纸条 :每个词元创建查询向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i = W q x i \mathbf{q}_i = W_q x_i </math>qi=Wqxi
- 制作答案卡 :每个词元生成键向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> k j = W k x j \mathbf{k}_j = W_k x_j </math>kj=Wkxj 和值向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> v j = W v x j \mathbf{v}_j = W_v x_j </math>vj=Wvxj
- 收集答案:每个查询收集所有键值对的加权和:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> y i = ∑ j = 1 n softmax ( q i ⊤ k j d ) v j y_i = \sum_{j=1}^n \text{softmax}\left(\frac{\mathbf{q}_i^\top \mathbf{k}_j}{\sqrt{d}}\right) \mathbf{v}_j </math>yi=j=1∑nsoftmax(d qi⊤kj)vj
其中:
- <math xmlns="http://www.w3.org/1998/Math/MathML"> q \mathbf{q} </math>q(Query)是查询矩阵,大小为 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( n × d ) (n \times d) </math>(n×d),其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n 是查询的数量, <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d 是特征维度。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> k \mathbf{k} </math>k(Key)是键矩阵,大小为 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( m × d ) (m \times d) </math>(m×d),其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> m m </math>m 是键的数量, <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d 是特征维度。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> v \mathbf{v} </math>v(Value)是值矩阵,大小为 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( m × d v ) (m \times d_v) </math>(m×dv)。
- <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 d \frac{1}{\sqrt{d}} </math>d 1 是一个缩放因子,用于防止大数值导致 softmax 过于极端,从而影响梯度的稳定性。
示例:考虑句子"猫吃鱼",自注意力会让"吃"同时关注"猫"和"鱼",就像我们在理解动词时会自动联系主语和宾语。
下面的代码片段是基于多头注意力对一个张量完成自注意力的计算,输入张量 <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X 的形状为 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( 批量大小 , 序列长度 , 特征维度 ) (\text{批量大小}, \text{序列长度}, \text{特征维度}) </math>(批量大小,序列长度,特征维度),经过自注意力计算后,输出张量与输入张量形状保持一致。
python
import torch
import d2l
num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
num_hiddens, num_heads, 0.5)
attention.eval()
print(attention)
"""输出:
MultiHeadAttention(
(attention): DotProductAttention(
(dropout): Dropout(p=0.5, inplace=False)
)
(W_q): Linear(in_features=100, out_features=100, bias=False)
(W_k): Linear(in_features=100, out_features=100, bias=False)
(W_v): Linear(in_features=100, out_features=100, bias=False)
(W_o): Linear(in_features=100, out_features=100, bias=False)
)
"""
batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))
print(attention(X, X, X, valid_lens).shape)
# 输出:torch.Size([2, 4, 100])
2. 三大序列模型的巅峰对决
2.1 参赛选手介绍
模型类型 | 工作方式 | 可视化类比 |
---|---|---|
CNN | 滑动窗口扫描 | 望远镜观察局部区域 |
RNN | 顺序传递信息 | 接力赛传递消息 |
自注意力 | 全局直接交互 | 电话会议全员讨论 |
2.2 性能参数对比
使用 <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n 个词元,每个维度 <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d,卷积核大小 <math xmlns="http://www.w3.org/1998/Math/MathML"> k k </math>k:
指标 | CNN | RNN | 自注意力 |
---|---|---|---|
计算复杂度 | <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( k n d 2 ) \mathcal{O}(knd^2) </math>O(knd2) | <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( n d 2 ) \mathcal{O}(nd^2) </math>O(nd2) | <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( n 2 d ) \mathcal{O}(n^2d) </math>O(n2d) |
并行能力 | 高 | 低 | 极高 |
最大路径长度 | <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( n / k ) \mathcal{O}(n/k) </math>O(n/k) | <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( n ) \mathcal{O}(n) </math>O(n) | <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( 1 ) \mathcal{O}(1) </math>O(1) |
图1 比较卷积神经网络(填充词元被忽略)、循环神经网络和自注意力三种架构
示例:处理100个词的句子时,自注意力需要100×100=10,000次交互计算,而CNN(假设k=3)只需3×100=300次局部计算。
3. 位置编码:给词语发"座位号"
3.1 为什么需要位置信息?
自注意力虽然强大,但有个致命缺陷------所有词语同时处理,就像把句子里的词全部平铺在桌面上,模型无法知道它们的原始顺序。这时就需要位置编码来标记每个词的位置。
3.2 神奇的三角函数编码
使用正弦和余弦函数的组合生成位置编码矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> P P </math>P,其中第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 行对应位置,第 <math xmlns="http://www.w3.org/1998/Math/MathML"> 2 j 2j </math>2j 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> 2 j + 1 2j+1 </math>2j+1 列使用:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> P i , 2 j = sin ( i 1000 0 2 j / d ) P i , 2 j + 1 = cos ( i 1000 0 2 j / d ) \begin{aligned} P_{i,2j} &= \sin\left(\frac{i}{10000^{2j/d}}\right) \\ P_{i,2j+1} &= \cos\left(\frac{i}{10000^{2j/d}}\right) \end{aligned} </math>Pi,2jPi,2j+1=sin(100002j/di)=cos(100002j/di)
示例 :当 <math xmlns="http://www.w3.org/1998/Math/MathML"> d = 4 d=4 </math>d=4 时,位置1的编码可能是: [sin(1/10000^0), cos(1/10000^0), sin(1/10000^(2/4)), cos(1/10000^(2/4))]
让我们在下面的PositionalEncoding
类中实现它这种编码方式:
python
class PositionalEncoding(nn.Module):
"""位置编码"""
def __init__(self, num_hiddens, dropout, max_len=1000):
super(PositionalEncoding, self).__init__()
self.dropout = nn.Dropout(dropout)
# 创建一个足够长的P
self.P = torch.zeros((1, max_len, num_hiddens))
X = torch.arange(max_len, dtype=torch.float32).reshape(
-1, 1) / torch.pow(10000, torch.arange(
0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
self.P[:, :, 0::2] = torch.sin(X)
self.P[:, :, 1::2] = torch.cos(X)
def forward(self, X):
X = X + self.P[:, :X.shape[1], :].to(X.device)
return self.dropout(X)
在位置嵌入矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> P P </math>P 中,行代表词元在序列中的位置,列代表位置编码的不同维度。
python
encoding_dim, num_steps = 32, 60
pos_encoding = d2l.PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
figsize=(6.18, 3.82), legend=["Col %d" % d for d in torch.arange(6, 10)])
从下面的例子中可以看到位置嵌入矩阵的第6列和第7列的频率高于第8列和第9列。第6列和第7列之间的偏移量(第8列和第9列相同)是由于正弦函数和余弦函数的交替。
3.3 编码特性揭秘
绝对位置感知
不同列对应不同频率的波形,就像钢琴键盘上从左到右音调逐渐降低。高频(左侧列)帮助捕捉相邻词语的位置关系,低频(右侧列)负责编码词语在序列中的整体位置。
python
P = P[0, :, :].unsqueeze(0).unsqueeze(0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
ylabel='Row (position)', figsize=(3.82, 6.18), cmap='Blues')
相对位置推理
关键公式:位置 <math xmlns="http://www.w3.org/1998/Math/MathML"> i + k i+k </math>i+k 的编码可以表示为位置 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 编码的线性变换:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> sin ( ω j ( i + k ) ) = sin ( ω j i ) cos ( ω j k ) + cos ( ω j i ) sin ( ω j k ) cos ( ω j ( i + k ) ) = cos ( ω j i ) cos ( ω j k ) − sin ( ω j i ) sin ( ω j k ) \begin{aligned} \sin(\omega_j(i+k)) &= \sin(\omega_j i)\cos(\omega_j k) + \cos(\omega_j i)\sin(\omega_j k) \\ \cos(\omega_j(i+k)) &= \cos(\omega_j i)\cos(\omega_j k) - \sin(\omega_j i)\sin(\omega_j k) \end{aligned} </math>sin(ωj(i+k))cos(ωj(i+k))=sin(ωji)cos(ωjk)+cos(ωji)sin(ωjk)=cos(ωji)cos(ωjk)−sin(ωji)sin(ωjk)
这就像通过三角函数公式,模型可以推导出词语之间的相对距离。
4. 关键知识点总结
- 自注意力的本质:让每个词元都能与序列中所有词元直接交互
- 三大模型对比 :
- CNN:局部感知,适合处理图像
- RNN:顺序处理,适合流式数据
- 自注意力:全局交互,适合长程依赖
- 位置编码的妙用 :
- 绝对位置:通过不同频率的正余弦函数编码
- 相对位置:利用三角恒等式实现位置偏移的线性表示
通过这个魔法般的组合,现代Transformer模型才能在机器翻译、文本生成等任务中展现出惊人的性能。理解这些基础原理,就是打开深度学习宝库的第一把钥匙!