文章目录
注意力残差这个概念前几天火了,还不太懂,先跟跟风。
示例
1、安装依赖
pip install torch
pip install numpy # numpy是多维数组操作相关的库
2、代码:
py
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class MultiHeadAttention(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super(MultiHeadAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
assert self.head_dim * num_heads == embed_dim, "embed_dim must be divisible by num_heads"
# 定义 Q, K, V 的线性变换
self.q_linear = nn.Linear(embed_dim, embed_dim)
self.k_linear = nn.Linear(embed_dim, embed_dim)
self.v_linear = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
self.out_linear = nn.Linear(embed_dim, embed_dim) # 输出投影
def forward(self, query, key, value, mask=None):
batch_size = query.size(0)
# 1. 线性映射并拆分多头
# q: (batch, seq_len, embed_dim) -> (batch, num_heads, seq_len, head_dim)
q = self.q_linear(query).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
k = self.k_linear(key).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
v = self.v_linear(value).view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2)
# 2. 计算注意力分数 (Scaled Dot-Product Attention)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
attn_weights = F.softmax(scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# 3. 加权求和
context = torch.matmul(attn_weights, v) # (batch, num_heads, seq_len, head_dim)
# 4. 合并多头并线性投影
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.embed_dim)
attention_output = self.out_linear(context)
return attention_output
class AttentionBlockWithResidual(nn.Module):
def __init__(self, embed_dim, num_heads, dropout=0.1):
super(AttentionBlockWithResidual, self).__init__()
self.attention = MultiHeadAttention(embed_dim, num_heads, dropout)
self.norm = nn.LayerNorm(embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, mask=None):
# --- 关键部分:注意力残差实现 ---
# 方案 A: Post-LN (原始 Transformer 做法)
# 1. 先计算注意力
attn_out = self.attention(x, x, x, mask)
# 2. 应用 Dropout (可选,通常在残差前或后)
attn_out = self.dropout(attn_out)
# 3. 残差连接: Input + Attention_Output
residual_out = x + attn_out
# 4. 层归一化
output = self.norm(residual_out)
return output
# 方案 B: Pre-LN (现代 Transformer 常用,训练更稳定)
# norm_x = self.norm(x)
# attn_out = self.attention(norm_x, norm_x, norm_x, mask)
# output = x + self.dropout(attn_out)
# return output
# --- 测试代码 ---
if __name__ == "__main__":
# 假设参数
batch_size = 2
seq_length = 10
embed_dim = 512
num_heads = 8
# 随机输入数据
x = torch.randn(batch_size, seq_length, embed_dim)
# 初始化模块
model = AttentionBlockWithResidual(embed_dim, num_heads)
# 前向传播
output = model(x)
print(f"输入形状: {x.shape}")
print(f"输出形状: {output.shape}")
# 验证残差是否生效 (简单检查:输出不应等于纯注意力输出,也不应等于纯输入)
# 理论上 output ≈ LayerNorm(x + Attention(x))
print("注意力残差模块运行成功!")
输出内容:
输入形状: torch.Size([2, 10, 512])
输出形状: torch.Size([2, 10, 512])
注意力残差模块运行成功!
结果不变是对的吗?
这是对的,因为核心指标就是维度守恒。
维度守恒:
输入:batch_size=2, seq_len=10, embed_dim=512
输出:batch_size=2, seq_len=10, embed_dim=512
结论:残差连接要求 X 和 F(X)维度必须一致,你的代码完美满足了这一点。这是构建深层 Transformer 堆叠层的基础。