【LLM基础】3.大模型前沿注意力机制优化笔记 (以 Qwen3.5-MoE 为例)

LLM 前沿架构笔记:多头注意力机制(Attention)的进阶优化策略

在现代 LLM(比如 Qwen、LLaMA 家族等)中,Transformer 的标准多头注意力(MHA)结构正面临表征能力显存消耗的双重挑战。主流解决思路主要有以下两类:

核心痛点

  1. 表征维度受限

    • 标准 MHA 将隐藏层特征均分到各头,单头维度(d_{head})小于主干宽度(hidden_size),易造成信息压缩损失。
  2. 显存墙(Memory Wall)

    • 推理阶段(自回归生成时),KV Cache 需存储长序列全部的 Key/Value,特别是高上下文窗口(128K+)下,显存压力骤增。

优化策略一:特征维度解耦(Decoupled Head Dimension)

核心思想

打破「单头维度 = hidden_size / num_heads」的传统,将每个 Attention Head 的特征维度设为独立可控的更大值,提升表征能力。

具体示例(hidden_size = 2048, num_heads = 16)
  • 传统做法(强绑定):

dhead, old=hidden_sizenum_heads=204816=128d_{head,\,old} = \frac{hidden\_size}{num\_heads} = \frac{2048}{16} = 128dhead,old=num_headshidden_size=162048=128

  • 解耦做法(升维举例:head_dim=256):

dhead, new=256 (显式指定)d_{head,\,new} = 256\ \text{(显式指定)}dhead,new=256 (显式指定)

总内部特征提升为:

Internal Attention Dim=num_heads×dhead, new=16×256=4096\text{Internal Attention Dim} = num\heads \times d{head,\,new} = 16 \times 256 = 4096Internal Attention Dim=num_heads×dhead,new=16×256=4096

计算过程

输入特征矩阵 X(形状 L × 2048),经投影后:

Q=X⋅WQQ = X \cdot W_QQ=X⋅WQ

其中,

WQ∈R2048×4096W_Q \in \mathbb{R}^{2048 \times 4096}WQ∈R2048×4096

变换后,

Q∈RL×4096Q \in \mathbb{R}^{L \times 4096}Q∈RL×4096

Reshape 为多个头:

Qreshaped∈RL×16×256Q_{reshaped} \in \mathbb{R}^{L \times 16 \times 256}Qreshaped∈RL×16×256

优势

  • 提升了注意力内部的表达容量
  • 不增加主干(残差流)参数量和算力压力

优化策略二:分组查询注意力(Grouped Query Attention, GQA)

核心思想

让多个 Query 头共用较少数量的 Key/Value 头,是 MHA(每个 Q 头配一个 KV 头)与 MQA(所有 Q 头共用 1 组 KV)之间的"折中":

机制细节(以 16 个 Attention 头,2 个 KV 头举例)
  1. 输入 X(L × 2048),投影生成 K、V:

    K=X⋅WK K = X \cdot W_KK=X⋅WK

    V=X⋅WVV = X \cdot W_VV=X⋅WV

    其中

    WK,WV∈R2048×(2×256)W_K, W_V \in \mathbb{R}^{2048 \times (2 \times 256)}WK,WV∈R2048×(2×256)

    输出

    K,V∈RL×2×256 K, V \in \mathbb{R}^{L \times 2 \times 256}K,V∈RL×2×256

  2. Query 头分组后,按如下方式共享 KV(假设 16 Q 头/2 组,每组 8 个):

    • 第 0~7 个 Q 头,对应第 0 号 K/V 头
    • 第 8~15 个 Q 头,对应第 1 号 K/V 头
  3. 组内做标准缩放点积注意力:

    Attention(Q,K,V)=softmax(QKTdk)V Attention(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) VAttention(Q,K,V)=softmax(dk QKT)V

    其中,dk=256d_k = 256dk=256

优势
  • 极致显存优化:KV Cache 从 16 头降为 2 头,显存为原来的 1/8
  • 性能无损:大模型实验表明,GQA 下推理/生成能力几乎无损

总结

策略类型 优化点 主要公式 (代码块) 效果
维度解耦 表征能力提升 dhead⟶更大可控d_{head} \longrightarrow \text{更大可控}dhead⟶更大可控 特征自由度大幅提升
分组查询(GQA) KV Cache 显存极致缩减 原KV头数分组合并数\frac{\text{原KV头数}}{\text{分组合并数}}分组合并数原KV头数 推理显存消耗降低 N 倍
注意力计算(通用) - Attention(Q,K,V)=softmax(QKTdk)V\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{Q K^T}{\sqrt{d_k}}\right) VAttention(Q,K,V)=softmax(dk QKT)V -

公式用法注意:如需插入 LateX 代码块,形式为:

Q=X⋅WQQ = X \cdot W_QQ=X⋅WQ

相关推荐
seven97_top2 小时前
第一批被龙虾气到的人出现了
人工智能
AC赳赳老秦2 小时前
国产化AI运维新趋势:DeepSeek赋能国产算力部署的高效故障排查
大数据·人工智能·python·django·去中心化·ai-native·deepseek
Never_every992 小时前
5 个批量抠图工具,提升 10 倍效率
大数据·前端·ai
1941s2 小时前
01-LLM 基础与提示词工程:从 API 调用到 Prompt 优化技巧
人工智能·python·prompt
愚公搬代码2 小时前
【粉丝福利社】AI时代硬核竞争力:这个数学书单传疯了
人工智能
超级学长2 小时前
光学神经网络:进展与挑战(Optical Neural Networks: Progress and Challenges)
人工智能·深度学习·光学神经网络
咚咚王者2 小时前
人工智能之语言领域 自然语言处理 第九章 文本相似度计算
人工智能·自然语言处理
研究点啥好呢2 小时前
每日GitHub热门项目推荐 | 2026年3月9日
人工智能·ai·自动化·github·openclaw
itwangyang5202 小时前
GitHub Push Protection 报错解决指南(检测到 Token / Secret)
人工智能·python·github