注意力机制
注意力机制,就是让模型在看一整段内容时,知道 "该重点看哪里"。
比如你读一句话:
"我喜欢在傍晚看湖边的落日"
当模型要理解 "落日" 时,它会自动把注意力放在:看、傍晚、湖边 这些相关词上。而忽略不重要的词
注意力机制 = 给每个词分配一个权重,重要的词权重高,不重要的权重低,让模型只关注关键信息。
注意力机制的核心计算公式涉及查询(Query)、键(Key)、值(Value)的交互以及权重分配。以下是关键公式的详细说明:
注意力权重计算
注意力权重通过查询(Q)与键(K)的相似度计算,通常使用缩放点积注意力(Scaled Dot-Product Attention)。公式如下:
其中:
- (
) 分别表示查询、键和值矩阵。
- (
) 是键的维度,缩放因子 (
) 用于防止点积过大导致梯度消失。
多头注意力机制
多头注意力通过并行计算多个注意力头并拼接结果,增强模型的表达能力。公式如下:
其中:
- 是第 ( i ) 个头的投影矩阵。
是输出投影矩阵。、
自注意力机制
自注意力中,查询、键、值均来自同一输入序列。公式与缩放点积注意力一致,但输入来源相同:
掩码注意力
在解码器中,为避免未来信息泄露,使用掩码(通常为上三角矩阵)屏蔽未来位置:
其中 ( M ) 为掩码矩阵,未来位置设为 ( -\infty ),当前及过去位置设为 ( 0 )。
加性注意力
另一种计算方式为加性注意力(Additive Attention),适用于查询和键维度不一致的情况:
其中 ( ) 和 (
) 是可学习参数矩阵。
实现
注意力机制核心代码实现
以下是一个基于PyTorch的注意力机制实现示例,关键代码部分添加详细注释:
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attention(nn.Module):
def __init__(self, embed_dim, hidden_dim=None, output_dim=None):
super().__init__()
# 如果未指定hidden_dim,默认使用embed_dim
hidden_dim = hidden_dim or embed_dim
# 如果未指定output_dim,默认使用embed_dim
output_dim = output_dim or embed_dim
# 线性变换层,用于生成Q、K、V矩阵
self.q_proj = nn.Linear(embed_dim, hidden_dim)
self.k_proj = nn.Linear(embed_dim, hidden_dim)
self.v_proj = nn.Linear(embed_dim, output_dim)
# 缩放因子,防止点积结果过大导致softmax梯度消失
self.scale = hidden_dim ** -0.5
def forward(self, x):
# x形状: (batch_size, seq_len, embed_dim)
batch_size, seq_len, embed_dim = x.shape
# 生成Q、K、V矩阵
q = self.q_proj(x) # (batch_size, seq_len, hidden_dim)
k = self.k_proj(x) # (batch_size, seq_len, hidden_dim)
v = self.v_proj(x) # (batch_size, seq_len, output_dim)
# 计算注意力分数
# q @ k.transpose(-2, -1) 计算点积
attn_scores = (q @ k.transpose(-2, -1)) * self.scale # (batch_size, seq_len, seq_len)
# 应用softmax得到注意力权重
attn_weights = F.softmax(attn_scores, dim=-1) # (batch_size, seq_len, seq_len)
# 加权求和得到输出
output = attn_weights @ v # (batch_size, seq_len, output_dim)
return output
多头注意力实现
python
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads=8):
super().__init__()
assert embed_dim % num_heads == 0, "embed_dim必须能被num_heads整除"
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
# 合并的QKV投影矩阵
self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim)
# 输出投影
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.scale = self.head_dim ** -0.5
def forward(self, x):
batch_size, seq_len, embed_dim = x.shape
# 生成Q、K、V并分割为多头
qkv = self.qkv_proj(x)
qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, batch_size, num_heads, seq_len, head_dim)
q, k, v = qkv[0], qkv[1], qkv[2]
# 计算注意力分数
attn_scores = (q @ k.transpose(-2, -1)) * self.scale # (batch_size, num_heads, seq_len, seq_len)
# 应用softmax
attn_weights = F.softmax(attn_scores, dim=-1)
# 加权求和
output = attn_weights @ v # (batch_size, num_heads, seq_len, head_dim)
# 合并多头输出
output = output.transpose(1, 2) # (batch_size, seq_len, num_heads, head_dim)
output = output.reshape(batch_size, seq_len, embed_dim) # (batch_size, seq_len, embed_dim)
# 最终线性变换
output = self.out_proj(output)
return output
位置编码实现
python
class PositionalEncoding(nn.Module):
def __init__(self, embed_dim, max_len=5000):
super().__init__()
# 计算位置编码
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim))
pe = torch.zeros(max_len, embed_dim)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0) # (1, max_len, embed_dim)
self.register_buffer('pe', pe)
def forward(self, x):
# x形状: (batch_size, seq_len, embed_dim)
# 添加位置编码
x = x + self.pe[:, :x.size(1)]
return x
使用示例
python
# 创建注意力层实例
attention = Attention(embed_dim=512)
# 创建输入张量 (batch_size=2, seq_len=10, embed_dim=512)
x = torch.randn(2, 10, 512)
# 前向传播
output = attention(x)
print(output.shape) # torch.Size([2, 10, 512])
这些代码展示了PyTorch中实现注意力机制的核心组件,包括基础注意力、多头注意力和位置编码。注释详细解释了每个关键步骤的作用和实现原理。