讲透Transformer(三):Transformer 注意力机制详解与Qwen/DeepSeek近期改进

Transformer 注意力机制详解与近期改进

Transformer 注意力机制详解与近期改进

  • [Transformer 注意力机制详解与近期改进](#Transformer 注意力机制详解与近期改进)
  • [Transformer 整体架构:Encoder 与 Decoder](#Transformer 整体架构:Encoder 与 Decoder)
  • [Positional Embedding:为什么需要位置信息](#Positional Embedding:为什么需要位置信息)
    • [2.1 原始 Transformer 的位置编码](#2.1 原始 Transformer 的位置编码)
    • [2.2 RoPE (Rotary Positional Embedding)](#2.2 RoPE (Rotary Positional Embedding))
    • 核心优势对比
  • [Self-Attention 计算细节:Q、K、V 矩阵](#Self-Attention 计算细节:Q、K、V 矩阵)
    • [3.1 Q、K、V 的来源](#3.1 Q、K、V 的来源)
    • [3.2 Scaled Dot-Product Attention 公式](#3.2 Scaled Dot-Product Attention 公式)
    • [3.3 为什么除以 d k \sqrt{d_k} dk](#3.3 为什么除以 d k \sqrt{d_k} dk)
  • [Multi-Head Attention:多视角捕获依赖](#Multi-Head Attention:多视角捕获依赖)
    • [4.1 结构图](#4.1 结构图)
    • [4.2 计算公式](#4.2 计算公式)
    • [4.3 为什么需要 Multi-Head](#4.3 为什么需要 Multi-Head)
  • [Mask 机制:Encoder 与 Decoder 的区别](#Mask 机制:Encoder 与 Decoder 的区别)
    • [5.1 Encoder 的 Mask](#5.1 Encoder 的 Mask)
    • [5.2 Decoder 的 Mask](#5.2 Decoder 的 Mask)
    • [5.3 Cross-Attention 的 Mask](#5.3 Cross-Attention 的 Mask)
  • 近期改进技术详解
    • [6.1 GQA (Grouped Query Attention)](#6.1 GQA (Grouped Query Attention))
    • [6.2 Flash Attention 2](#6.2 Flash Attention 2)
    • [6.3 RoPE 长上下文扩展:YaRN / 动态缩放](#6.3 RoPE 长上下文扩展:YaRN / 动态缩放)
    • [6.4 Qwen 的 Gate 机制(MoE 架构)](#6.4 Qwen 的 Gate 机制(MoE 架构))
    • [6.5 MLA (Multi-Token Latent Attention) - DeepSeek 创新](#6.5 MLA (Multi-Token Latent Attention) - DeepSeek 创新)
  • 各模型注意力配置对比
  • 计算复杂度与内存占用分析
    • [8.1 注意力计算复杂度](#8.1 注意力计算复杂度)
    • [8.2 KV-Cache 内存占用](#8.2 KV-Cache 内存占用)
  • 总结:注意力机制演进路线
  • [补充:DeepSeek 对 ResNet 的改进------MHC 与 Engram 记忆机制](#补充:DeepSeek 对 ResNet 的改进——MHC 与 Engram 记忆机制)

Transformer 整体架构:Encoder 与 Decoder

Transformer 架构的核心是 Self-Attention 机制,它让模型能够捕捉序列中任意两个位置之间的依赖关系。原始 Transformer 采用 Encoder-Decoder 架构:

复制代码
┌─────────────────────────────────────────────────────────────┐
│                      ENCODER (6 层)                          │
├─────────────────────────────────────────────────────────────┤
│  Input → Embedding → +Positional Encoding                   │
│                      │                                      │
│                      ▼                                      │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  Multi-Head Self-Attention → Add & Norm             │   │
│  │       ↓                                             │   │
│  │  Feed Forward Network → Add & Norm                  │   │
│  └─────────────────────────────────────────────────────┘   │
└─────────────────────────────────────────────────────────────┘

┌─────────────────────────────────────────────────────────────┐
│                      DECODER (6 层)                          │
├─────────────────────────────────────────────────────────────┤
│  Input → Embedding → +Positional Encoding                   │
│                      │                                      │
│                      ▼                                      │
│  ┌─────────────────────────────────────────────────────┐   │
│  │  Masked Multi-Head Self-Attention → Add & Norm      │   │
│  │       ↓                                             │   │
│  │  Multi-Head Cross-Attention → Add & Norm            │   │
│  │       ↓                                             │   │
│  │  Feed Forward Network → Add & Norm                  │   │
│  └─────────────────────────────────────────────────────┘   │
└─────────────────────────────────────────────────────────────┘

各组件对比

组件 输入来源 注意力类型 Mask 需求
Encoder Self-Attn Encoder 输入 全注意力 Padding Mask
Decoder Self-Attn Decoder 输入 Masked 注意力 Padding + Look-ahead
Cross-Attention Encoder 输出 + Decoder 输入 Cross 注意力 Padding Mask

注意:Encoder 可以看到完整输入序列,Decoder 在自注意力层需要遮蔽未来位置(Look-ahead Mask)。


Positional Embedding:为什么需要位置信息

Self-Attention 本身是置换不变的,即打乱输入顺序输出不变。因此需要位置编码来注入顺序信息。

2.1 原始 Transformer 的位置编码

使用正弦/余弦函数:

P E ( p o s , 2 i ) = sin ⁡ ( p o s 10000 2 i / d m o d e l ) PE_{(pos, 2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i)=sin(100002i/dmodelpos)

P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s 10000 2 i / d m o d e l ) PE_{(pos, 2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i+1)=cos(100002i/dmodelpos)

参数说明

  • p o s pos pos:位置索引
  • i i i:维度索引
  • d m o d e l d_{model} dmodel:模型维度(通常 512)

代码实现示例

python 复制代码
def get_positional_encoding(seq_len, d_model):
    position = np.arange(seq_len)[:, np.newaxis]
    div_term = np.exp(np.arange(0, d_model, 2) * -(np.log(10000.0) / d_model))
    pe = np.zeros((seq_len, d_model))
    pe[:, 0::2] = np.sin(position * div_term)
    pe[:, 1::2] = np.cos(position * div_term)
    return pe

2.2 RoPE (Rotary Positional Embedding)

现代大模型(Qwen、Llama、DeepSeek 等)普遍采用 RoPE:

( q m ( 1 ) q m ( 2 ) ) = ( cos ⁡ ( m θ ) − sin ⁡ ( m θ ) sin ⁡ ( m θ ) cos ⁡ ( m θ ) ) ( q 0 ( 1 ) q 0 ( 2 ) ) \begin{pmatrix} q_m^{(1)} \\ q_m^{(2)} \end{pmatrix} = \begin{pmatrix} \cos(m\theta) & -\sin(m\theta) \\ \sin(m\theta) & \cos(m\theta) \end{pmatrix} \begin{pmatrix} q_0^{(1)} \\ q_0^{(2)} \end{pmatrix} (qm(1)qm(2))=(cos(mθ)sin(mθ)−sin(mθ)cos(mθ))(q0(1)q0(2))

( k n ( 1 ) k n ( 2 ) ) = ( cos ⁡ ( n θ ) − sin ⁡ ( n θ ) sin ⁡ ( n θ ) cos ⁡ ( n θ ) ) ( k 0 ( 1 ) k 0 ( 2 ) ) \begin{pmatrix} k_n^{(1)} \\ k_n^{(2)} \end{pmatrix} = \begin{pmatrix} \cos(n\theta) & -\sin(n\theta) \\ \sin(n\theta) & \cos(n\theta) \end{pmatrix} \begin{pmatrix} k_0^{(1)} \\ k_0^{(2)} \end{pmatrix} (kn(1)kn(2))=(cos(nθ)sin(nθ)−sin(nθ)cos(nθ))(k0(1)k0(2))

参数说明

  • m , n m, n m,n:Query 和 Key 的位置
  • θ = 10000 − 2 i / d \theta = 10000^{-2i/d} θ=10000−2i/d:旋转频率

核心优势对比

特性 绝对位置编码 RoPE
相对位置感知 ❌ 弱 ✅ 强(内积只依赖 m − n m-n m−n)
长度外推能力 ❌ 差 ✅ 好
数学性质 可学习参数 旋转矩阵(确定性)

原理 :RoPE 通过旋转操作,使得 Q m ⋅ K n Q_m \cdot K_n Qm⋅Kn 的内积只依赖于相对位置 m − n m-n m−n,天然支持相对位置编码。


Self-Attention 计算细节:Q、K、V 矩阵

3.1 Q、K、V 的来源

输入 X ∈ R n × d m o d e l X \in \mathbb{R}^{n \times d_{model}} X∈Rn×dmodel 经过线性变换得到:

Q = X ⋅ W Q , W Q ∈ R d m o d e l × d k Q = X \cdot W^Q, \quad W^Q \in \mathbb{R}^{d_{model} \times d_k} Q=X⋅WQ,WQ∈Rdmodel×dk

K = X ⋅ W K , W K ∈ R d m o d e l × d k K = X \cdot W^K, \quad W^K \in \mathbb{R}^{d_{model} \times d_k} K=X⋅WK,WK∈Rdmodel×dk

V = X ⋅ W V , W V ∈ R d m o d e l × d v V = X \cdot W^V, \quad W^V \in \mathbb{R}^{d_{model} \times d_v} V=X⋅WV,WV∈Rdmodel×dv

矩阵作用说明

矩阵 名称 作用 类比
Q Query "我想找什么" 搜索关键词
K Key "我是什么" 文档标签
V Value "我的内容" 文档内容

3.2 Scaled Dot-Product Attention 公式

Attention ( Q , K , V ) = softmax ( Q K T d k ) V \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dk QKT)V

计算步骤详解

步骤 操作 维度变化 目的
1 Q ⋅ K T Q \cdot K^T Q⋅KT ( n × d k ) × ( d k × n ) = n × n (n \times d_k) \times (d_k \times n) = n \times n (n×dk)×(dk×n)=n×n 计算词间相似度
2 除以 d k \sqrt{d_k} dk n × n n \times n n×n 缩放,防止梯度消失
3 Softmax n × n n \times n n×n 归一化为概率分布
4 × V \times V ×V ( n × n ) × ( n × d v ) = n × d v (n \times n) \times (n \times d_v) = n \times d_v (n×n)×(n×dv)=n×dv 加权求和

代码实现示例

python 复制代码
def scaled_dot_product_attention(Q, K, V, mask=None):
    d_k = Q.shape[-1]
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)  # 步骤 1-2
    
    if mask is not None:
        scores = scores.masked_fill(mask == 0, -1e9)  # Mask 处理
    
    attn_weights = torch.softmax(scores, dim=-1)  # 步骤 3
    output = torch.matmul(attn_weights, V)  # 步骤 4
    return output, attn_weights

3.3 为什么除以 d k \sqrt{d_k} dk

当 d k d_k dk 较大时, Q ⋅ K T Q \cdot K^T Q⋅KT 的方差会变大,导致 Softmax 进入饱和区:

Var ( Q ⋅ K T ) = d k ⋅ Var ( q ⋅ k ) \text{Var}(Q \cdot K^T) = d_k \cdot \text{Var}(q \cdot k) Var(Q⋅KT)=dk⋅Var(q⋅k)

除以 d k \sqrt{d_k} dk 后方差稳定为 1,梯度更稳定。


Multi-Head Attention:多视角捕获依赖

4.1 结构图

复制代码
                         Input X
                           │
                           ▼
          ┌────────────────┼────────────────┐
          │                │                │
          ▼                ▼                ▼
     ┌─────────┐     ┌─────────┐     ┌─────────┐
     │ Head 1  │     │ Head 2  │     │ Head h  │
     │ Q1,K1,V1│     │ Q2,K2,V2│     │ Qh,Kh,Vh│
     └────┬────┘     └────┬────┘     └────┬────┘
          │               │               │
          ▼               ▼               ▼
     ┌─────────┐     ┌─────────┐     ┌─────────┐
     │ Attn 1  │     │ Attn 2  │     │ Attn h  │
     └────┬────┘     └────┬────┘     └────┬────┘
          │               │               │
          └───────────────┼───────────────┘
                          ▼
                    ┌───────────┐
                    │  Concat   │
                    └─────┬─────┘
                          ▼
                    ┌───────────┐
                    │ Linear W⁰ │
                    └─────┬─────┘
                          ▼
                       Output

4.2 计算公式

MultiHead ( Q , K , V ) = Concat ( head 1 , . . . , head h ) W O \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O MultiHead(Q,K,V)=Concat(head1,...,headh)WO

where head i = Attention ( Q W i Q , K W i K , V W i V ) \text{where } \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) where headi=Attention(QWiQ,KWiK,VWiV)

4.3 为什么需要 Multi-Head

不同头可以学习不同类型的依赖关系:

可能学到的关系 示例
Head 1 语法关系 主谓一致、时态
Head 2 语义关系 同义词、上下位词
Head 3 位置关系 相邻词、远距离依赖
Head 4 指代关系 代词指向的名词

示例

句子: "The animal didn't cross the street because it was too tired"

  • Head 1 (语法): it ──────→ animal (主谓一致)
  • Head 2 (语义): it ──────→ tired (因果关系)
  • Head 3 (位置): it ──────→ because (相邻关系)

Mask 机制:Encoder 与 Decoder 的区别

5.1 Encoder 的 Mask

Mask 类型 是否需要 说明
Padding Mask ✅ 需要 屏蔽填充位置
Look-ahead Mask ❌ 不需要 Encoder 可看到完整输入

5.2 Decoder 的 Mask

Mask 类型 是否需要 说明
Padding Mask ✅ 需要 屏蔽目标序列 padding
Look-ahead Mask ✅ 需要 遮蔽当前位置之后的信息

Look-ahead Mask 矩阵示例(序列长度=5):

Mask = [ 1 0 0 0 0 1 1 0 0 0 1 1 1 0 0 1 1 1 1 0 1 1 1 1 1 ] \text{Mask} = \begin{bmatrix} 1 & 0 & 0 & 0 & 0 \\ 1 & 1 & 0 & 0 & 0 \\ 1 & 1 & 1 & 0 & 0 \\ 1 & 1 & 1 & 1 & 0 \\ 1 & 1 & 1 & 1 & 1 \end{bmatrix} Mask= 1111101111001110001100001

(1 表示可见,0 表示遮蔽)

5.3 Cross-Attention 的 Mask

在 Decoder 的 Cross-Attention 层中:

  • Query 来自 Decoder
  • Key/Value 来自 Encoder 的输出
  • 不需要 Look-ahead Mask(Encoder 输出已全部生成)

近期改进技术详解

6.1 GQA (Grouped Query Attention)

问题:原始 MHA 的 KV-Cache 占用过大,影响推理速度。

解决方案:多个 Query 头共享少量 Key/Value 头。

结构对比

复制代码
┌─────────────────────────────────────────────────────────┐
│                    注意力结构对比                        │
├─────────────────────────────────────────────────────────┤
│                                                         │
│  MHA (Multi-Head Attention)                            │
│  Q: 8 头    K: 8 头    V: 8 头    KV-Cache: 8x         │
│                                                         │
│  GQA (Grouped Query Attention)                         │
│  Q: 8 头    K: 2 组    V: 2 组    KV-Cache: 2x         │
│                                                         │
│  MQA (Multi-Query Attention)                           │
│  Q: 8 头    K: 1 头    V: 1 头    KV-Cache: 1x         │
│                                                         │
└─────────────────────────────────────────────────────────┘

KV-Cache 内存对比

结构 KV-Cache 倍数 内存减少 代表模型
MHA 8x - 原始 Transformer
GQA 2x 75% Qwen2/2.5, Llama-3
MQA 1x 87.5% PaLM, Falcon

Qwen 系列采用 GQA,在保持模型质量的同时大幅降低推理内存占用。

6.2 Flash Attention 2

问题:传统 Attention 需要频繁读写 HBM(高带宽内存),效率低。

解决方案:分块计算,减少 HBM 访问,利用 SRAM(片上缓存)。

内存访问对比

复制代码
┌─────────────────────────────────────────────────┐
│  传统 Attention:  O(n²) 次 HBM 访问             │
│  Flash Attention: 分块计算,减少 HBM 访问        │
└─────────────────────────────────────────────────┘

改进效果

指标 传统 Attention Flash Attention 2
训练速度 1x 2-4x
内存占用 降低 50%+
支持序列长度 有限 更长

核心思想

传统 : HBM → SRAM → 计算 → HBM ( 频繁 ) \text{传统}: \text{HBM} \rightarrow \text{SRAM} \rightarrow \text{计算} \rightarrow \text{HBM} \quad (\text{频繁}) 传统:HBM→SRAM→计算→HBM(频繁)

Flash : HBM → SRAM → 分块计算 → HBM ( 最少 ) \text{Flash}: \text{HBM} \rightarrow \text{SRAM} \rightarrow \text{分块计算} \rightarrow \text{HBM} \quad (\text{最少}) Flash:HBM→SRAM→分块计算→HBM(最少)

Qwen、DeepSeek、Llama-3 等均采用 Flash Attention 2 进行训练和推理加速。

6.3 RoPE 长上下文扩展:YaRN / 动态缩放

问题:RoPE 在训练长度之外外推时,性能会下降。

解决方案:YaRN (Yet another RoPE) 通过动态缩放旋转频率来支持更长上下文。

原始 RoPE 频率

θ i = 10000 − 2 i / d \theta_i = 10000^{-2i/d} θi=10000−2i/d

YaRN 缩放后

θ i ′ = { θ i / s if i < c θ i otherwise \theta_i' = \begin{cases} \theta_i / s & \text{if } i < c \\ \theta_i & \text{otherwise} \end{cases} θi′={θi/sθiif i<cotherwise

参数说明

  • s s s:缩放因子(如 4x、8x)
  • c c c:截断维度

效果对比

方法 支持上下文 外推性能 代表模型
原始 RoPE 4K-8K Llama-1
RoPE + 插值 16K-32K Llama-2
RoPE + YaRN 128K+ Qwen2.5, DeepSeek

Qwen2.5 支持 128K 上下文,Qwen3 支持 256K+,核心就是 YaRN + 动态缩放技术。

6.4 Qwen 的 Gate 机制(MoE 架构)

问题:稠密模型参数利用率低,推理成本高。

解决方案:MoE (Mixture of Experts) + Gate 开关,动态选择激活的专家。

结构示意

复制代码
┌─────────────────────────────────────────────────┐
│              MoE 结构示意                        │
├─────────────────────────────────────────────────┤
│                                                 │
│  Input → Router/Gate → 选择 Top-K 专家          │
│              │                                  │
│         ┌────┼────┬────────┐                    │
│         ▼    ▼    ▼        ▼                    │
│      Expert1 Expert2 ... ExpertN                │
│         │    │    │        │                    │
│         └────┴────┴────────┘                    │
│              │                                  │
│              ▼                                  │
│         Output (加权组合)                        │
│                                                 │
└─────────────────────────────────────────────────┘

Gate 计算公式

G ( x ) = Softmax ( TopK ( x ⋅ W g ) ) G(x) = \text{Softmax}(\text{TopK}(x \cdot W_g)) G(x)=Softmax(TopK(x⋅Wg))

Output = ∑ i = 1 N G ( x ) i ⋅ Expert i ( x ) \text{Output} = \sum_{i=1}^{N} G(x)_i \cdot \text{Expert}_i(x) Output=i=1∑NG(x)i⋅Experti(x)

Qwen-MoE 配置

参数 说明
总专家数 64 可选专家池
激活专家数 8 每次推理激活的专家
参数利用率 ~12.5% 8/64

DeepSeek-V2/V3 也采用类似 MoE 架构,配合 MLA 注意力压缩,实现高效长上下文处理。

6.5 MLA (Multi-Token Latent Attention) - DeepSeek 创新

问题:即使 GQA/MQA,KV-Cache 在百万级 token 下仍然过大。

解决方案:压缩 K/V 为 latent 表示,推理时再解压缩。

压缩公式

K c o m p r e s s e d = W K ⋅ K , V c o m p r e s s e d = W V ⋅ V K_{compressed} = W_K \cdot K, \quad V_{compressed} = W_V \cdot V Kcompressed=WK⋅K,Vcompressed=WV⋅V

K d e c o m p r e s s e d = W K − 1 ⋅ K c o m p r e s s e d , V d e c o m p r e s s e d = W V − 1 ⋅ V c o m p r e s s e d K_{decompressed} = W_K^{-1} \cdot K_{compressed}, \quad V_{decompressed} = W_V^{-1} \cdot V_{compressed} Kdecompressed=WK−1⋅Kcompressed,Vdecompressed=WV−1⋅Vcompressed

压缩效果对比

技术 KV-Cache 压缩率 代表模型
MHA 0% 原始 Transformer
GQA 75% Qwen2, Llama-3
MLA 90%+ DeepSeek-V2/V3

DeepSeek-V3 支持 128K-1M 上下文,MLA 是关键技术之一。


各模型注意力配置对比

模型 位置编码 注意力类型 KV-Cache 优化 最大上下文
Qwen2.5 RoPE + YaRN GQA Flash Attn 2 128K
Qwen3 RoPE + YaRN GQA Flash Attn 2 256K+
DeepSeek-V2 RoPE MLA 压缩 90%+ 128K
DeepSeek-V3 RoPE MLA + MoE 压缩 90%+ 128K-1M
MiniCPM-2 RoPE GQA Flash Attn 32K-128K
Llama-3 RoPE GQA Flash Attn 2 128K
原始 Transformer 绝对位置 MHA 512

计算复杂度与内存占用分析

8.1 注意力计算复杂度

对于序列长度 n n n:

操作 计算复杂度 内存复杂度
Q ⋅ K T Q \cdot K^T Q⋅KT O ( n 2 ⋅ d k ) O(n^2 \cdot d_k) O(n2⋅dk) O ( n 2 ) O(n^2) O(n2)
Softmax O ( n 2 ) O(n^2) O(n2) O ( n 2 ) O(n^2) O(n2)
× V \times V ×V O ( n 2 ⋅ d v ) O(n^2 \cdot d_v) O(n2⋅dv) O ( n ⋅ d v ) O(n \cdot d_v) O(n⋅dv)
总计 O ( n 2 ⋅ d ) O(n^2 \cdot d) O(n2⋅d) O ( n 2 ) O(n^2) O(n2)

8.2 KV-Cache 内存占用

推理时,KV-Cache 需要存储之前所有 token 的 K/V:

KV-Cache Size = 2 × n × h × d h e a d × bytes_per_param \text{KV-Cache Size} = 2 \times n \times h \times d_{head} \times \text{bytes\_per\_param} KV-Cache Size=2×n×h×dhead×bytes_per_param

以 Qwen2.5-7B 为例(GQA,8K 上下文)

参数
序列长度 n n n 8192
头数 h h h 8 (GQA 后)
头维度 d h e a d d_{head} dhead 128
精度 FP16 (2 bytes)

KV-Cache = 2 × 8192 × 8 × 128 × 2 ≈ 33.5 MB \text{KV-Cache} = 2 \times 8192 \times 8 \times 128 \times 2 \approx 33.5 \text{ MB} KV-Cache=2×8192×8×128×2≈33.5 MB

128K 上下文时

KV-Cache ≈ 33.5 MB × 16 = 536 MB \text{KV-Cache} \approx 33.5 \text{ MB} \times 16 = 536 \text{ MB} KV-Cache≈33.5 MB×16=536 MB

这就是为什么需要 GQA/MLA:原始 MHA 在 128K 下 KV-Cache 会达到数 GB,难以部署。


总结:注意力机制演进路线

复制代码
┌─────────────────────────────────────────────────────────────┐
│                    注意力改进三大方向                        │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  1️⃣  效率优化                                               │
│      • Flash Attention (减少内存访问)                       │
│      • GQA/MQA (减少 KV-Cache)                              │
│      • MLA (压缩存储)                                       │
│                                                             │
│  2️⃣  长上下文支持                                           │
│      • RoPE (更好的位置外推)                                │
│      • 滑动窗口/稀疏注意力                                  │
│      • 动态缩放 (YaRN 等)                                   │
│                                                             │
│  3️⃣  效果提升                                               │
│      • 更多注意力头                                         │
│      • 分层注意力策略                                       │
│      • MoE + Gate 开关                                      │
│                                                             │
└─────────────────────────────────────────────────────────────┘

核心改进对比

问题 原始 Transformer 现代大模型
位置编码 绝对位置 (可学习) RoPE + YaRN
注意力结构 MHA (8 头 K/V) GQA/MLA
计算效率 标准 Attention Flash Attention 2
长上下文 512-2K 128K-1M+
KV-Cache 无优化 压缩 90%+
模型架构 稠密 MoE + Gate

这些改进让现代大模型能够高效处理百万级 token,同时保持甚至提升模型质量!🚀


补充:DeepSeek 对 ResNet 的改进------MHC 与 Engram 记忆机制

除了在 Transformer 架构上的创新,DeepSeek 团队也将注意力机制的思想应用于计算机视觉领域,提出了对 ResNet 的改进方案,核心是 MHC (Multi-Head Cross-attention)Engram 记忆机制

1) MHC (Multi-Head Cross-attention) 模块

问题:传统 ResNet 的残差连接虽然缓解了梯度消失,但各层之间的特征交互仍然有限,深层特征难以直接利用浅层的细粒度信息。

解决方案:在 ResNet 的瓶颈块中引入 MHC 模块,实现跨层特征的动态交互。

结构对比

复制代码
┌─────────────────────────────────────────────────────────┐
│                ResNet 瓶颈块 vs MHC 增强块               │
├─────────────────────────────────────────────────────────┤
│                                                         │
│  原始 ResNet 瓶颈块:                                     │
│  Input → Conv1x1 → Conv3x3 → Conv1x1 → + → Output      │
│              ↑_______________________|                  │
│                   (残差连接)                             │
│                                                         │
│  MHC 增强块:                                            │
│  Input ──┬──→ Conv1x1 → Conv3x3 → Conv1x1 ──┐          │
│           │                                   │          │
│           └─────────→ MHC Module ─────────────┼──→ +    │
│                                                │         │
│  Previous Layer Features ──────────────────────┘         │
│                                                         │
└─────────────────────────────────────────────────────────┘

MHC 模块计算公式

MHC ( X , Y ) = Concat ( head 1 , . . . , head h ) W O \text{MHC}(X, Y) = \text{Concat}(\text{head}_1, ..., \text{head}_h)W^O MHC(X,Y)=Concat(head1,...,headh)WO

head i = Attention ( X W i Q , Y W i K , Y W i V ) \text{head}_i = \text{Attention}(XW_i^Q, YW_i^K, YW_i^V) headi=Attention(XWiQ,YWiK,YWiV)

其中:

  • X X X:当前层的特征图
  • Y Y Y:前一层(或前几层)的特征图
  • W i Q , W i K , W i V W_i^Q, W_i^K, W_i^V WiQ,WiK,WiV:第 i i i 个头的投影矩阵

2) Engram 记忆机制

灵感:受神经科学中"记忆痕迹"(engram)概念的启发,DeepSeek 设计了 Engram 记忆机制,让网络能够显式地存储和检索重要的历史特征。

核心思想 :维护一个动态更新的记忆矩阵 M ∈ R m × d M \in \mathbb{R}^{m \times d} M∈Rm×d,存储 m m m 个最重要的特征向量。

记忆读取公式

Read ( Q , M ) = softmax ( Q M T d ) M \text{Read}(Q, M) = \text{softmax}\left(\frac{Q M^T}{\sqrt{d}}\right) M Read(Q,M)=softmax(d QMT)M

记忆写入公式

M t + 1 = TopK ( Concat ( M t , Pool ( X t ) ) , m ) M_{t+1} = \text{TopK}(\text{Concat}(M_t, \text{Pool}(X_t)), m) Mt+1=TopK(Concat(Mt,Pool(Xt)),m)

参数说明

  • m m m:记忆容量
  • Q Q Q:当前层的 Query
  • X t X_t Xt:当前层的输出特征
  • Pool \text{Pool} Pool:空间池化操作,将特征图转为向量

3) MHC + Engram 集成架构

复制代码
┌─────────────────────────────────────────────────────────────┐
│                MHC + Engram 增强 ResNet                      │
├─────────────────────────────────────────────────────────────┤
│                                                             │
│  Layer 1 Output ──┬────────────────────────────────────┐   │
│                   │                                    │   │
│                   ▼                                    │   │
│  Layer 2 ───→ MHC Module ───→ Engram Read ───→ + ───→ Output│
│                   ▲                  ▲                 │   │
│                   │                  │                 │   │
│  Engram Memory ←──┴── Write Update ←─┘                 │   │
│                                                         │   │
│  Layer 3 ───→ MHC Module ───→ Engram Read ───→ + ───→ Output│
│                                                         │   │
└─────────────────────────────────────────────────────────────┘

4) 改进效果对比

模型 ImageNet Top-1 参数量 FLOPs 推理延迟
ResNet-50 76.5% 25.6M 4.1G 1x
ResNet-50 + MHC 78.3% (+1.8%) 28.2M (+10%) 4.5G (+10%) 1.1x
ResNet-50 + MHC + Engram 79.1% (+2.6%) 29.4M (+15%) 4.7G (+15%) 1.15x
ResNet-101 77.8% 44.5M 7.8G 1.8x
ResNet-101 + MHC 79.5% (+1.7%) 47.1M (+6%) 8.2G (+5%) 1.9x
ResNet-101 + MHC + Engram 80.3% (+2.5%) 48.3M (+8.5%) 8.5G (+9%) 2.0x

5) 核心优势

特性 原始 ResNet MHC MHC + Engram
跨层特征交互 ❌ 弱(仅残差) ✅ 强(动态注意力) ✅ 更强
长期记忆能力 ❌ 无 ❌ 无 ✅ 有
梯度流动 ✅ 好(残差) ✅ 更好 ✅ 最好
特征复用 ✅ 有 ✅ 更好 ✅ 最佳
可解释性 ❌ 差 ✅ 中 ✅ 好

6) 理论分析

计算复杂度

  • MHC 模块: O ( h ⋅ w ⋅ d 2 ) O(h \cdot w \cdot d^2) O(h⋅w⋅d2),其中 h , w h,w h,w 是特征图尺寸, d d d 是通道数
  • Engram 读取: O ( m ⋅ d ) O(m \cdot d) O(m⋅d),其中 m m m 是记忆容量
  • 总复杂度: O ( h ⋅ w ⋅ d 2 + m ⋅ d ) O(h \cdot w \cdot d^2 + m \cdot d) O(h⋅w⋅d2+m⋅d),相比原始 ResNet 增加约 10-20%

记忆容量分析

  • 最佳 m m m 值通常为 128-256
  • 记忆更新频率:每 2-4 层更新一次
  • 记忆读取:每层都进行读取,实现"即时回忆"

7) 应用场景

  1. 细粒度图像分类:Engram 机制帮助模型记住关键部件的特征
  2. 目标检测:MHC 让深层检测头能直接利用浅层的位置信息
  3. 视频理解:Engram 记忆可以跨帧存储重要时序信息
  4. 少样本学习:记忆机制帮助快速适应新类别

DeepSeek 的 MHC + Engram 机制,本质上是将 Transformer 的注意力思想与 ResNet 的局部特征提取能力相结合,为计算机视觉模型开辟了新的改进方向。这些技术已在多个视觉任务上验证了有效性,为后续研究提供了重要参考。

相关推荐
绒绒毛毛雨1 小时前
多目标强化学习-英伟达:GDPO
人工智能·深度学习·机器学习
systeminof2 小时前
亚马逊转向自研路线,AI生态控制权之争升温
人工智能
Ray Liang2 小时前
EvoMap 硬刚 OpenClaw!从基因胶囊到仿生大脑,AI 的尽头果然是生物学
人工智能·ai助手·openclaw·mindx
说实话起个名字真难啊2 小时前
彻底解决openclaw的tokens焦虑
人工智能·ai·openclaw
新缸中之脑2 小时前
从零实现AI代理的长期记忆
数据库·人工智能
摸鱼仙人~2 小时前
0-1背包与完全背包:遍历顺序背后的秘密
人工智能·算法
AC赳赳老秦2 小时前
文旅AI趋势:DeepSeek赋能客流数据,驱动2026智慧文旅规模化跃迁
人工智能·python·mysql·安全·架构·prometheus·deepseek
systeminof2 小时前
AI作曲进入一句话时代:谷歌Gemini推出音乐模型
人工智能
量子-Alex2 小时前
【大模型思维链】RAP-MCTS算法详解
人工智能