目录
1. 概述
1.1 什么是注意力机制
注意力机制(Attention Mechanism)是一种让模型能够动态聚焦于输入序列中最相关部分的技术。它的核心思想是:在处理一个序列时,模型不应该平等对待所有位置的信息,而应该根据当前任务的需要,选择性地关注最相关的信息。
这种机制的灵感来源于人类的视觉和认知系统。当我们阅读一段文字时,我们的眼睛会自然地聚焦于关键词语,而不是平均地关注每一个字符。类似地,当我们翻译一个句子时,我们会在翻译每个词时关注源句子中对应的词。
1.2 注意力机制的重要性
注意力机制的出现彻底改变了深度学习领域,尤其是在自然语言处理和计算机视觉领域。它的核心贡献包括:
解决了长距离依赖问题:传统的循环神经网络(RNN)在处理长序列时,由于梯度消失问题,难以捕捉远距离的依赖关系。注意力机制允许模型直接建立任意两个位置之间的连接,无论它们之间的距离有多远。
实现了并行计算:RNN必须按顺序处理序列,无法并行。注意力机制可以同时处理所有位置,充分利用GPU的并行计算能力,大大提高了训练和推理效率。
提供了可解释性:通过可视化注意力权重,我们可以理解模型在做决策时关注了哪些信息,这为模型的可解释性提供了重要工具。
统一了多种模态:注意力机制可以自然地处理不同模态的数据(文本、图像、音频),成为多模态学习的核心技术。
1.3 注意力机制的核心组成
一个完整的注意力机制通常包含三个核心元素:
查询(Query):代表当前需要处理的位置的信息需求。在翻译任务中,查询可能是当前要生成的目标词的表示。
键(Key):代表每个位置的特征描述,用于与查询进行匹配。键可以理解为"这个位置提供了什么信息"的索引。
值(Value):代表每个位置的实际内容。当注意力权重计算完成后,值会根据权重进行加权求和,得到最终的输出。
用一个直观的比喻:想象你在图书馆找书。查询是你想找的书的特征(比如"机器学习"),键是每本书的标签(书脊上的标题),值是书的实际内容。你通过比较查询和键的相似度来决定关注哪些书,然后读取这些书的内容。
2. 人类注意力的启发
2.1 视觉注意力
人类的视觉系统具有选择性注意力的能力。当我们看一张图片时,我们不会同时关注图片的所有区域,而是会聚焦于最显著的部分。这种选择性注意力可以分为两种类型:
自下而上的注意力:由外部刺激驱动,自动被显著的物体吸引。例如,在一片绿色的草地上,一朵红色的花会自然吸引我们的注意力。
自上而下的注意力:由内部目标驱动,主动寻找特定的信息。例如,当我们在人群中寻找一个穿红衣服的朋友时,我们会主动关注红色的物体。
2.2 认知注意力
在认知科学中,注意力被理解为一种有限的认知资源。人类的大脑无法同时处理所有的信息,因此需要选择性地分配注意力资源。这种选择性体现在:
选择性:从众多信息中选择最重要的进行处理。
持续性:在一段时间内保持对特定信息的关注。
分配性:在多个任务之间分配注意力资源。
2.3 从人类到机器
深度学习中的注意力机制正是借鉴了人类注意力的这些特性。模型通过学习一个注意力函数,自动决定在处理每个位置时应该关注输入的哪些部分。这种机制使得模型能够:
动态调整关注点:根据输入内容和当前任务,模型可以灵活地调整其关注的区域。
处理变长输入:注意力机制可以处理任意长度的输入序列,不受固定窗口大小的限制。
捕捉层次化信息:通过多层注意力,模型可以捕捉从低级特征到高级语义的多层次信息。
3. 注意力机制的发展历史
3.1 早期探索(2014年之前)
注意力机制的思想可以追溯到早期的神经网络研究。在20世纪90年代,就有研究者探索让神经网络学习关注输入的不同部分。但这些早期工作受限于计算能力和模型架构,没有得到广泛应用。
3.2 Bahdanau注意力(2014)
2014年,Bahdanau等人在机器翻译任务中首次提出了现代意义上的注意力机制。在此之前,机器翻译主要使用编码器-解码器架构,编码器将整个源句子压缩为一个固定长度的向量,解码器基于这个向量生成翻译。
这种方法的问题是:一个固定长度的向量无法承载长句子的全部信息,特别是对于长句子,翻译质量会显著下降。
Bahdanau的解决方案是:在解码器生成每个词时,动态地关注源句子的不同部分。具体来说,在生成第t个词时,解码器会计算源句子每个位置的注意力权重,然后根据这些权重对编码器的隐藏状态进行加权求和,得到上下文向量。
3.3 Luong注意力(2015)
Luong等人在2015年提出了几种不同的注意力计算方式,包括点积注意力、一般注意力和拼接注意力。这些变体在不同的任务上有不同的表现,但核心思想与Bahdanau注意力相同。
3.4 自注意力机制(2016)
2016年,研究者开始探索自注意力(Self-Attention)机制。与传统的注意力机制不同,自注意力的查询、键和值都来自同一个序列。这使得模型能够捕捉序列内部的依赖关系,而不仅仅是源序列和目标序列之间的对应关系。
3.5 Transformer(2017)
2017年,Vaswani等人在论文"Attention Is All You Need"中提出了Transformer架构,完全基于注意力机制,抛弃了RNN和CNN。Transformer引入了多头注意力、位置编码等关键创新,成为现代NLP的基础架构。
3.6 视觉Transformer(2020)
2020年,Dosovitskiy等人提出了Vision Transformer(ViT),将Transformer应用于图像分类任务。这标志着注意力机制从NLP扩展到计算机视觉,开启了视觉Transformer的时代。
4. 基础注意力机制
4.1 加性注意力
加性注意力(Additive Attention)是Bahdanau等人提出的方法,也称为Bahdanau注意力。它通过一个前馈神经网络来计算注意力分数:
e i j = v T tanh ( W 1 s i − 1 + W 2 h j ) e_{ij} = v^T \tanh(W_1 s_{i-1} + W_2 h_j) eij=vTtanh(W1si−1+W2hj)
其中 s i − 1 s_{i-1} si−1 是解码器在时刻i-1的隐藏状态, h j h_j hj 是编码器在位置j的隐藏状态, W 1 W_1 W1、 W 2 W_2 W2 和 v v v 是可学习的参数。
注意力权重通过Softmax归一化:
α i j = exp ( e i j ) ∑ k = 1 T x exp ( e i k ) \alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k=1}^{T_x} \exp(e_{ik})} αij=∑k=1Txexp(eik)exp(eij)
上下文向量是编码器隐藏状态的加权和:
c i = ∑ j = 1 T x α i j h j c_i = \sum_{j=1}^{T_x} \alpha_{ij} h_j ci=j=1∑Txαijhj
实现:
python
class AdditiveAttention(nn.Module):
def __init__(self, encoder_dim, decoder_dim, attention_dim):
super().__init__()
self.W1 = nn.Linear(encoder_dim, attention_dim, bias=False)
self.W2 = nn.Linear(decoder_dim, attention_dim, bias=False)
self.v = nn.Linear(attention_dim, 1, bias=False)
def forward(self, encoder_outputs, decoder_hidden):
"""
Args:
encoder_outputs: [batch, src_len, encoder_dim]
decoder_hidden: [batch, decoder_dim]
"""
# 扩展decoder_hidden维度
decoder_hidden = decoder_hidden.unsqueeze(1).repeat(
1, encoder_outputs.size(1), 1
) # [batch, src_len, decoder_dim]
# 计算注意力分数
energy = torch.tanh(
self.W1(encoder_outputs) + self.W2(decoder_hidden)
) # [batch, src_len, attention_dim]
attention = self.v(energy).squeeze(2) # [batch, src_len]
# Softmax归一化
weights = F.softmax(attention, dim=1) # [batch, src_len]
# 加权求和
context = torch.bmm(weights.unsqueeze(1), encoder_outputs).squeeze(1)
return context, weights
4.2 乘性注意力
乘性注意力(Multiplicative Attention)也称为Luong注意力或点积注意力。它通过计算查询和键的点积来得到注意力分数:
e i j = s i T W h j e_{ij} = s_i^T W h_j eij=siTWhj
其中 W W W 是可学习的权重矩阵。当 W W W 是单位矩阵时,退化为点积注意力:
e i j = s i T h j e_{ij} = s_i^T h_j eij=siThj
实现:
python
class MultiplicativeAttention(nn.Module):
def __init__(self, encoder_dim, decoder_dim):
super().__init__()
self.W = nn.Linear(encoder_dim, decoder_dim, bias=False)
def forward(self, encoder_outputs, decoder_hidden):
# 变换encoder_outputs
transformed = self.W(encoder_outputs) # [batch, src_len, decoder_dim]
# 计算注意力分数
scores = torch.bmm(transformed, decoder_hidden.unsqueeze(2)).squeeze(2)
# Softmax归一化
weights = F.softmax(scores, dim=1)
# 加权求和
context = torch.bmm(weights.unsqueeze(1), encoder_outputs).squeeze(1)
return context, weights
4.3 缩放点积注意力
缩放点积注意力(Scaled Dot-Product Attention)是Transformer使用的核心注意力机制。它在点积注意力的基础上引入了缩放因子:
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
为什么要缩放?当 d k d_k dk 较大时, Q K T QK^T QKT 的值会很大,导致Softmax函数进入饱和区,梯度接近于0。除以 d k \sqrt{d_k} dk 可以将方差控制在合理范围内。
数学推导:
假设Q和K的每个元素都是独立同分布的,均值为0,方差为1。那么 Q ⋅ K = ∑ i = 1 d k q i k i Q \cdot K = \sum_{i=1}^{d_k} q_i k_i Q⋅K=∑i=1dkqiki,其方差为 d k d_k dk。除以 d k \sqrt{d_k} dk 后,方差变为1。
5. 自注意力机制
5.1 核心思想
自注意力(Self-Attention)是Transformer的核心创新。与传统的注意力机制不同,自注意力的查询、键和值都来自同一个序列。这意味着序列中的每个位置都能"看到"并"关注"序列中的所有其他位置。
自注意力的核心优势在于它能够直接建立任意两个位置之间的依赖关系,无论它们之间的距离有多远。这解决了RNN的长距离依赖问题。
5.2 数学公式
给定输入序列 X ∈ R n × d X \in \mathbb{R}^{n \times d} X∈Rn×d,自注意力的计算过程如下:
首先,通过三个线性变换得到查询、键和值:
Q = X W Q , K = X W K , V = X W V Q = XW_Q, \quad K = XW_K, \quad V = XW_V Q=XWQ,K=XWK,V=XWV
其中 W Q , W K ∈ R d × d k W_Q, W_K \in \mathbb{R}^{d \times d_k} WQ,WK∈Rd×dk, W V ∈ R d × d v W_V \in \mathbb{R}^{d \times d_v} WV∈Rd×dv。
然后,计算注意力权重:
A = softmax ( Q K T d k ) A = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) A=softmax(dk QKT)
最后,计算输出:
output = A V \text{output} = AV output=AV
5.3 直观理解
自注意力可以理解为一种"信息聚合"机制。对于序列中的每个位置,自注意力会计算该位置与所有其他位置的相关性,然后根据相关性对所有位置的信息进行加权求和。
例如,在句子"The animal didn't cross the street because it was too tired"中,当处理"it"这个词时,自注意力会计算"it"与其他所有词的相关性。由于"it"指代的是"animal",所以"animal"的注意力权重会最高,模型会更多地关注"animal"的信息。
5.4 与RNN的对比
| 特性 | RNN | 自注意力 |
|---|---|---|
| 计算路径长度 | O(n) | O(1) |
| 并行性 | 顺序 | 完全并行 |
| 长距离依赖 | 困难 | 直接建模 |
| 计算复杂度 | O(n·d²) | O(n²·d) |
RNN需要经过n步才能建立位置1和位置n之间的连接,而自注意力只需要一步。这使得自注意力在处理长序列时更加有效。
5.5 实现
python
class SelfAttention(nn.Module):
def __init__(self, d_model, d_k=None, d_v=None):
super().__init__()
self.d_k = d_k or d_model
self.d_v = d_v or d_model
self.W_Q = nn.Linear(d_model, self.d_k)
self.W_K = nn.Linear(d_model, self.d_k)
self.W_V = nn.Linear(d_model, self.d_v)
def forward(self, x, mask=None):
"""
Args:
x: [batch_size, seq_len, d_model]
mask: [batch_size, seq_len, seq_len] or None
"""
Q = self.W_Q(x) # [batch, seq_len, d_k]
K = self.W_K(x) # [batch, seq_len, d_k]
V = self.W_V(x) # [batch, seq_len, d_v]
# 计算注意力分数
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
# 应用掩码
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
# Softmax归一化
weights = F.softmax(scores, dim=-1)
# 加权求和
output = torch.matmul(weights, V)
return output, weights
6. 多头注意力机制
6.1 核心思想
多头注意力(Multi-Head Attention)是Transformer的关键创新之一。它的核心思想是:与其使用单一的注意力函数,不如使用多个独立的注意力函数,每个注意力函数关注不同方面的信息。
这类似于人类的注意力方式:当我们阅读一段文字时,我们可能同时关注语法关系、语义关系、上下文关系等不同方面的信息。多头注意力让模型能够同时学习这些不同类型的注意力模式。
6.2 数学公式
多头注意力的计算过程如下:
MultiHead ( Q , K , V ) = Concat ( head 1 , . . . , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO
其中每个注意力头:
head i = Attention ( Q W i Q , K W i K , V W i V ) \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) headi=Attention(QWiQ,KWiK,VWiV)
参数维度:
- W i Q ∈ R d m o d e l × d k W_i^Q \in \mathbb{R}^{d_{model} \times d_k} WiQ∈Rdmodel×dk
- W i K ∈ R d m o d e l × d k W_i^K \in \mathbb{R}^{d_{model} \times d_k} WiK∈Rdmodel×dk
- W i V ∈ R d m o d e l × d v W_i^V \in \mathbb{R}^{d_{model} \times d_v} WiV∈Rdmodel×dv
- W O ∈ R h d v × d m o d e l W^O \in \mathbb{R}^{hd_v \times d_{model}} WO∈Rhdv×dmodel
通常 d k = d v = d m o d e l / h d_k = d_v = d_{model} / h dk=dv=dmodel/h,这样总计算量与单头注意力相当。
6.3 注意力头的功能分化
研究发现,不同的注意力头会学习到不同的功能:
语法头:关注主谓关系、修饰关系等语法结构。例如,某些头会学习关注形容词修饰的名词。
位置头:关注相对位置关系。例如,某些头会学习关注前一个或后一个位置的词。
语义头:关注语义相关的词。例如,某些头会学习关注同义词或相关概念。
稀疏头:有些头几乎不关注任何位置,可能是冗余的。
6.4 实现
python
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.1):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 线性变换
Q = self.W_Q(Q) # [batch, seq_len, d_model]
K = self.W_K(K)
V = self.W_V(V)
# 分割成多头
Q = Q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
K = K.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
V = V.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
# 计算注意力
scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.d_k ** 0.5)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1)
weights = self.dropout(weights)
# 加权求和
output = torch.matmul(weights, V) # [batch, n_heads, seq_len, d_k]
# 合并多头
output = output.transpose(1, 2).contiguous().view(
batch_size, -1, self.d_model
)
# 输出投影
output = self.W_O(output)
return output, weights
7. 位置编码
7.1 为什么需要位置编码
自注意力机制是"置换等变"的(permutation equivariant),即打乱输入序列的顺序,输出也会相应打乱,但不会改变内容。这意味着自注意力无法区分"猫追狗"和"狗追猫",因为它们包含相同的词,只是顺序不同。
为了解决这个问题,需要向输入中注入位置信息,这就是位置编码的作用。
7.2 正弦位置编码
Transformer使用正弦和余弦函数生成位置编码:
P E ( p o s , 2 i ) = sin ( p o s / 10000 2 i / d m o d e l ) PE_{(pos, 2i)} = \sin(pos / 10000^{2i/d_{model}}) PE(pos,2i)=sin(pos/100002i/dmodel)
P E ( p o s , 2 i + 1 ) = cos ( p o s / 10000 2 i / d m o d e l ) PE_{(pos, 2i+1)} = \cos(pos / 10000^{2i/d_{model}}) PE(pos,2i+1)=cos(pos/100002i/dmodel)
其中pos是位置索引,i是维度索引。
设计直觉:
- 唯一性:每个位置有唯一的编码向量。
- 有界性:编码值在-1, 1之间,不会发散。
- 相对位置 : P E p o s + k PE_{pos+k} PEpos+k 可以表示为 P E p o s PE_{pos} PEpos 的线性函数,使得模型能够学习相对位置关系。
- 泛化性:可以处理比训练时更长的序列。
7.3 可学习位置编码
另一种方法是使用可学习的位置编码,即为每个位置学习一个嵌入向量。BERT和GPT都使用这种方法。
与正弦位置编码相比,可学习位置编码的参数更多,但通常效果相当。正弦位置编码的优势是可以泛化到更长的序列,而可学习位置编码受限于训练时的最大长度。
7.4 旋转位置编码 (RoPE)
旋转位置编码(Rotary Position Embedding, RoPE)是近年来被广泛采用的位置编码方法,被LLaMA、Mistral等现代大语言模型使用。
RoPE的核心思想是将位置信息编码为旋转矩阵,使得两个位置的向量的内积自然地包含相对位置信息:
f ( q , m ) = R Θ , m q f(q, m) = R_{\Theta, m} q f(q,m)=RΘ,mq
其中 R Θ , m R_{\Theta, m} RΘ,m 是旋转矩阵。关键性质是:
⟨ f ( q , m ) , f ( k , n ) ⟩ = g ( q , k , m − n ) \langle f(q, m), f(k, n) \rangle = g(q, k, m-n) ⟨f(q,m),f(k,n)⟩=g(q,k,m−n)
即两个向量的内积只依赖于它们的相对位置 m − n m-n m−n。
8. 注意力的变体
8.1 交叉注意力
交叉注意力(Cross-Attention)的查询来自一个序列,键和值来自另一个序列。它用于建立两个不同序列之间的对应关系。
在机器翻译中,交叉注意力让解码器在生成每个词时关注源句子的相关部分。在图像描述生成中,交叉注意力让语言模型关注图像的不同区域。
8.2 因果注意力
因果注意力(Causal Attention)也称为掩码自注意力,它确保每个位置只能关注之前的位置,不能看到未来的信息。这在自回归语言模型中是必要的,因为生成过程是顺序的,不应该看到尚未生成的词。
实现方式是在注意力分数上应用一个上三角掩码,将未来位置的分数设为负无穷。
8.3 稀疏注意力
全注意力的计算复杂度是 O ( n 2 ) O(n^2) O(n2),对于长序列来说计算量很大。稀疏注意力通过只计算部分位置之间的注意力来降低复杂度。
常见的稀疏注意力模式包括:
局部窗口注意力:每个位置只关注附近的k个位置。
全局注意力:部分特殊位置(如CLS)关注所有位置。
随机注意力:每个位置随机关注r个位置。
组合模式:结合多种模式,如Longformer使用局部窗口+全局注意力。
8.4 线性注意力
线性注意力将Softmax替换为核函数,使得计算复杂度从 O ( n 2 d ) O(n^2d) O(n2d) 降低到 O ( n d 2 ) O(nd^2) O(nd2):
Attention ( Q , K , V ) = ϕ ( Q ) ( ϕ ( K ) T V ) ϕ ( Q ) ∑ i ϕ ( k i ) \text{Attention}(Q, K, V) = \frac{\phi(Q)(\phi(K)^T V)}{\phi(Q)\sum_i \phi(k_i)} Attention(Q,K,V)=ϕ(Q)∑iϕ(ki)ϕ(Q)(ϕ(K)TV)
其中 ϕ \phi ϕ 是特征映射函数(如elu+1)。
9. 高效注意力机制
9.1 Flash Attention
Flash Attention是一种IO感知的注意力算法,通过分块计算和在线Softmax技术,避免存储完整的注意力矩阵,将内存复杂度从 O ( n 2 ) O(n^2) O(n2) 降低到 O ( n ) O(n) O(n)。
Flash Attention的核心思想是:
- 将Q、K、V分成小块
- 在SRAM(高速缓存)中计算注意力块
- 使用在线softmax算法,不需要存储完整的注意力矩阵
- 将结果写回HBM(主存)
这种方法不仅减少了内存使用,还提高了计算速度,因为它减少了对HBM的访问次数。
9.2 分组查询注意力 (GQA)
分组查询注意力(Grouped Query Attention, GQA)是多头注意力和多查询注意力之间的折中方案。它让多个查询头共享同一组键值头,从而减少KV缓存的大小,提高推理效率。
在GQA中,如果有h个查询头和g个键值头(g < h),则每h/g个查询头共享一组键值头。当g=1时,退化为多查询注意力(MQA);当g=h时,退化为标准多头注意力。
LLaMA 2、Mistral等现代大语言模型都采用了GQA。
9.3 滑动窗口注意力
滑动窗口注意力(Sliding Window Attention)限制每个位置只关注其前后固定窗口内的位置。这将计算复杂度从 O ( n 2 ) O(n^2) O(n2) 降低到 O ( n w ) O(nw) O(nw),其中w是窗口大小。
Mistral使用滑动窗口注意力来处理长序列,同时保持较低的计算成本。
10. 视觉注意力机制
10.1 通道注意力
通道注意力(Channel Attention)学习不同特征通道的重要性权重。SENet(Squeeze-and-Excitation Network)是通道注意力的代表工作。
它首先通过全局平均池化将每个通道压缩为一个标量,然后通过两个全连接层学习通道间的依赖关系,最后用学到的权重对通道进行加权。
10.2 空间注意力
空间注意力(Spatial Attention)学习图像中不同空间位置的重要性权重。它关注"应该关注图像的哪个位置"。
实现方式是:首先计算每个通道的平均值和最大值,得到两个空间描述符,然后通过卷积层学习空间注意力权重。
10.3 CBAM
CBAM(Convolutional Block Attention Module)将通道注意力和空间注意力串联起来,先进行通道注意力,再进行空间注意力。
10.4 视觉Transformer
Vision Transformer(ViT)将Transformer直接应用于图像。它将图像分割为固定大小的patch,每个patch作为一个token,然后使用标准的Transformer编码器处理。
11. 注意力的可解释性
11.1 注意力可视化
通过可视化注意力权重,我们可以理解模型在做决策时关注了哪些信息。这为模型的可解释性提供了重要工具。
例如,在机器翻译中,我们可以看到解码器在生成每个词时关注源句子的哪些位置。在图像分类中,我们可以看到模型关注图像的哪些区域。
11.2 注意力与可解释性的关系
研究表明,注意力权重与模型的决策有一定的相关性,但不能完全等同于特征重要性。注意力权重高不一定意味着该位置对决策更重要,因为还需要考虑值向量的大小。
11.3 注意力头的分析
通过分析不同注意力头的功能,我们可以理解模型学到了什么样的知识。例如,有些头专注于语法关系,有些头专注于语义关系,有些头关注位置信息。
12. 完整代码实现
12.1 完整的Transformer实现
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads, dropout=0.1):
super().__init__()
assert d_model % n_heads == 0
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
self.W_Q = nn.Linear(d_model, d_model)
self.W_K = nn.Linear(d_model, d_model)
self.W_V = nn.Linear(d_model, d_model)
self.W_O = nn.Linear(d_model, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
Q = self.W_Q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_K(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_V(V).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
weights = F.softmax(scores, dim=-1)
weights = self.dropout(weights)
output = torch.matmul(weights, V)
output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
output = self.W_O(output)
return output, weights
class PositionWiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super().__init__()
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
return self.fc2(self.dropout(F.gelu(self.fc1(x))))
class TransformerBlock(nn.Module):
def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
super().__init__()
self.attention = MultiHeadAttention(d_model, n_heads, dropout)
self.ffn = PositionWiseFeedForward(d_model, d_ff, dropout)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, mask=None):
# 自注意力
attn_output, attn_weights = self.attention(x, x, x, mask)
x = self.norm1(x + self.dropout1(attn_output))
# 前馈网络
ffn_output = self.ffn(x)
x = self.norm2(x + self.dropout2(ffn_output))
return x, attn_weights
class Transformer(nn.Module):
def __init__(self, vocab_size, d_model=512, n_heads=8, n_layers=6,
d_ff=2048, max_len=5000, dropout=0.1):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.positional_encoding = self._create_positional_encoding(max_len, d_model)
self.layers = nn.ModuleList([
TransformerBlock(d_model, n_heads, d_ff, dropout)
for _ in range(n_layers)
])
self.norm = nn.LayerNorm(d_model)
self.dropout = nn.Dropout(dropout)
def _create_positional_encoding(self, max_len, d_model):
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1).float()
div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
return pe.unsqueeze(0)
def forward(self, x, mask=None):
seq_len = x.size(1)
# 嵌入 + 位置编码
x = self.embedding(x) * math.sqrt(self.d_model)
x = x + self.positional_encoding[:, :seq_len, :].to(x.device)
x = self.dropout(x)
# Transformer层
attention_weights = []
for layer in self.layers:
x, weights = layer(x, mask)
attention_weights.append(weights)
x = self.norm(x)
return x, attention_weights
13. 应用场景
13.1 机器翻译
注意力机制在机器翻译中的应用是最经典的成功案例。通过注意力,翻译模型能够在生成每个目标词时动态关注源句子的相关部分,显著提高了翻译质量。
13.2 文本摘要
在文本摘要任务中,注意力机制帮助模型关注原文中最关键的信息,生成准确、简洁的摘要。
13.3 图像描述生成
在图像描述生成任务中,注意力机制让语言模型在生成每个词时关注图像的不同区域,实现图像和文本的对齐。
13.4 语音识别
在语音识别中,注意力机制帮助模型对齐音频信号和文字序列,处理不同长度的输入输出。
13.5 推荐系统
在推荐系统中,注意力机制用于建模用户行为序列中不同交互的重要性,提高推荐的准确性。
14. 参考资料
核心论文
- Bahdanau Attention: "Neural Machine Translation by Jointly Learning to Align and Translate" (Bahdanau et al., 2014)
- Luong Attention: "Effective Approaches to Attention-based Neural Machine Translation" (Luong et al., 2015)
- Transformer: "Attention Is All You Need" (Vaswani et al., 2017)
- Vision Transformer: "An Image is Worth 16x16 Words" (Dosovitskiy et al., 2020)
- Flash Attention: "FlashAttention: Fast and Memory-Efficient Exact Attention" (Dao et al., 2022)
- RoPE: "RoFormer: Enhanced Transformer with Rotary Position Embedding" (Su et al., 2021)
开源库
- Transformers: https://github.com/huggingface/transformers
- Flash Attention: https://github.com/Dao-AILab/flash-attention