自注意力与位置编码:让模型理解序列的魔法

1. 什么是自注意力?

想象一下,你正在阅读一本小说,每看到一个词语时,大脑会自动关注前文中与之相关的信息。这种"聚焦重点"的能力,正是自注意力机制的核心思想。

自注意力(Self-Attention)是一种让序列中的每个元素都能关注整个序列的机制。就像班级讨论时,每个同学发言(查询)都会考虑所有人的观点(键和值)。具体来说:

给定输入序列 <math xmlns="http://www.w3.org/1998/Math/MathML"> X = [ x 1 , x 2 , . . . , x n ] \mathbf{X} = [x_1, x_2, ..., x_n] </math>X=[x1,x2,...,xn],自注意力通过三个步骤生成输出:

  1. 生成问题纸条 :每个词元创建查询向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> q i = W q x i \mathbf{q}_i = W_q x_i </math>qi=Wqxi
  2. 制作答案卡 :每个词元生成键向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> k j = W k x j \mathbf{k}_j = W_k x_j </math>kj=Wkxj 和值向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> v j = W v x j \mathbf{v}_j = W_v x_j </math>vj=Wvxj
  3. 收集答案:每个查询收集所有键值对的加权和:

<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> y i = ∑ j = 1 n softmax ( q i ⊤ k j d ) v j y_i = \sum_{j=1}^n \text{softmax}\left(\frac{\mathbf{q}_i^\top \mathbf{k}_j}{\sqrt{d}}\right) \mathbf{v}_j </math>yi=j=1∑nsoftmax(d qi⊤kj)vj

其中:

  • <math xmlns="http://www.w3.org/1998/Math/MathML"> q \mathbf{q} </math>q(Query)是查询矩阵,大小为 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( n × d ) (n \times d) </math>(n×d),其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n 是查询的数量, <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d 是特征维度。
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> k \mathbf{k} </math>k(Key)是键矩阵,大小为 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( m × d ) (m \times d) </math>(m×d),其中 <math xmlns="http://www.w3.org/1998/Math/MathML"> m m </math>m 是键的数量, <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d 是特征维度。
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> v \mathbf{v} </math>v(Value)是值矩阵,大小为 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( m × d v ) (m \times d_v) </math>(m×dv)。
  • <math xmlns="http://www.w3.org/1998/Math/MathML"> 1 d \frac{1}{\sqrt{d}} </math>d 1 是一个缩放因子,用于防止大数值导致 softmax 过于极端,从而影响梯度的稳定性。

示例:考虑句子"猫吃鱼",自注意力会让"吃"同时关注"猫"和"鱼",就像我们在理解动词时会自动联系主语和宾语。

下面的代码片段是基于多头注意力对一个张量完成自注意力的计算,输入张量 <math xmlns="http://www.w3.org/1998/Math/MathML"> X X </math>X 的形状为 <math xmlns="http://www.w3.org/1998/Math/MathML"> ( 批量大小 , 序列长度 , 特征维度 ) (\text{批量大小}, \text{序列长度}, \text{特征维度}) </math>(批量大小,序列长度,特征维度),经过自注意力计算后,输出张量与输入张量形状保持一致。

python 复制代码
import torch
import d2l

num_hiddens, num_heads = 100, 5
attention = d2l.MultiHeadAttention(num_hiddens, num_hiddens, num_hiddens,
                                   num_hiddens, num_heads, 0.5)
attention.eval()

print(attention)
"""输出:

MultiHeadAttention(
  (attention): DotProductAttention(
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (W_q): Linear(in_features=100, out_features=100, bias=False)
  (W_k): Linear(in_features=100, out_features=100, bias=False)
  (W_v): Linear(in_features=100, out_features=100, bias=False)
  (W_o): Linear(in_features=100, out_features=100, bias=False)
)
"""

batch_size, num_queries, valid_lens = 2, 4, torch.tensor([3, 2])
X = torch.ones((batch_size, num_queries, num_hiddens))

print(attention(X, X, X, valid_lens).shape)
# 输出:torch.Size([2, 4, 100])

2. 三大序列模型的巅峰对决

2.1 参赛选手介绍

模型类型 工作方式 可视化类比
CNN 滑动窗口扫描 望远镜观察局部区域
RNN 顺序传递信息 接力赛传递消息
自注意力 全局直接交互 电话会议全员讨论

2.2 性能参数对比

使用 <math xmlns="http://www.w3.org/1998/Math/MathML"> n n </math>n 个词元,每个维度 <math xmlns="http://www.w3.org/1998/Math/MathML"> d d </math>d,卷积核大小 <math xmlns="http://www.w3.org/1998/Math/MathML"> k k </math>k:

指标 CNN RNN 自注意力
计算复杂度 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( k n d 2 ) \mathcal{O}(knd^2) </math>O(knd2) <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( n d 2 ) \mathcal{O}(nd^2) </math>O(nd2) <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( n 2 d ) \mathcal{O}(n^2d) </math>O(n2d)
并行能力 极高
最大路径长度 <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( n / k ) \mathcal{O}(n/k) </math>O(n/k) <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( n ) \mathcal{O}(n) </math>O(n) <math xmlns="http://www.w3.org/1998/Math/MathML"> O ( 1 ) \mathcal{O}(1) </math>O(1)


图1 比较卷积神经网络(填充词元被忽略)、循环神经网络和自注意力三种架构

示例:处理100个词的句子时,自注意力需要100×100=10,000次交互计算,而CNN(假设k=3)只需3×100=300次局部计算。


3. 位置编码:给词语发"座位号"

3.1 为什么需要位置信息?

自注意力虽然强大,但有个致命缺陷------所有词语同时处理,就像把句子里的词全部平铺在桌面上,模型无法知道它们的原始顺序。这时就需要位置编码来标记每个词的位置。

3.2 神奇的三角函数编码

使用正弦和余弦函数的组合生成位置编码矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> P P </math>P,其中第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 行对应位置,第 <math xmlns="http://www.w3.org/1998/Math/MathML"> 2 j 2j </math>2j 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> 2 j + 1 2j+1 </math>2j+1 列使用:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> P i , 2 j = sin ⁡ ( i 1000 0 2 j / d ) P i , 2 j + 1 = cos ⁡ ( i 1000 0 2 j / d ) \begin{aligned} P_{i,2j} &= \sin\left(\frac{i}{10000^{2j/d}}\right) \\ P_{i,2j+1} &= \cos\left(\frac{i}{10000^{2j/d}}\right) \end{aligned} </math>Pi,2jPi,2j+1=sin(100002j/di)=cos(100002j/di)

示例 :当 <math xmlns="http://www.w3.org/1998/Math/MathML"> d = 4 d=4 </math>d=4 时,位置1的编码可能是: [sin(1/10000^0), cos(1/10000^0), sin(1/10000^(2/4)), cos(1/10000^(2/4))]

让我们在下面的PositionalEncoding类中实现它这种编码方式:

python 复制代码
class PositionalEncoding(nn.Module):
    """位置编码"""

    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)
        # 创建一个足够长的P
        self.P = torch.zeros((1, max_len, num_hiddens))
        X = torch.arange(max_len, dtype=torch.float32).reshape(
            -1, 1) / torch.pow(10000, torch.arange(
            0, num_hiddens, 2, dtype=torch.float32) / num_hiddens)
        self.P[:, :, 0::2] = torch.sin(X)
        self.P[:, :, 1::2] = torch.cos(X)

    def forward(self, X):
        X = X + self.P[:, :X.shape[1], :].to(X.device)
        return self.dropout(X)

在位置嵌入矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> P P </math>P 中,行代表词元在序列中的位置,列代表位置编码的不同维度。

python 复制代码
encoding_dim, num_steps = 32, 60
pos_encoding = d2l.PositionalEncoding(encoding_dim, 0)
pos_encoding.eval()
X = pos_encoding(torch.zeros((1, num_steps, encoding_dim)))
P = pos_encoding.P[:, :X.shape[1], :]
d2l.plot(torch.arange(num_steps), P[0, :, 6:10].T, xlabel='Row (position)',
         figsize=(6.18, 3.82), legend=["Col %d" % d for d in torch.arange(6, 10)])

从下面的例子中可以看到位置嵌入矩阵的第6列和第7列的频率高于第8列和第9列。第6列和第7列之间的偏移量(第8列和第9列相同)是由于正弦函数和余弦函数的交替。

3.3 编码特性揭秘

绝对位置感知

不同列对应不同频率的波形,就像钢琴键盘上从左到右音调逐渐降低。高频(左侧列)帮助捕捉相邻词语的位置关系,低频(右侧列)负责编码词语在序列中的整体位置。

python 复制代码
P = P[0, :, :].unsqueeze(0).unsqueeze(0)
d2l.show_heatmaps(P, xlabel='Column (encoding dimension)',
                  ylabel='Row (position)', figsize=(3.82, 6.18), cmap='Blues')

相对位置推理

关键公式:位置 <math xmlns="http://www.w3.org/1998/Math/MathML"> i + k i+k </math>i+k 的编码可以表示为位置 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i 编码的线性变换:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> sin ⁡ ( ω j ( i + k ) ) = sin ⁡ ( ω j i ) cos ⁡ ( ω j k ) + cos ⁡ ( ω j i ) sin ⁡ ( ω j k ) cos ⁡ ( ω j ( i + k ) ) = cos ⁡ ( ω j i ) cos ⁡ ( ω j k ) − sin ⁡ ( ω j i ) sin ⁡ ( ω j k ) \begin{aligned} \sin(\omega_j(i+k)) &= \sin(\omega_j i)\cos(\omega_j k) + \cos(\omega_j i)\sin(\omega_j k) \\ \cos(\omega_j(i+k)) &= \cos(\omega_j i)\cos(\omega_j k) - \sin(\omega_j i)\sin(\omega_j k) \end{aligned} </math>sin(ωj(i+k))cos(ωj(i+k))=sin(ωji)cos(ωjk)+cos(ωji)sin(ωjk)=cos(ωji)cos(ωjk)−sin(ωji)sin(ωjk)

这就像通过三角函数公式,模型可以推导出词语之间的相对距离。


4. 关键知识点总结

  1. 自注意力的本质:让每个词元都能与序列中所有词元直接交互
  2. 三大模型对比
    • CNN:局部感知,适合处理图像
    • RNN:顺序处理,适合流式数据
    • 自注意力:全局交互,适合长程依赖
  3. 位置编码的妙用
    • 绝对位置:通过不同频率的正余弦函数编码
    • 相对位置:利用三角恒等式实现位置偏移的线性表示

通过这个魔法般的组合,现代Transformer模型才能在机器翻译、文本生成等任务中展现出惊人的性能。理解这些基础原理,就是打开深度学习宝库的第一把钥匙!

相关推荐
海森大数据44 分钟前
AI矿工掘金材料新大陆:DeepMind如何用神经网络改写元素周期表?
人工智能·深度学习·神经网络
liruiqiang051 小时前
卷积神经网络 - LeNet-5
人工智能·深度学习·神经网络·机器学习·cnn
点我头像干啥1 小时前
计算机视觉的多模态模型:开启感知智能的新篇章
人工智能·深度学习·计算机视觉
gorgor在码农2 小时前
神经网络基础(NN)
人工智能·pytorch·python·深度学习·神经网络·机器学习
盼小辉丶3 小时前
TensorFlow深度学习实战——利用词嵌入实现垃圾邮件检测
人工智能·深度学习·tensorflow
Hi__3 小时前
深度学习基础-----神经⽹络与深度学习((美)MichaelNielsen )
人工智能·深度学习
视觉语言导航4 小时前
ICASSP-2025 | 国防科大具身导航高效记忆与推理!GAR:基于图感知推理与双向选择的视觉语言导航
人工智能·深度学习·具身智能
James. 常德 student5 小时前
深度学习之自动求导
人工智能·深度学习
船长@Quant5 小时前
PyTorch量化技术教程:第三章 PyTorch模型构建与训练
pytorch·python·深度学习·机器学习·量化交易·ta-lib
神经星星5 小时前
新加坡国立大学张阳团队开发第二代RNA结构预测算法,多项基准测试超越SOTA
人工智能·深度学习·机器学习