本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在聚客AI学院。
本文深入剖析Transformer的核心创新------Self-Attention机制,通过数学推导、代码实现和可视化,全面讲解Query/Key/Value概念、Scaled Dot-Product Attention原理以及Multi-Head Attention实现细节。
一、Self-Attention机制:序列建模的革命
1.1 传统序列建模的局限性
css
graph LR
A[RNN/LSTM] --> B[顺序处理]
B --> C[无法并行]
C --> D[长程依赖衰减]
D --> E[梯度消失/爆炸]
1.2 Self-Attention核心思想
python
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# 输入序列 (batch_size=1, seq_length=4, embedding_dim=8)
x = torch.tensor([[
[1.0, 0.5, 0.8, 2.0, 0.1, 1.5, 0.3, 1.2],
[0.7, 1.2, 0.4, 1.8, 0.9, 0.6, 1.1, 0.2],
[1.3, 0.3, 1.7, 0.6, 1.4, 0.8, 0.5, 1.9],
[0.2, 1.5, 1.1, 0.7, 0.3, 1.8, 1.6, 0.4]
]])
print("输入序列形状:", x.shape)
Self-Attention三大核心向量:
- Query (Q):当前关注的词向量
- Key (K):用于被查询的标识向量
- Value (V):实际传递信息的向量
python
class SelfAttention(nn.Module):
def __init__(self, embed_size):
super().__init__()
self.embed_size = embed_size
# 线性变换层
self.Wq = nn.Linear(embed_size, embed_size)
self.Wk = nn.Linear(embed_size, embed_size)
self.Wv = nn.Linear(embed_size, embed_size)
def forward(self, x):
Q = self.Wq(x) # Query
K = self.Wk(x) # Key
V = self.Wv(x) # Value
return Q, K, V
# 生成Q,K,V
attention = SelfAttention(embed_size=8)
Q, K, V = attention(x)
print("Query形状:", Q.shape)
print("Key形状:", K.shape)
print("Value形状:", V.shape)
Self-Attention核心优势:
全局依赖:直接捕获任意位置间的关系
并行计算:所有位置同时计算注意力
长程建模:无距离衰减的信息传递
可解释性:注意力权重可视化决策依据
二、Scaled Dot-Product Attention:注意力计算核心

2.1 数学原理详解

计算步骤分解:
相似度计算: <math xmlns="http://www.w3.org/1998/Math/MathML"> Q K T QK^T </math>QKT(查询与键的点积)
缩放处理:除以 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk (防止梯度消失)
权重归一化:softmax函数
加权求和:乘以Value向量
ini
def scaled_dot_product_attention(Q, K, V):
# Step 1: 计算Q和K的点积
matmul_qk = torch.matmul(Q, K.transpose(-2, -1))
# Step 2: 缩放处理
d_k = K.size(-1)
scaled_attention_logits = matmul_qk / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))
# Step 3: softmax归一化
attention_weights = torch.softmax(scaled_attention_logits, dim=-1)
# Step 4: 加权求和
output = torch.matmul(attention_weights, V)
return output, attention_weights
# 计算注意力
output, attn_weights = scaled_dot_product_attention(Q, K, V)
print("注意力输出形状:", output.shape)
print("注意力权重形状:", attn_weights.shape)
2.2 注意力权重可视化
scss
# 可视化注意力权重
plt.figure(figsize=(10, 8))
plt.imshow(attn_weights.detach().squeeze().numpy(), cmap='viridis')
plt.title('Self-Attention权重矩阵')
plt.xlabel('Key位置')
plt.ylabel('Query位置')
plt.colorbar()
plt.xticks(range(4), ['词1', '词2', '词3', '词4'])
plt.yticks(range(4), ['词1', '词2', '词3', '词4'])
# 添加权重值
for i in range(attn_weights.shape[-2]):
for j in range(attn_weights.shape[-1]):
plt.text(j, i, f"{attn_weights[0,i,j].item():.2f}",
ha="center", va="center", color="w")
plt.show()
缩放因子 <math xmlns="http://www.w3.org/1998/Math/MathML"> d k \sqrt{d_k} </math>dk 的数学意义:
假设 <math xmlns="http://www.w3.org/1998/Math/MathML"> q q </math>q和 <math xmlns="http://www.w3.org/1998/Math/MathML"> k k </math>k是独立随机变量,均值为0,方差为1
则点积 <math xmlns="http://www.w3.org/1998/Math/MathML"> q ⋅ k = ∑ i = 1 d k q i k i q \cdot k = \sum_{i=1}^{d_k} q_i k_i </math>q⋅k=∑i=1dkqiki的:
- 均值: <math xmlns="http://www.w3.org/1998/Math/MathML"> E [ q ⋅ k ] = 0 E[q \cdot k] = 0 </math>E[q⋅k]=0
- 方差: <math xmlns="http://www.w3.org/1998/Math/MathML"> Var ( q ⋅ k ) = d k \text{Var}(q \cdot k) = d_k </math>Var(q⋅k)=dk
缩放后方差变为1,保持梯度稳定性:

三、Multi-Head Attention:多视角注意力

3.1 多头注意力原理
css
graph LR
A[输入向量] --> B[线性变换]
B --> C1[头1 QKV]
B --> C2[头2 QKV]
B --> C3[头n QKV]
C1 --> D1[Scaled Dot-Attention]
C2 --> D2[Scaled Dot-Attention]
C3 --> Dn[Scaled Dot-Attention]
D1 --> E[拼接输出]
D2 --> E
Dn --> E
E --> F[线性变换]
F --> G[最终输出]
3.2 完整多头注意力实现
ini
class MultiHeadAttention(nn.Module):
def __init__(self, embed_size, num_heads):
super().__init__()
self.embed_size = embed_size
self.num_heads = num_heads
self.head_dim = embed_size // num_heads
assert self.head_dim * num_heads == embed_size, "嵌入维度必须是头数的整数倍"
# 线性变换层
self.Wq = nn.Linear(embed_size, embed_size)
self.Wk = nn.Linear(embed_size, embed_size)
self.Wv = nn.Linear(embed_size, embed_size)
self.fc_out = nn.Linear(embed_size, embed_size)
def split_heads(self, x):
"""将嵌入维度分割为多个头"""
batch_size, seq_length, _ = x.size()
return x.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# 线性变换
Q = self.Wq(Q)
K = self.Wk(K)
V = self.Wv(V)
# 分割多头
Q = self.split_heads(Q) # (batch_size, num_heads, seq_len, head_dim)
K = self.split_heads(K)
V = self.split_heads(V)
# 计算缩放点积注意力
attn_output, attn_weights = scaled_dot_product_attention(Q, K, V)
# 拼接多头输出
attn_output = attn_output.transpose(1, 2).contiguous().view(
batch_size, -1, self.embed_size)
# 最终线性变换
output = self.fc_out(attn_output)
return output, attn_weights
# 测试多头注意力
embed_size = 8
num_heads = 2
multihead_attn = MultiHeadAttention(embed_size, num_heads)
output, attn_weights = multihead_attn(x, x, x)
print("多头注意力输出形状:", output.shape)
print("多头注意力权重形状:", attn_weights.shape) # (batch_size, num_heads, seq_len, seq_len)
3.3 多头注意力可视化
scss
# 可视化不同头的注意力权重
fig, axes = plt.subplots(1, num_heads, figsize=(15, 5))
for i in range(num_heads):
ax = axes[i]
head_weights = attn_weights[0, i].detach().numpy()
im = ax.imshow(head_weights, cmap='viridis')
ax.set_title(f'头 {i+1} 注意力权重')
ax.set_xlabel('Key位置')
ax.set_ylabel('Query位置')
fig.colorbar(im, ax=ax)
# 添加权重值
for row in range(head_weights.shape[0]):
for col in range(head_weights.shape[1]):
ax.text(col, row, f"{head_weights[row, col]:.2f}",
ha="center", va="center", color="w", fontsize=8)
plt.tight_layout()
plt.show()
多头注意力的优势:
多视角建模:每个头关注不同特征空间
并行计算:多个头可同时独立计算
表征能力增强:组合不同子空间信息
可解释性提升:不同头可学习不同关系
四、Transformer中的注意力应用
4.1 编码器-解码器注意力

4.2 三种注意力模式
编码器自注意力:源序列内部关系
ini
encoder_self_attn = MultiHeadAttention(embed_size, num_heads)
encoder_output, _ = encoder_self_attn(src, src, src)
解码器自注意力:目标序列内部关系(带掩码)
ini
# 创建下三角掩码
def create_mask(size):
mask = torch.tril(torch.ones(size, size))
return mask.masked_fill(mask == 0, float('-inf'))
mask = create_mask(tgt.size(1))
decoder_self_attn = MultiHeadAttention(embed_size, num_heads)
decoder_output, _ = decoder_self_attn(tgt, tgt, tgt, mask)
编码器-解码器注意力:源序列与目标序列间关系
ini
cross_attn = MultiHeadAttention(embed_size, num_heads)
cross_output, _ = cross_attn(decoder_output, encoder_output, encoder_output)
4.3 完整Transformer层实现
ini
class TransformerBlock(nn.Module):
"""完整的Transformer编码器层"""
def __init__(self, embed_size, num_heads, ff_dim, dropout=0.1):
super().__init__()
# 多头注意力
self.attention = MultiHeadAttention(embed_size, num_heads)
# 前馈网络
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, ff_dim),
nn.ReLU(),
nn.Linear(ff_dim, embed_size)
)
# 归一化层
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
# Dropout
self.dropout = nn.Dropout(dropout)
def forward(self, x):
# 残差连接1
attn_output, _ = self.attention(x, x, x)
x = self.norm1(x + self.dropout(attn_output))
# 残差连接2
ff_output = self.feed_forward(x)
x = self.norm2(x + self.dropout(ff_output))
return x
# 测试Transformer层
transformer_block = TransformerBlock(
embed_size=8,
num_heads=2,
ff_dim=32
)
output = transformer_block(x)
print("Transformer层输出形状:", output.shape)
五、注意力机制高级应用
5.1 注意力变体比较

5.2 自注意力与卷积的融合
ini
class ConvAttention(nn.Module):
"""卷积增强的自注意力"""
def __init__(self, embed_size, num_heads, kernel_size=3):
super().__init__()
self.attention = MultiHeadAttention(embed_size, num_heads)
self.conv = nn.Conv1d(
in_channels=embed_size,
out_channels=embed_size,
kernel_size=kernel_size,
padding=kernel_size//2
)
self.norm = nn.LayerNorm(embed_size)
def forward(self, x):
# 自注意力路径
attn_out, _ = self.attention(x, x, x)
# 卷积路径 (需要调整维度)
conv_out = self.conv(x.transpose(1, 2)).transpose(1, 2)
# 融合并归一化
combined = attn_out + conv_out
return self.norm(combined)
# 测试卷积注意力
conv_attn = ConvAttention(embed_size=8, num_heads=2)
output = conv_attn(x)
print("卷积注意力输出形状:", output.shape)
5.3 高效注意力实现
python
class EfficientAttention(nn.Module):
"""线性复杂度的注意力机制"""
def __init__(self, embed_size):
super().__init__()
self.embed_size = embed_size
# 特征变换
self.to_query = nn.Linear(embed_size, embed_size)
self.to_key = nn.Linear(embed_size, embed_size)
self.to_value = nn.Linear(embed_size, embed_size)
def forward(self, x):
Q = self.to_query(x)
K = self.to_key(x)
V = self.to_value(x)
# 高效计算 (避免显式计算QK^T)
K = K.softmax(dim=1)
context = torch.einsum('bnd,bne->bde', K, V)
output = torch.einsum('bnd,bde->bne', Q, context)
return output
# 测试高效注意力
eff_attn = EfficientAttention(embed_size=8)
output = eff_attn(x)
print("高效注意力输出形状:", output.shape)
六、Self-Attention实战:文本分类
6.1 数据准备
ini
from torchtext.datasets import IMDB
from torchtext.data import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
# 加载IMDB数据集
train_iter = IMDB(split='train')
tokenizer = get_tokenizer('basic_english')
# 构建词汇表
def yield_tokens(data_iter):
for _, text in data_iter:
yield tokenizer(text)
vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=['<unk>', '<pad>'])
vocab.set_default_index(vocab['<unk>'])
# 文本转张量
def text_pipeline(text):
return vocab(tokenizer(text))
# 创建批次处理函数
def collate_batch(batch, max_len=512):
label_list, text_list = [], []
for label, text in batch:
label_list.append(1 if label=='pos' else 0)
processed_text = text_pipeline(text)[:max_len]
processed_text += [vocab['<pad>']] * (max_len - len(processed_text))
text_list.append(processed_text)
return torch.tensor(label_list), torch.tensor(text_list)
# 创建数据加载器
from torch.utils.data import DataLoader
train_loader = DataLoader(
list(IMDB(split='train')),
batch_size=32,
collate_fn=collate_batch
)
6.2 基于Self-Attention的分类模型
ini
class AttentionClassifier(nn.Module):
def __init__(self, vocab_size, embed_dim, num_heads, hidden_dim, num_classes):
super().__init__()
# 嵌入层
self.embedding = nn.Embedding(vocab_size, embed_dim)
# 位置编码
self.pos_encoding = nn.Parameter(torch.randn(1, 512, embed_dim))
# 自注意力层
self.attention = MultiHeadAttention(embed_dim, num_heads)
# 分类器
self.classifier = nn.Sequential(
nn.Linear(embed_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(hidden_dim, num_classes)
)
def forward(self, x):
# 嵌入层
x = self.embedding(x) # (batch, seq, embed_dim)
# 添加位置编码
seq_len = x.size(1)
x = x + self.pos_encoding[:, :seq_len, :]
# 自注意力
attn_output, _ = self.attention(x, x, x)
# 全局平均池化
pooled = attn_output.mean(dim=1)
# 分类
return self.classifier(pooled)
# 初始化模型
vocab_size = len(vocab)
model = AttentionClassifier(
vocab_size=vocab_size,
embed_dim=128,
num_heads=4,
hidden_dim=256,
num_classes=2
)
# 训练配置
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
6.3 训练与注意力可视化
scss
# 训练循环
for epoch in range(5):
total_loss = 0
correct = 0
total = 0
for labels, texts in train_loader:
optimizer.zero_grad()
outputs = model(texts)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
accuracy = 100. * correct / total
print(f"Epoch {epoch+1}, Loss: {total_loss/len(train_loader):.4f}, Acc: {accuracy:.2f}%")
# 可视化样本注意力
def visualize_attention(text):
# 预处理文本
tokens = tokenizer(text)
indexed = [vocab[token] for token in tokens][:512]
input_tensor = torch.tensor([indexed])
# 获取模型输出和注意力权重
model.eval()
with torch.no_grad():
embeddings = model.embedding(input_tensor)
_, attn_weights = model.attention(embeddings, embeddings, embeddings)
attn_weights = attn_weights.mean(dim=1) # 平均多头
# 可视化
plt.figure(figsize=(12, 6))
plt.imshow(attn_weights.squeeze().numpy(), cmap='viridis')
plt.title('文本注意力权重')
plt.xlabel('Key位置')
plt.ylabel('Query位置')
plt.xticks(range(len(tokens)), tokens, rotation=90)
plt.yticks(range(len(tokens)), tokens)
plt.colorbar()
plt.tight_layout()
plt.show()
# 测试样例
sample_text = "This movie is absolutely fantastic and captivating from start to finish"
visualize_attention(sample_text)
关键要点总结
Self-Attention核心公式:

Multi-Head Attention处理流程:
css
flowchart LR
A[输入] --> B[线性变换]
B --> C[分割多头]
C --> D[Scaled Dot-Product]
D --> E[拼接输出]
E --> F[线性变换]
F --> G[输出]
Transformer中注意力的三种应用:

注意力机制超参数选择:

掌握Self-Attention机制是理解现代大模型的基础,通过本文的数学推导和代码实践,你已经具备了实现和优化注意力模型的核心能力!更多AI大模型应用开发学习视频内容和资料,尽在聚客AI学院。