注意力机制(个人理解)

GPT2Block/多头注意力

含义:将输入投影到多个头,每个头计算缩放点积注意力,然后拼接并投影结果。

python 复制代码
class _GELU(nn.Module):
    def forward(self, x):
        return x * 0.5 * (1.0 + torch.erf(x / (2.0 ** 0.5)))

class GPT2Block(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

        self.mlp = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            _GELU(),
            nn.Linear(4 * d_model, d_model),
        )

    def _attn(self, x):
        B, S, _ = x.shape
        q = self.W_q(x).view(B, S, self.num_heads, self.d_k).transpose(1, 2)
        k = self.W_k(x).view(B, S, self.num_heads, self.d_k).transpose(1, 2)
        v = self.W_v(x).view(B, S, self.num_heads, self.d_k).transpose(1, 2)
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5)
        mask = torch.triu(torch.ones(S, S, device=x.device, dtype=torch.bool), diagonal=1)
        scores = scores.masked_fill(mask, float('-inf'))
        weights = torch.softmax(scores, dim=-1)
        attn = torch.matmul(weights, v)
        return self.W_o(attn.transpose(1, 2).contiguous().view(B, S, -1))

    def forward(self, x):
        x = x + self._attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

注意:

  1. Q和K的序列长度可能不一致
  2. 缩放因子 / sqrt(d_k):防止点积过大导致 softmax 梯度消失
  3. 所有头的注意力计算通过批量矩阵乘法并行完成,效率高。
  4. masked_fill确保在预测第 t 个 token 时,只能看到前 t-1 个 token
常见问题

1、大模型为什么一般不用dropout

  • Dropout 的核心目的是防止过拟合。大模型训练的数据足够大,使用反而可能导致欠拟合。
  • 会损害模型的"记忆能力"和训练稳定性
  • 替代的正则化手段更有效:权重衰减、layernorm、早停等
  • 推理阶段的不一致性

2、FFN层的作用/为什么要先升维,后降维/非线性的作用

  • 线性层等价于一个单一的线性变换,会导致模型拟合复杂函数的能力受限
  • 模型在一个更广阔的高维空间中进行非线性变换,足够宽的隐藏层可以近似任何连续函数。扩维让网络有能力学习更复杂的特征组合。
  • Transformer 的注意力机制(Attention)主要负责捕捉序列中不同位置之间的关系(即"谁关注谁"),而 FFN 主要负责对每个位置的特征进行独立处理和深化。
  • 计算效率与表达能力的平衡

组查询注意力GQA

含义:GQA 使用比查询头更少的 KV 头,每个 KV 头在一组查询头之间共享,在保持质量的同时减少 KV 缓存大小。。

python 复制代码
class GroupQueryAttention(nn.Module):
    def __init__(self, d_model, num_heads, num_kv_heads):
        super().__init__()
        self.num_heads = num_heads
        self.num_kv_heads = num_kv_heads
        self.d_k = d_model // num_heads
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, num_kv_heads * self.d_k)
        self.W_v = nn.Linear(d_model, num_kv_heads * self.d_k)
        self.W_o = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, S, _ = x.shape
        q = self.W_q(x).view(B, S, self.num_heads, self.d_k).transpose(1, 2)
        k = self.W_k(x).view(B, S, self.num_kv_heads, self.d_k).transpose(1, 2)
        v = self.W_v(x).view(B, S, self.num_kv_heads, self.d_k).transpose(1, 2)
        repeats = self.num_heads // self.num_kv_heads
        k = k.repeat_interleave(repeats, dim=1)
        v = v.repeat_interleave(repeats, dim=1)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        weights = torch.softmax(scores, dim=-1)
        attn = torch.matmul(weights, v)
        out = attn.transpose(1, 2).contiguous().view(B, S, -1)
        return self.W_o(out)

注意:

  1. W_k和W_v的大小和W_q不一致

滑动窗口注意力

含义:滑动窗口注意力限制每个位置只关注固定窗口内的位置,在保持局部上下文的同时降低长序列的复杂度。

python 复制代码
def sliding_window_attention(Q, K, V, window_size):
    d_k = K.size(-1)
    scores = torch.bmm(Q, K.transpose(1, 2)) / math.sqrt(d_k)
    S = Q.size(1)
    idx = torch.arange(S, device=Q.device)
    mask = (idx.unsqueeze(0) - idx.unsqueeze(1)).abs() > window_size
    scores = scores.masked_fill(mask.unsqueeze(0), float('-inf'))
    weights = torch.softmax(scores, dim=-1)
    return torch.bmm(weights, V)

注意:

  1. 用 -inf 掩盖 |i - j| > window_size 的位置
  2. 大窗口等同于全注意力

差分注意力

含义:将 Q 和 K 各自分成两半,分别计算两个 softmax 注意力图,然后相减(乘以可学习的 lambda)以消除噪声,提升对相关上下文的聚焦能力。

python 复制代码
def diff_attention(Q, K, V, lambda_val):
    B, S, D2 = Q.shape
    D_h = D2 // 2
    Q1, Q2 = Q[..., :D_h], Q[..., D_h:]
    K1, K2 = K[..., :D_h], K[..., D_h:]
    scale = D_h ** -0.5
    A1 = torch.softmax(Q1 @ K1.transpose(-2, -1) * scale, dim=-1)
    A2 = torch.softmax(Q2 @ K2.transpose(-2, -1) * scale, dim=-1)
    return (A1 - lambda_val * A2) @ V
  1. 增强对比性/去噪能力,不同部分捕捉不同的空间信息
  2. 差值 A1 - lambda*A2 可能出现负值。比标准注意力(只能加权求和,不能减)更具表达能力,允许模型主动"忽略"或"抵消"某些上下文信息。

多头潜在注意力(MLA)

含义:不缓存完整的 K 和 V 张量,而是将其压缩为低秩潜在向量 c_kv,推理时再即时解压。这大幅降低了推理时的 KV 缓存内存占用。

python 复制代码
def mla_attention(X, W_dkv, W_uk, W_uv, W_q, num_heads):
    B, S, D = X.shape
    D_h = W_q.shape[1] // num_heads
    # Compress KV into low-rank latent
    c_kv = X @ W_dkv                          # (B, S, kv_rank)
    K = c_kv @ W_uk                            # (B, S, num_heads*D_h)
    V = c_kv @ W_uv                            # (B, S, num_heads*D_h)
    Q = X @ W_q                                # (B, S, num_heads*D_h)
    # Reshape to multi-head format
    def split_heads(t):
        return t.view(B, S, num_heads, D_h).transpose(1, 2)
    Q, K, V = split_heads(Q), split_heads(K), split_heads(V)
    scale = D_h ** -0.5
    attn = torch.softmax(Q @ K.transpose(-2, -1) * scale, dim=-1)
    out = (attn @ V).transpose(1, 2).reshape(B, S, num_heads * D_h)
    return out
  1. D_h = W_q.shape[1] // num_heads:W_q的输入不一定是num_heads * D_h,但输出一定是num_heads * D_h
  2. 计算复杂度与标准注意力相同,但是所需缓存更小
  3. 表达能力近似完整,取决于kv_rank的大小
相关推荐
Flying pigs~~2 小时前
LoRA 面试完全指南:低秩分解原理 + Transformer 应用
人工智能·深度学习·lora·大模型·微调·transformer
iwhitney2 小时前
【次方量化】3分钟搞懂什么是量化策略
python
高洁013 小时前
大模型部署资源不足?轻量化部署解决方案
python·深度学习·机器学习·数据挖掘·transformer
机械X人3 小时前
Encoder-Decoder PLM
人工智能·深度学习
阿里云大数据AI技术3 小时前
MaxFrame 视频帧智能分析:从视频到语义向量的端到端分布式处理
人工智能·python
淘矿人3 小时前
从0到1:用Claude启动你的第一个项目
开发语言·人工智能·git·python·github·php·pygame
嘻嘻哈哈樱桃3 小时前
牛客经典101题题解集--动态规划
java·数据结构·python·算法·职场和发展·动态规划
gmaajt3 小时前
Golang怎么做国际化多语言_Golang i18n教程【核心】
jvm·数据库·python
maqr_1103 小时前
CSS如何利用Sass定义全局阴影方案_通过变量实现统一CSS风格
jvm·数据库·python