GPT2Block/多头注意力
含义:将输入投影到多个头,每个头计算缩放点积注意力,然后拼接并投影结果。
python
class _GELU(nn.Module):
def forward(self, x):
return x * 0.5 * (1.0 + torch.erf(x / (2.0 ** 0.5)))
class GPT2Block(nn.Module):
def __init__(self, d_model, num_heads):
super().__init__()
self.num_heads = num_heads
self.d_k = d_model // num_heads
self.ln1 = nn.LayerNorm(d_model)
self.ln2 = nn.LayerNorm(d_model)
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.mlp = nn.Sequential(
nn.Linear(d_model, 4 * d_model),
_GELU(),
nn.Linear(4 * d_model, d_model),
)
def _attn(self, x):
B, S, _ = x.shape
q = self.W_q(x).view(B, S, self.num_heads, self.d_k).transpose(1, 2)
k = self.W_k(x).view(B, S, self.num_heads, self.d_k).transpose(1, 2)
v = self.W_v(x).view(B, S, self.num_heads, self.d_k).transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5)
mask = torch.triu(torch.ones(S, S, device=x.device, dtype=torch.bool), diagonal=1)
scores = scores.masked_fill(mask, float('-inf'))
weights = torch.softmax(scores, dim=-1)
attn = torch.matmul(weights, v)
return self.W_o(attn.transpose(1, 2).contiguous().view(B, S, -1))
def forward(self, x):
x = x + self._attn(self.ln1(x))
x = x + self.mlp(self.ln2(x))
return x
注意:
- Q和K的序列长度可能不一致
- 缩放因子 / sqrt(d_k):防止点积过大导致 softmax 梯度消失
- 所有头的注意力计算通过批量矩阵乘法并行完成,效率高。
- masked_fill确保在预测第 t 个 token 时,只能看到前 t-1 个 token
常见问题
1、大模型为什么一般不用dropout
- Dropout 的核心目的是防止过拟合。大模型训练的数据足够大,使用反而可能导致欠拟合。
- 会损害模型的"记忆能力"和训练稳定性
- 替代的正则化手段更有效:权重衰减、layernorm、早停等
- 推理阶段的不一致性
2、FFN层的作用/为什么要先升维,后降维/非线性的作用
- 线性层等价于一个单一的线性变换,会导致模型拟合复杂函数的能力受限
- 模型在一个更广阔的高维空间中进行非线性变换,足够宽的隐藏层可以近似任何连续函数。扩维让网络有能力学习更复杂的特征组合。
- Transformer 的注意力机制(Attention)主要负责捕捉序列中不同位置之间的关系(即"谁关注谁"),而 FFN 主要负责对每个位置的特征进行独立处理和深化。
- 计算效率与表达能力的平衡
组查询注意力GQA
含义:GQA 使用比查询头更少的 KV 头,每个 KV 头在一组查询头之间共享,在保持质量的同时减少 KV 缓存大小。。
python
class GroupQueryAttention(nn.Module):
def __init__(self, d_model, num_heads, num_kv_heads):
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads
self.d_k = d_model // num_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, num_kv_heads * self.d_k)
self.W_v = nn.Linear(d_model, num_kv_heads * self.d_k)
self.W_o = nn.Linear(d_model, d_model)
def forward(self, x):
B, S, _ = x.shape
q = self.W_q(x).view(B, S, self.num_heads, self.d_k).transpose(1, 2)
k = self.W_k(x).view(B, S, self.num_kv_heads, self.d_k).transpose(1, 2)
v = self.W_v(x).view(B, S, self.num_kv_heads, self.d_k).transpose(1, 2)
repeats = self.num_heads // self.num_kv_heads
k = k.repeat_interleave(repeats, dim=1)
v = v.repeat_interleave(repeats, dim=1)
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
weights = torch.softmax(scores, dim=-1)
attn = torch.matmul(weights, v)
out = attn.transpose(1, 2).contiguous().view(B, S, -1)
return self.W_o(out)
注意:
- W_k和W_v的大小和W_q不一致
滑动窗口注意力
含义:滑动窗口注意力限制每个位置只关注固定窗口内的位置,在保持局部上下文的同时降低长序列的复杂度。
python
def sliding_window_attention(Q, K, V, window_size):
d_k = K.size(-1)
scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(d_k)
S = Q.size(1)
idx = torch.arange(S, device=Q.device)
mask = (idx.unsqueeze(0) - idx.unsqueeze(1)).abs() > window_size
scores = scores.masked_fill(mask.unsqueeze(0), float('-inf'))
weights = torch.softmax(scores, dim=-1)
return torch.bmm(weights, V)
注意:
- 用 -inf 掩盖 |i - j| > window_size 的位置
- 大窗口等同于全注意力
差分注意力
含义:将 Q 和 K 各自分成两半,分别计算两个 softmax 注意力图,然后相减(乘以可学习的 lambda)以消除噪声,提升对相关上下文的聚焦能力。
python
def diff_attention(Q, K, V, lambda_val):
B, S, D2 = Q.shape
D_h = D2 // 2
Q1, Q2 = Q[..., :D_h], Q[..., D_h:]
K1, K2 = K[..., :D_h], K[..., D_h:]
scale = D_h ** -0.5
A1 = torch.softmax(Q1 @ K1.transpose(-2, -1) * scale, dim=-1)
A2 = torch.softmax(Q2 @ K2.transpose(-2, -1) * scale, dim=-1)
return (A1 - lambda_val * A2) @ V
- 增强对比性/去噪能力,不同部分捕捉不同的空间信息
- 差值 A1 - lambda*A2 可能出现负值。比标准注意力(只能加权求和,不能减)更具表达能力,允许模型主动"忽略"或"抵消"某些上下文信息。
多头潜在注意力(MLA)
含义:不缓存完整的 K 和 V 张量,而是将其压缩为低秩潜在向量 c_kv,推理时再即时解压。这大幅降低了推理时的 KV 缓存内存占用。
python
def mla_attention(X, W_dkv, W_uk, W_uv, W_q, num_heads):
B, S, D = X.shape
D_h = W_q.shape[1] // num_heads
# Compress KV into low-rank latent
c_kv = X @ W_dkv # (B, S, kv_rank)
K = c_kv @ W_uk # (B, S, num_heads*D_h)
V = c_kv @ W_uv # (B, S, num_heads*D_h)
Q = X @ W_q # (B, S, num_heads*D_h)
# Reshape to multi-head format
def split_heads(t):
return t.view(B, S, num_heads, D_h).transpose(1, 2)
Q, K, V = split_heads(Q), split_heads(K), split_heads(V)
scale = D_h ** -0.5
attn = torch.softmax(Q @ K.transpose(-2, -1) * scale, dim=-1)
out = (attn @ V).transpose(1, 2).reshape(B, S, num_heads * D_h)
return out
- D_h = W_q.shape[1] // num_heads:W_q的输入不一定是num_heads * D_h,但输出一定是num_heads * D_h
- 计算复杂度与标准注意力相同,但是所需缓存更小
- 表达能力近似完整,取决于kv_rank的大小