深入浅出 Transformers:自注意力和多头注意力的那些事儿

Transformers 的自注意力和多头注意力机制是它的灵魂所在。如果你对这些概念感到困惑,别担心!今天我们不仅用通俗易懂的语言解释,还会深入到算法细节,带你彻底搞懂它们的原理。准备好了吗?让我们开始吧!


自注意力:模型的"千里眼"

1. 自注意力的核心问题

在自然语言处理中,句子中的每个词都可能与其他词有关系。比如:

"小明喜欢吃苹果,因为苹果很甜。"

当模型处理"苹果"时,它需要知道:

  • "苹果"是"小明喜欢的"。
  • "苹果"是"很甜的"。

自注意力的目标就是解决这个问题:让每个词根据上下文,动态地决定该关注哪些词,以及关注的程度


2. 自注意力的算法步骤

自注意力的计算可以分为以下几个步骤:

Step 1: 输入表示

假设我们有一个句子,包含 ( n ) 个词,每个词用一个向量表示。输入可以表示为一个矩阵 ( X ): <math xmlns="http://www.w3.org/1998/Math/MathML"> [ X = [ x 1 , x 2 , ... , x n ] ] [ X = [x_1, x_2, \dots, x_n] ] </math>[X=[x1,x2,...,xn]] 其中,( x_i ) 是每个词的向量表示(比如 512 维)。

Step 2: 生成 Query、Key 和 Value

每个词会生成三个向量:Query(查询向量)Key(键向量)Value(值向量) 。这些向量是通过三个不同的线性变换得到的: <math xmlns="http://www.w3.org/1998/Math/MathML"> [ Q = X W Q , K = X W K , V = X W V ] [ Q = XW_Q, \quad K = XW_K, \quad V = XW_V ] </math>[Q=XWQ,K=XWK,V=XWV]

  • ( W_Q, W_K, W_V ) 是可学习的权重矩阵。
  • ( Q, K, V ) 分别是 Query、Key 和 Value 的矩阵表示。
Step 3: 计算注意力分数

每个词的 Query 会与所有词的 Key 进行点积,计算它们的相似度(即注意力分数): <math xmlns="http://www.w3.org/1998/Math/MathML"> [ Attention Score i j = Q i ⋅ K j T ] [ \text{Attention Score}_{ij} = Q_i \cdot K_j^T ] </math>[Attention Scoreij=Qi⋅KjT]

  • ( Q_i ) 是第 ( i ) 个词的 Query 向量。
  • ( K_j ) 是第 ( j ) 个词的 Key 向量。

为了让分数更稳定,我们会对点积结果进行缩放(防止数值过大): <math xmlns="http://www.w3.org/1998/Math/MathML"> [ Scaled Score i j = Q i ⋅ K j T d k ] [ \text{Scaled Score}_{ij} = \frac{Q_i \cdot K_j^T}{\sqrt{d_k}} ] </math>[Scaled Scoreij=dk Qi⋅KjT]其中,( d_k ) 是 Key 向量的维度。

Step 4: 归一化分数(Softmax)

将每个词的注意力分数通过 Softmax 转换为概率分布: <math xmlns="http://www.w3.org/1998/Math/MathML"> [ Attention Weights ∗ i j = Softmax ( Scaled Score ∗ i j ) ] [ \text{Attention Weights} *{ij} = \text{Softmax}(\text{Scaled Score}* {ij}) ] </math>[Attention Weights∗ij=Softmax(Scaled Score∗ij)] 这样可以确保所有分数加起来等于 1,表示每个词对当前词的重要程度。

Step 5: 加权求和

用注意力权重对 Value 向量进行加权求和,得到每个词的最终表示: <math xmlns="http://www.w3.org/1998/Math/MathML"> [ Output ∗ i = ∑ ∗ j = 1 n Attention Weights i j ⋅ V j ] [ \text{Output}*i = \sum*{j=1}^n \text{Attention Weights}_{ij} \cdot V_j ] </math>[Output∗i=∑∗j=1nAttention Weightsij⋅Vj]

  • ( <math xmlns="http://www.w3.org/1998/Math/MathML"> V j V_j </math>Vj ) 是第 ( j ) 个词的 Value 向量。
  • ( <math xmlns="http://www.w3.org/1998/Math/MathML"> Output i \text{Output}_i </math>Outputi) 是第 ( i ) 个词的最终表示,包含了它对其他词的注意力信息。

3. 自注意力的直观理解

自注意力的过程可以类比为一个"开会讨论"的场景:

  • 每个人(词)都有自己的问题(Query)。
  • 每个人也有自己的身份信息(Key)和观点(Value)。
  • 每个人会根据别人的身份信息(Key)来决定是否听取他们的观点(Value)。
  • 最后,每个人综合所有人的观点,形成自己的最终决策。

多头注意力:模型的"多线程处理器"

1. 为什么需要多头注意力?

自注意力虽然强大,但它有一个局限:一次只能关注一种关系。比如:

  • "苹果"和"甜"是一个关系。
  • "小明"和"喜欢"是另一个关系。

如果只有一个注意力机制,模型可能会漏掉一些重要信息。于是,多头注意力应运而生。


2. 多头注意力的算法步骤

多头注意力的核心思想是:让模型同时关注多个关系,每个头独立计算注意力,然后将结果合并

Step 1: 分头计算

将输入向量 ( X ) 分成 ( h ) 个头,每个头独立计算 Query、Key 和 Value: <math xmlns="http://www.w3.org/1998/Math/MathML"> [ Q h = X W Q h , K h = X W K h , V h = X W V h ] [ Q_h = XW_Q^h, \quad K_h = XW_K^h, \quad V_h = XW_V^h ] </math>[Qh=XWQh,Kh=XWKh,Vh=XWVh]

  • ( <math xmlns="http://www.w3.org/1998/Math/MathML"> W Q h , W K h , W V h W_Q^h, W_K^h, W_V^h </math>WQh,WKh,WVh) 是每个头的权重矩阵。
  • 每个头的维度通常比原始维度小(比如原始维度是 512,每个头的维度是 64)。
Step 2: 独立计算注意力

每个头独立执行自注意力的计算,得到每个头的输出: <math xmlns="http://www.w3.org/1998/Math/MathML"> [ Output h = Attention ( Q h , K h , V h ) ] [ \text{Output}_h = \text{Attention}(Q_h, K_h, V_h) ] </math>[Outputh=Attention(Qh,Kh,Vh)]

Step 3: 合并结果

将所有头的输出拼接起来,形成一个新的矩阵: <math xmlns="http://www.w3.org/1998/Math/MathML"> [ Concat Output = [ Output 1 , Output 2 , ... , Output h ] ] [ \text{Concat Output} = [\text{Output}_1, \text{Output}_2, \dots, \text{Output}_h] ] </math>[Concat Output=[Output1,Output2,...,Outputh]]

Step 4: 线性变换

对拼接后的结果进行线性变换,得到最终的多头注意力输出: <math xmlns="http://www.w3.org/1998/Math/MathML"> [ Final Output = Concat Output ⋅ W O ] [ \text{Final Output} = \text{Concat Output} \cdot W_O ] </math>[Final Output=Concat Output⋅WO]

  • ( <math xmlns="http://www.w3.org/1998/Math/MathML"> W O W_O </math>WO ) 是一个可学习的权重矩阵。

3. 多头注意力的直观理解

多头注意力就像一个侦探团队,每个侦探负责调查不同的线索:

  • 侦探 A 关注"苹果"和"甜"的关系。
  • 侦探 B 关注"小明"和"喜欢"的关系。
  • 侦探 C 关注"因为"和"苹果"的关系。

最后,所有侦探的调查结果汇总成一份完整的报告,帮助模型更全面地理解句子。


为什么需要位置编码?

Transformers 的一个特点是,它不依赖传统的 RNN 或 CNN 结构,而是基于自注意力机制处理输入数据。自注意力的好处是可以并行计算,但它有一个问题:它对词语的顺序没有天然的感知能力

  • "小明喜欢吃苹果" 和 "苹果喜欢吃小明" 对人类来说完全不同,但对 Transformers 来说,它们的输入向量可能是一样的。

为了让模型知道"谁在前,谁在后",我们需要给每个词加上位置信息,这就是位置编码的作用。


位置编码的原理

位置编码的目标是为每个词引入一个独特的"位置信息",并将其与词向量结合,让模型能够感知词语的顺序。

1. 数学公式

位置编码通常使用正弦和余弦函数生成,公式如下:

<math xmlns="http://www.w3.org/1998/Math/MathML"> P E ( p o s , 2 i ) = sin ⁡ ( p o s 1000 0 2 i d ) PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{\frac{2i}{d}}}\right) </math>PE(pos,2i)=sin(10000d2ipos)

<math xmlns="http://www.w3.org/1998/Math/MathML"> P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s 1000 0 2 i d ) PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{\frac{2i}{d}}}\right) </math>PE(pos,2i+1)=cos(10000d2ipos)

  • ( pos ):词语的位置(如第 1 个词、第 2 个词)。
  • ( i ):向量的维度索引。
  • ( d ):词向量的总维度。

2. 为什么用正弦和余弦?

  • 周期性:正弦和余弦函数是周期性的,能够捕捉词语之间的相对位置关系。
  • 平滑变化:随着位置 ( pos ) 的增加,编码值会平滑变化,便于模型学习。
  • 不同维度的变化速率 :通过 ( <math xmlns="http://www.w3.org/1998/Math/MathML"> 1000 0 2 i d 10000^{\frac{2i}{d}} </math>10000d2i ) 的缩放,不同维度的变化速率不同,确保每个位置的编码是唯一的。

总结:自注意力和多头注意力的优雅结合

  • 自注意力:让每个词都能动态地关注句子中的其他词,捕捉上下文关系。
  • 多头注意力:让模型从多个角度分析句子,信息更全面,表达能力更强。
  • 位置编码(Positional Encoding) :帮助模型理解词语的顺序关系。

Transformers 的强大之处就在于它能灵活地捕捉句子中的各种关系,而自注意力和多头注意力正是实现这一目标的关键。

相关推荐
IT古董5 分钟前
【漫话机器学习系列】208.标准差(Standard Deviation)
人工智能
CH3_CH2_CHO12 分钟前
DAY06:【pytorch】图像增强
人工智能·pytorch·计算机视觉
意.远14 分钟前
PyTorch实现权重衰退:从零实现与简洁实现
人工智能·pytorch·python·深度学习·神经网络·机器学习
兮兮能吃能睡18 分钟前
我的机器学习之路(初稿)
人工智能·机器学习
Java中文社群21 分钟前
SpringAI版本更新:向量数据库不可用的解决方案!
java·人工智能·后端
Shawn_Shawn27 分钟前
AI换装-OOTDiffusion使用教程
人工智能·llm
扉间79828 分钟前
探索图像分类模型的 Flask 应用搭建之旅
人工智能·分类·flask
鲜枣课堂40 分钟前
发力“5G-A x AI融智创新”,中国移动推出重要行动计划!打造“杭州Mobile AI第一城”!
人工智能·5g
爱的叹息1 小时前
AI应用开发平台 和 通用自动化工作流工具 的详细对比,涵盖定义、核心功能、典型工具、适用场景及优缺点分析
运维·人工智能·自动化
Dm_dotnet1 小时前
使用CAMEL创建第一个Agent Society
人工智能