Transformer 注意力机制详解与近期改进
Transformer 注意力机制详解与近期改进
- [Transformer 注意力机制详解与近期改进](#Transformer 注意力机制详解与近期改进)
- [Transformer 整体架构:Encoder 与 Decoder](#Transformer 整体架构:Encoder 与 Decoder)
- [Positional Embedding:为什么需要位置信息](#Positional Embedding:为什么需要位置信息)
-
- [2.1 原始 Transformer 的位置编码](#2.1 原始 Transformer 的位置编码)
- [2.2 RoPE (Rotary Positional Embedding)](#2.2 RoPE (Rotary Positional Embedding))
- 核心优势对比
- [Self-Attention 计算细节:Q、K、V 矩阵](#Self-Attention 计算细节:Q、K、V 矩阵)
-
- [3.1 Q、K、V 的来源](#3.1 Q、K、V 的来源)
- [3.2 Scaled Dot-Product Attention 公式](#3.2 Scaled Dot-Product Attention 公式)
- [3.3 为什么除以 d k \sqrt{d_k} dk](#3.3 为什么除以 d k \sqrt{d_k} dk)
- [Multi-Head Attention:多视角捕获依赖](#Multi-Head Attention:多视角捕获依赖)
-
- [4.1 结构图](#4.1 结构图)
- [4.2 计算公式](#4.2 计算公式)
- [4.3 为什么需要 Multi-Head](#4.3 为什么需要 Multi-Head)
- [Mask 机制:Encoder 与 Decoder 的区别](#Mask 机制:Encoder 与 Decoder 的区别)
-
- [5.1 Encoder 的 Mask](#5.1 Encoder 的 Mask)
- [5.2 Decoder 的 Mask](#5.2 Decoder 的 Mask)
- [5.3 Cross-Attention 的 Mask](#5.3 Cross-Attention 的 Mask)
- 近期改进技术详解
-
- [6.1 GQA (Grouped Query Attention)](#6.1 GQA (Grouped Query Attention))
- [6.2 Flash Attention 2](#6.2 Flash Attention 2)
- [6.3 RoPE 长上下文扩展:YaRN / 动态缩放](#6.3 RoPE 长上下文扩展:YaRN / 动态缩放)
- [6.4 Qwen 的 Gate 机制(MoE 架构)](#6.4 Qwen 的 Gate 机制(MoE 架构))
- [6.5 MLA (Multi-Token Latent Attention) - DeepSeek 创新](#6.5 MLA (Multi-Token Latent Attention) - DeepSeek 创新)
- 各模型注意力配置对比
- 计算复杂度与内存占用分析
-
- [8.1 注意力计算复杂度](#8.1 注意力计算复杂度)
- [8.2 KV-Cache 内存占用](#8.2 KV-Cache 内存占用)
- 总结:注意力机制演进路线
- [补充:DeepSeek 对 ResNet 的改进------MHC 与 Engram 记忆机制](#补充:DeepSeek 对 ResNet 的改进——MHC 与 Engram 记忆机制)
-
- 1) MHC (Multi-Head Cross-attention) 模块 MHC (Multi-Head Cross-attention) 模块)
- 2) Engram 记忆机制 Engram 记忆机制)
- 3) MHC + Engram 集成架构 MHC + Engram 集成架构)
- 4) 改进效果对比 改进效果对比)
- 5) 核心优势 核心优势)
- 6) 理论分析 理论分析)
- 7) 应用场景 应用场景)
Transformer 整体架构:Encoder 与 Decoder
Transformer 架构的核心是 Self-Attention 机制,它让模型能够捕捉序列中任意两个位置之间的依赖关系。原始 Transformer 采用 Encoder-Decoder 架构:
┌─────────────────────────────────────────────────────────────┐
│ ENCODER (6 层) │
├─────────────────────────────────────────────────────────────┤
│ Input → Embedding → +Positional Encoding │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ Multi-Head Self-Attention → Add & Norm │ │
│ │ ↓ │ │
│ │ Feed Forward Network → Add & Norm │ │
│ └─────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────┘
┌─────────────────────────────────────────────────────────────┐
│ DECODER (6 层) │
├─────────────────────────────────────────────────────────────┤
│ Input → Embedding → +Positional Encoding │
│ │ │
│ ▼ │
│ ┌─────────────────────────────────────────────────────┐ │
│ │ Masked Multi-Head Self-Attention → Add & Norm │ │
│ │ ↓ │ │
│ │ Multi-Head Cross-Attention → Add & Norm │ │
│ │ ↓ │ │
│ │ Feed Forward Network → Add & Norm │ │
│ └─────────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────┘
各组件对比
| 组件 | 输入来源 | 注意力类型 | Mask 需求 |
|---|---|---|---|
| Encoder Self-Attn | Encoder 输入 | 全注意力 | Padding Mask |
| Decoder Self-Attn | Decoder 输入 | Masked 注意力 | Padding + Look-ahead |
| Cross-Attention | Encoder 输出 + Decoder 输入 | Cross 注意力 | Padding Mask |
注意:Encoder 可以看到完整输入序列,Decoder 在自注意力层需要遮蔽未来位置(Look-ahead Mask)。
Positional Embedding:为什么需要位置信息
Self-Attention 本身是置换不变的,即打乱输入顺序输出不变。因此需要位置编码来注入顺序信息。
2.1 原始 Transformer 的位置编码
使用正弦/余弦函数:
P E ( p o s , 2 i ) = sin ( p o s 10000 2 i / d m o d e l ) PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i)=sin(100002i/dmodelpos)
P E ( p o s , 2 i + 1 ) = cos ( p o s 10000 2 i / d m o d e l ) PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i+1)=cos(100002i/dmodelpos)
参数说明:
- p o s pos pos:位置索引
- i i i:维度索引
- d m o d e l d_{model} dmodel:模型维度(通常 512)
代码实现示例:
python
def get_positional_encoding(seq_len, d_model):
position = np.arange(seq_len)[:, np.newaxis]
div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
pe = np.zeros((seq_len, d_model))
pe[:, 0::2] = np.sin(position * div_term)
pe[:, 1::2] = np.cos(position * div_term)
return pe
2.2 RoPE (Rotary Positional Embedding)
现代大模型(Qwen、Llama、DeepSeek 等)普遍采用 RoPE:
( q m ( 1 ) q m ( 2 ) ) = ( cos ( m θ ) − sin ( m θ ) sin ( m θ ) cos ( m θ ) ) ( q 0 ( 1 ) q 0 ( 2 ) ) \begin{pmatrix} q_m^{(1)} \\ q_m^{(2)} \end{pmatrix} = \begin{pmatrix} \cos(m\theta) & -\sin(m\theta) \\ \sin(m\theta) & \cos(m\theta) \end{pmatrix} \begin{pmatrix} q_0^{(1)} \\ q_0^{(2)} \end{pmatrix} (qm(1)qm(2))=(cos(mθ)sin(mθ)−sin(mθ)cos(mθ))(q0(1)q0(2))
( k n ( 1 ) k n ( 2 ) ) = ( cos ( n θ ) − sin ( n θ ) sin ( n θ ) cos ( n θ ) ) ( k 0 ( 1 ) k 0 ( 2 ) ) \begin{pmatrix} k_n^{(1)} \\ k_n^{(2)} \end{pmatrix} = \begin{pmatrix} \cos(n\theta) & -\sin(n\theta) \\ \sin(n\theta) & \cos(n\theta) \end{pmatrix} \begin{pmatrix} k_0^{(1)} \\ k_0^{(2)} \end{pmatrix} (kn(1)kn(2))=(cos(nθ)sin(nθ)−sin(nθ)cos(nθ))(k0(1)k0(2))
参数说明:
- m , n m, n m,n:Query 和 Key 的位置
- θ = 10000 − 2 i / d \theta = 10000^{-2i/d} θ=10000−2i/d:旋转频率
核心优势对比
| 特性 | 绝对位置编码 | RoPE |
|---|---|---|
| 相对位置感知 | ❌ 弱 | ✅ 强(内积只依赖 m − n m-n m−n) |
| 长度外推能力 | ❌ 差 | ✅ 好 |
| 数学性质 | 可学习参数 | 旋转矩阵(确定性) |
原理 :RoPE 通过旋转操作,使得 Q m ⋅ K n Q_m \cdot K_n Qm⋅Kn 的内积只依赖于相对位置 m − n m-n m−n,天然支持相对位置编码。
Self-Attention 计算细节:Q、K、V 矩阵
3.1 Q、K、V 的来源
输入 X ∈ R n × d m o d e l X \in \mathbb{R}^{n \times d_{model}} X∈Rn×dmodel 经过线性变换得到:
Q = X ⋅ W Q , W Q ∈ R d m o d e l × d k Q = X \cdot W^Q, \quad W^Q \in \mathbb{R}^{d_{model} \times d_k} Q=X⋅WQ,WQ∈Rdmodel×dk
K = X ⋅ W K , W K ∈ R d m o d e l × d k K = X \cdot W^K, \quad W^K \in \mathbb{R}^{d_{model} \times d_k} K=X⋅WK,WK∈Rdmodel×dk
V = X ⋅ W V , W V ∈ R d m o d e l × d v V = X \cdot W^V, \quad W^V \in \mathbb{R}^{d_{model} \times d_v} V=X⋅WV,WV∈Rdmodel×dv
矩阵作用说明:
| 矩阵 | 名称 | 作用 | 类比 |
|---|---|---|---|
| Q | Query | "我想找什么" | 搜索关键词 |
| K | Key | "我是什么" | 文档标签 |
| V | Value | "我的内容" | 文档内容 |
3.2 Scaled Dot-Product Attention 公式
Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V
计算步骤详解:
| 步骤 | 操作 | 维度变化 | 目的 |
|---|---|---|---|
| 1 | Q ⋅ K T Q \cdot K^T Q⋅KT | ( n × d k ) × ( d k × n ) = n × n (n \times d_k) \times (d_k \times n) = n \times n (n×dk)×(dk×n)=n×n | 计算词间相似度 |
| 2 | 除以 d k \sqrt{d_k} dk | n × n n \times n n×n | 缩放,防止梯度消失 |
| 3 | Softmax | n × n n \times n n×n | 归一化为概率分布 |
| 4 | × V \times V ×V | ( n × n ) × ( n × d v ) = n × d v (n \times n) \times (n \times d_v) = n \times d_v (n×n)×(n×dv)=n×dv | 加权求和 |
代码实现示例:
python
def scaled_dot_product_attention(Q, K, V, mask=None):
d_k = Q.shape[-1]
scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k) # 步骤 1-2
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9) # Mask 处理
attn_weights = torch.softmax(scores, dim=-1) # 步骤 3
output = torch.matmul(attn_weights, V) # 步骤 4
return output, attn_weights
3.3 为什么除以 d k \sqrt{d_k} dk
当 d k d_k dk 较大时, Q ⋅ K T Q \cdot K^T Q⋅KT 的方差会变大,导致 Softmax 进入饱和区:
Var ( Q ⋅ K T ) = d k ⋅ Var ( q ⋅ k ) \text{Var}(Q \cdot K^T) = d_k \cdot \text{Var}(q \cdot k) Var(Q⋅KT)=dk⋅Var(q⋅k)
除以 d k \sqrt{d_k} dk 后方差稳定为 1,梯度更稳定。
Multi-Head Attention:多视角捕获依赖
4.1 结构图
Input X
│
▼
┌────────────────┼────────────────┐
│ │ │
▼ ▼ ▼
┌─────────┐ ┌─────────┐ ┌─────────┐
│ Head 1 │ │ Head 2 │ │ Head h │
│ Q1,K1,V1│ │ Q2,K2,V2│ │ Qh,Kh,Vh│
└────┬────┘ └────┬────┘ └────┬────┘
│ │ │
▼ ▼ ▼
┌─────────┐ ┌─────────┐ ┌─────────┐
│ Attn 1 │ │ Attn 2 │ │ Attn h │
└────┬────┘ └────┬────┘ └────┬────┘
│ │ │
└───────────────┼───────────────┘
▼
┌───────────┐
│ Concat │
└─────┬─────┘
▼
┌───────────┐
│ Linear W⁰ │
└─────┬─────┘
▼
Output
4.2 计算公式
MultiHead ( Q , K , V ) = Concat ( head 1 , . . . , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO
where head i = Attention ( Q W i Q , K W i K , V W i V ) \text{where } \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) where headi=Attention(QWiQ,KWiK,VWiV)
4.3 为什么需要 Multi-Head
不同头可以学习不同类型的依赖关系:
| 头 | 可能学到的关系 | 示例 |
|---|---|---|
| Head 1 | 语法关系 | 主谓一致、时态 |
| Head 2 | 语义关系 | 同义词、上下位词 |
| Head 3 | 位置关系 | 相邻词、远距离依赖 |
| Head 4 | 指代关系 | 代词指向的名词 |
示例:
句子: "The animal didn't cross the street because it was too tired"
- Head 1 (语法): it ──────→ animal (主谓一致)
- Head 2 (语义): it ──────→ tired (因果关系)
- Head 3 (位置): it ──────→ because (相邻关系)
Mask 机制:Encoder 与 Decoder 的区别
5.1 Encoder 的 Mask
| Mask 类型 | 是否需要 | 说明 |
|---|---|---|
| Padding Mask | ✅ 需要 | 屏蔽填充位置 |
| Look-ahead Mask | ❌ 不需要 | Encoder 可看到完整输入 |
5.2 Decoder 的 Mask
| Mask 类型 | 是否需要 | 说明 |
|---|---|---|
| Padding Mask | ✅ 需要 | 屏蔽目标序列 padding |
| Look-ahead Mask | ✅ 需要 | 遮蔽当前位置之后的信息 |
Look-ahead Mask 矩阵示例(序列长度=5):
Mask = [ 1 0 0 0 0 1 1 0 0 0 1 1 1 0 0 1 1 1 1 0 1 1 1 1 1 ] \text{Mask} = \begin{bmatrix} 1 & 0 & 0 & 0 & 0 \\ 1 & 1 & 0 & 0 & 0 \\ 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 \\ 1 & 1 & 1 & 1 & 1 \end{bmatrix} Mask= 1111101111001110001100001
(1 表示可见,0 表示遮蔽)
5.3 Cross-Attention 的 Mask
在 Decoder 的 Cross-Attention 层中:
- Query 来自 Decoder
- Key/Value 来自 Encoder 的输出
- 不需要 Look-ahead Mask(Encoder 输出已全部生成)
近期改进技术详解
6.1 GQA (Grouped Query Attention)
问题:原始 MHA 的 KV-Cache 占用过大,影响推理速度。
解决方案:多个 Query 头共享少量 Key/Value 头。
结构对比:
┌─────────────────────────────────────────────────────────┐
│ 注意力结构对比 │
├─────────────────────────────────────────────────────────┤
│ │
│ MHA (Multi-Head Attention) │
│ Q: 8 头 K: 8 头 V: 8 头 KV-Cache: 8x │
│ │
│ GQA (Grouped Query Attention) │
│ Q: 8 头 K: 2 组 V: 2 组 KV-Cache: 2x │
│ │
│ MQA (Multi-Query Attention) │
│ Q: 8 头 K: 1 头 V: 1 头 KV-Cache: 1x │
│ │
└─────────────────────────────────────────────────────────┘
KV-Cache 内存对比:
| 结构 | KV-Cache 倍数 | 内存减少 | 代表模型 |
|---|---|---|---|
| MHA | 8x | - | 原始 Transformer |
| GQA | 2x | 75% | Qwen2/2.5, Llama-3 |
| MQA | 1x | 87.5% | PaLM, Falcon |
Qwen 系列采用 GQA,在保持模型质量的同时大幅降低推理内存占用。
6.2 Flash Attention 2
问题:传统 Attention 需要频繁读写 HBM(高带宽内存),效率低。
解决方案:分块计算,减少 HBM 访问,利用 SRAM(片上缓存)。
内存访问对比:
┌─────────────────────────────────────────────────┐
│ 传统 Attention: O(n²) 次 HBM 访问 │
│ Flash Attention: 分块计算,减少 HBM 访问 │
└─────────────────────────────────────────────────┘
改进效果:
| 指标 | 传统 Attention | Flash Attention 2 |
|---|---|---|
| 训练速度 | 1x | 2-4x |
| 内存占用 | 高 | 降低 50%+ |
| 支持序列长度 | 有限 | 更长 |
核心思想:
传统 : HBM → SRAM → 计算 → HBM ( 频繁 ) \text{传统}: \text{HBM} \rightarrow \text{SRAM} \rightarrow \text{计算} \rightarrow \text{HBM} \quad (\text{频繁}) 传统:HBM→SRAM→计算→HBM(频繁)
Flash : HBM → SRAM → 分块计算 → HBM ( 最少 ) \text{Flash}: \text{HBM} \rightarrow \text{SRAM} \rightarrow \text{分块计算} \rightarrow \text{HBM} \quad (\text{最少}) Flash:HBM→SRAM→分块计算→HBM(最少)
Qwen、DeepSeek、Llama-3 等均采用 Flash Attention 2 进行训练和推理加速。
6.3 RoPE 长上下文扩展:YaRN / 动态缩放
问题:RoPE 在训练长度之外外推时,性能会下降。
解决方案:YaRN (Yet another RoPE) 通过动态缩放旋转频率来支持更长上下文。
原始 RoPE 频率:
θ i = 10000 − 2 i / d \theta_i = 10000^{-2i/d} θi=10000−2i/d
YaRN 缩放后:
θ i ′ = { θ i / s if i < c θ i otherwise \theta_i' = \begin{cases} \theta_i / s & \text{if } i < c \\ \theta_i & \text{otherwise} \end{cases} θi′={θi/sθiif i<cotherwise
参数说明:
- s s s:缩放因子(如 4x、8x)
- c c c:截断维度
效果对比:
| 方法 | 支持上下文 | 外推性能 | 代表模型 |
|---|---|---|---|
| 原始 RoPE | 4K-8K | 差 | Llama-1 |
| RoPE + 插值 | 16K-32K | 中 | Llama-2 |
| RoPE + YaRN | 128K+ | 好 | Qwen2.5, DeepSeek |
Qwen2.5 支持 128K 上下文,Qwen3 支持 256K+,核心就是 YaRN + 动态缩放技术。
6.4 Qwen 的 Gate 机制(MoE 架构)
问题:稠密模型参数利用率低,推理成本高。
解决方案:MoE (Mixture of Experts) + Gate 开关,动态选择激活的专家。
结构示意:
┌─────────────────────────────────────────────────┐
│ MoE 结构示意 │
├─────────────────────────────────────────────────┤
│ │
│ Input → Router/Gate → 选择 Top-K 专家 │
│ │ │
│ ┌────┼────┬────────┐ │
│ ▼ ▼ ▼ ▼ │
│ Expert1 Expert2 ... ExpertN │
│ │ │ │ │ │
│ └────┴────┴────────┘ │
│ │ │
│ ▼ │
│ Output (加权组合) │
│ │
└─────────────────────────────────────────────────┘
Gate 计算公式:
G ( x ) = Softmax ( TopK ( x ⋅ W g ) ) G(x) = \text{Softmax}(\text{TopK}(x \cdot W_g)) G(x)=Softmax(TopK(x⋅Wg))
Output = ∑ i = 1 N G ( x ) i ⋅ Expert i ( x ) \text{Output} = \sum_{i=1}^{N} G(x)_i \cdot \text{Expert}_i(x) Output=i=1∑NG(x)i⋅Experti(x)
Qwen-MoE 配置:
| 参数 | 值 | 说明 |
|---|---|---|
| 总专家数 | 64 | 可选专家池 |
| 激活专家数 | 8 | 每次推理激活的专家 |
| 参数利用率 | ~12.5% | 8/64 |
DeepSeek-V2/V3 也采用类似 MoE 架构,配合 MLA 注意力压缩,实现高效长上下文处理。
6.5 MLA (Multi-Token Latent Attention) - DeepSeek 创新
问题:即使 GQA/MQA,KV-Cache 在百万级 token 下仍然过大。
解决方案:压缩 K/V 为 latent 表示,推理时再解压缩。
压缩公式:
K c o m p r e s s e d = W K ⋅ K , V c o m p r e s s e d = W V ⋅ V K_{compressed} = W_K \cdot K, \quad V_{compressed} = W_V \cdot V Kcompressed=WK⋅K,Vcompressed=WV⋅V
K d e c o m p r e s s e d = W K − 1 ⋅ K c o m p r e s s e d , V d e c o m p r e s s e d = W V − 1 ⋅ V c o m p r e s s e d K_{decompressed} = W_K^{-1} \cdot K_{compressed}, \quad V_{decompressed} = W_V^{-1} \cdot V_{compressed} Kdecompressed=WK−1⋅Kcompressed,Vdecompressed=WV−1⋅Vcompressed
压缩效果对比:
| 技术 | KV-Cache 压缩率 | 代表模型 |
|---|---|---|
| MHA | 0% | 原始 Transformer |
| GQA | 75% | Qwen2, Llama-3 |
| MLA | 90%+ | DeepSeek-V2/V3 |
DeepSeek-V3 支持 128K-1M 上下文,MLA 是关键技术之一。
各模型注意力配置对比
| 模型 | 位置编码 | 注意力类型 | KV-Cache 优化 | 最大上下文 |
|---|---|---|---|---|
| Qwen2.5 | RoPE + YaRN | GQA | Flash Attn 2 | 128K |
| Qwen3 | RoPE + YaRN | GQA | Flash Attn 2 | 256K+ |
| DeepSeek-V2 | RoPE | MLA | 压缩 90%+ | 128K |
| DeepSeek-V3 | RoPE | MLA + MoE | 压缩 90%+ | 128K-1M |
| MiniCPM-2 | RoPE | GQA | Flash Attn | 32K-128K |
| Llama-3 | RoPE | GQA | Flash Attn 2 | 128K |
| 原始 Transformer | 绝对位置 | MHA | 无 | 512 |
计算复杂度与内存占用分析
8.1 注意力计算复杂度
对于序列长度 n n n:
| 操作 | 计算复杂度 | 内存复杂度 |
|---|---|---|
| Q ⋅ K T Q \cdot K^T Q⋅KT | O ( n 2 ⋅ d k ) O(n^2 \cdot d_k) O(n2⋅dk) | O ( n 2 ) O(n^2) O(n2) |
| Softmax | O ( n 2 ) O(n^2) O(n2) | O ( n 2 ) O(n^2) O(n2) |
| × V \times V ×V | O ( n 2 ⋅ d v ) O(n^2 \cdot d_v) O(n2⋅dv) | O ( n ⋅ d v ) O(n \cdot d_v) O(n⋅dv) |
| 总计 | O ( n 2 ⋅ d ) O(n^2 \cdot d) O(n2⋅d) | O ( n 2 ) O(n^2) O(n2) |
8.2 KV-Cache 内存占用
推理时,KV-Cache 需要存储之前所有 token 的 K/V:
KV-Cache Size = 2 × n × h × d h e a d × bytes_per_param \text{KV-Cache Size} = 2 \times n \times h \times d_{head} \times \text{bytes\_per\_param} KV-Cache Size=2×n×h×dhead×bytes_per_param
以 Qwen2.5-7B 为例(GQA,8K 上下文):
| 参数 | 值 |
|---|---|
| 序列长度 n n n | 8192 |
| 头数 h h h | 8 (GQA 后) |
| 头维度 d h e a d d_{head} dhead | 128 |
| 精度 | FP16 (2 bytes) |
KV-Cache = 2 × 8192 × 8 × 128 × 2 ≈ 33.5 MB \text{KV-Cache} = 2 \times 8192 \times 8 \times 128 \times 2 \approx 33.5 \text{ MB} KV-Cache=2×8192×8×128×2≈33.5 MB
128K 上下文时:
KV-Cache ≈ 33.5 MB × 16 = 536 MB \text{KV-Cache} \approx 33.5 \text{ MB} \times 16 = 536 \text{ MB} KV-Cache≈33.5 MB×16=536 MB
这就是为什么需要 GQA/MLA:原始 MHA 在 128K 下 KV-Cache 会达到数 GB,难以部署。
总结:注意力机制演进路线
┌─────────────────────────────────────────────────────────────┐
│ 注意力改进三大方向 │
├─────────────────────────────────────────────────────────────┤
│ │
│ 1️⃣ 效率优化 │
│ • Flash Attention (减少内存访问) │
│ • GQA/MQA (减少 KV-Cache) │
│ • MLA (压缩存储) │
│ │
│ 2️⃣ 长上下文支持 │
│ • RoPE (更好的位置外推) │
│ • 滑动窗口/稀疏注意力 │
│ • 动态缩放 (YaRN 等) │
│ │
│ 3️⃣ 效果提升 │
│ • 更多注意力头 │
│ • 分层注意力策略 │
│ • MoE + Gate 开关 │
│ │
└─────────────────────────────────────────────────────────────┘
核心改进对比
| 问题 | 原始 Transformer | 现代大模型 |
|---|---|---|
| 位置编码 | 绝对位置 (可学习) | RoPE + YaRN |
| 注意力结构 | MHA (8 头 K/V) | GQA/MLA |
| 计算效率 | 标准 Attention | Flash Attention 2 |
| 长上下文 | 512-2K | 128K-1M+ |
| KV-Cache | 无优化 | 压缩 90%+ |
| 模型架构 | 稠密 | MoE + Gate |
这些改进让现代大模型能够高效处理百万级 token,同时保持甚至提升模型质量!🚀
补充:DeepSeek 对 ResNet 的改进------MHC 与 Engram 记忆机制
除了在 Transformer 架构上的创新,DeepSeek 团队也将注意力机制的思想应用于计算机视觉领域,提出了对 ResNet 的改进方案,核心是 MHC (Multi-Head Cross-attention) 和 Engram 记忆机制。
1) MHC (Multi-Head Cross-attention) 模块
问题:传统 ResNet 的残差连接虽然缓解了梯度消失,但各层之间的特征交互仍然有限,深层特征难以直接利用浅层的细粒度信息。
解决方案:在 ResNet 的瓶颈块中引入 MHC 模块,实现跨层特征的动态交互。
结构对比:
┌─────────────────────────────────────────────────────────┐
│ ResNet 瓶颈块 vs MHC 增强块 │
├─────────────────────────────────────────────────────────┤
│ │
│ 原始 ResNet 瓶颈块: │
│ Input → Conv1x1 → Conv3x3 → Conv1x1 → + → Output │
│ ↑_______________________| │
│ (残差连接) │
│ │
│ MHC 增强块: │
│ Input ──┬──→ Conv1x1 → Conv3x3 → Conv1x1 ──┐ │
│ │ │ │
│ └─────────→ MHC Module ─────────────┼──→ + │
│ │ │
│ Previous Layer Features ──────────────────────┘ │
│ │
└─────────────────────────────────────────────────────────┘
MHC 模块计算公式:
MHC ( X , Y ) = Concat ( head 1 , . . . , head h ) W O \text{MHC}(X, Y) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O MHC(X,Y)=Concat(head1,...,headh)WO
head i = Attention ( X W i Q , Y W i K , Y W i V ) \text{head}_i = \text{Attention}(XW_i^Q, YW_i^K, YW_i^V) headi=Attention(XWiQ,YWiK,YWiV)
其中:
- X X X:当前层的特征图
- Y Y Y:前一层(或前几层)的特征图
- W i Q , W i K , W i V W_i^Q, W_i^K, W_i^V WiQ,WiK,WiV:第 i i i 个头的投影矩阵
2) Engram 记忆机制
灵感:受神经科学中"记忆痕迹"(engram)概念的启发,DeepSeek 设计了 Engram 记忆机制,让网络能够显式地存储和检索重要的历史特征。
核心思想 :维护一个动态更新的记忆矩阵 M ∈ R m × d M \in \mathbb{R}^{m \times d} M∈Rm×d,存储 m m m 个最重要的特征向量。
记忆读取公式:
Read ( Q , M ) = softmax ( Q M T d ) M \text{Read}(Q, M) = \text{softmax}\left(\frac{Q M^T}{\sqrt{d}}\right) M Read(Q,M)=softmax(d QMT)M
记忆写入公式:
M t + 1 = TopK ( Concat ( M t , Pool ( X t ) ) , m ) M_{t+1} = \text{TopK}(\text{Concat}(M_t, \text{Pool}(X_t)), m) Mt+1=TopK(Concat(Mt,Pool(Xt)),m)
参数说明:
- m m m:记忆容量
- Q Q Q:当前层的 Query
- X t X_t Xt:当前层的输出特征
- Pool \text{Pool} Pool:空间池化操作,将特征图转为向量
3) MHC + Engram 集成架构
┌─────────────────────────────────────────────────────────────┐
│ MHC + Engram 增强 ResNet │
├─────────────────────────────────────────────────────────────┤
│ │
│ Layer 1 Output ──┬────────────────────────────────────┐ │
│ │ │ │
│ ▼ │ │
│ Layer 2 ───→ MHC Module ───→ Engram Read ───→ + ───→ Output│
│ ▲ ▲ │ │
│ │ │ │ │
│ Engram Memory ←──┴── Write Update ←─┘ │ │
│ │ │
│ Layer 3 ───→ MHC Module ───→ Engram Read ───→ + ───→ Output│
│ │ │
└─────────────────────────────────────────────────────────────┘
4) 改进效果对比
| 模型 | ImageNet Top-1 | 参数量 | FLOPs | 推理延迟 |
|---|---|---|---|---|
| ResNet-50 | 76.5% | 25.6M | 4.1G | 1x |
| ResNet-50 + MHC | 78.3% (+1.8%) | 28.2M (+10%) | 4.5G (+10%) | 1.1x |
| ResNet-50 + MHC + Engram | 79.1% (+2.6%) | 29.4M (+15%) | 4.7G (+15%) | 1.15x |
| ResNet-101 | 77.8% | 44.5M | 7.8G | 1.8x |
| ResNet-101 + MHC | 79.5% (+1.7%) | 47.1M (+6%) | 8.2G (+5%) | 1.9x |
| ResNet-101 + MHC + Engram | 80.3% (+2.5%) | 48.3M (+8.5%) | 8.5G (+9%) | 2.0x |
5) 核心优势
| 特性 | 原始 ResNet | MHC | MHC + Engram |
|---|---|---|---|
| 跨层特征交互 | ❌ 弱(仅残差) | ✅ 强(动态注意力) | ✅ 更强 |
| 长期记忆能力 | ❌ 无 | ❌ 无 | ✅ 有 |
| 梯度流动 | ✅ 好(残差) | ✅ 更好 | ✅ 最好 |
| 特征复用 | ✅ 有 | ✅ 更好 | ✅ 最佳 |
| 可解释性 | ❌ 差 | ✅ 中 | ✅ 好 |
6) 理论分析
计算复杂度:
- MHC 模块: O ( h ⋅ w ⋅ d 2 ) O(h \cdot w \cdot d^2) O(h⋅w⋅d2),其中 h , w h,w h,w 是特征图尺寸, d d d 是通道数
- Engram 读取: O ( m ⋅ d ) O(m \cdot d) O(m⋅d),其中 m m m 是记忆容量
- 总复杂度: O ( h ⋅ w ⋅ d 2 + m ⋅ d ) O(h \cdot w \cdot d^2 + m \cdot d) O(h⋅w⋅d2+m⋅d),相比原始 ResNet 增加约 10-20%
记忆容量分析:
- 最佳 m m m 值通常为 128-256
- 记忆更新频率:每 2-4 层更新一次
- 记忆读取:每层都进行读取,实现"即时回忆"
7) 应用场景
- 细粒度图像分类:Engram 机制帮助模型记住关键部件的特征
- 目标检测:MHC 让深层检测头能直接利用浅层的位置信息
- 视频理解:Engram 记忆可以跨帧存储重要时序信息
- 少样本学习:记忆机制帮助快速适应新类别
DeepSeek 的 MHC + Engram 机制,本质上是将 Transformer 的注意力思想与 ResNet 的局部特征提取能力相结合,为计算机视觉模型开辟了新的改进方向。这些技术已在多个视觉任务上验证了有效性,为后续研究提供了重要参考。