RoPE 解构:从复数平面到 Transformer 的位置魔法

⚡ RoPE 解构:从复数平面到 Transformer 的位置魔法

Rotary Position Embedding --- From Mathematical Foundations to Engineering Implementation

旋转位置编码 · 从数学基础到工程实现 · 完整技术详解


📑 目录

章节 主题 章节 主题
[01](#章节 主题 章节 主题 01 核心概念概述 09 注意力热力图 02 位置编码演变史 10 高效实现 03 复数域视角 11 计算流程 04 数学推导 12 方法对比 05 2D 旋转 13 RoPE 变体 06 多维旋转 14 代码实现 07 矩阵结构 15 总结 08 频率谱 — — "#01-%E6%A0%B8%E5%BF%83%E6%A6%82%E5%BF%B5%E6%A6%82%E8%BF%B0") 核心概念概述 [09](#章节 主题 章节 主题 01 核心概念概述 09 注意力热力图 02 位置编码演变史 10 高效实现 03 复数域视角 11 计算流程 04 数学推导 12 方法对比 05 2D 旋转 13 RoPE 变体 06 多维旋转 14 代码实现 07 矩阵结构 15 总结 08 频率谱 — — "#09-%E6%B3%A8%E6%84%8F%E5%8A%9B%E7%83%AD%E5%8A%9B%E5%9B%BE") 注意力热力图
[02](#章节 主题 章节 主题 01 核心概念概述 09 注意力热力图 02 位置编码演变史 10 高效实现 03 复数域视角 11 计算流程 04 数学推导 12 方法对比 05 2D 旋转 13 RoPE 变体 06 多维旋转 14 代码实现 07 矩阵结构 15 总结 08 频率谱 — — "#02-%E4%BD%8D%E7%BD%AE%E7%BC%96%E7%A0%81%E6%BC%94%E5%8F%98%E5%8F%B2") 位置编码演变史 [10](#章节 主题 章节 主题 01 核心概念概述 09 注意力热力图 02 位置编码演变史 10 高效实现 03 复数域视角 11 计算流程 04 数学推导 12 方法对比 05 2D 旋转 13 RoPE 变体 06 多维旋转 14 代码实现 07 矩阵结构 15 总结 08 频率谱 — — "#10-%E9%AB%98%E6%95%88%E5%AE%9E%E7%8E%B0") 高效实现
[03](#章节 主题 章节 主题 01 核心概念概述 09 注意力热力图 02 位置编码演变史 10 高效实现 03 复数域视角 11 计算流程 04 数学推导 12 方法对比 05 2D 旋转 13 RoPE 变体 06 多维旋转 14 代码实现 07 矩阵结构 15 总结 08 频率谱 — — "#03-%E5%A4%8D%E6%95%B0%E5%9F%9F%E8%A7%86%E8%A7%92%E4%B8%8E%E6%AC%A7%E6%8B%89%E5%85%AC%E5%BC%8F") 复数域视角 [11](#章节 主题 章节 主题 01 核心概念概述 09 注意力热力图 02 位置编码演变史 10 高效实现 03 复数域视角 11 计算流程 04 数学推导 12 方法对比 05 2D 旋转 13 RoPE 变体 06 多维旋转 14 代码实现 07 矩阵结构 15 总结 08 频率谱 — — "#11-%E8%AE%A1%E7%AE%97%E6%B5%81%E7%A8%8B") 计算流程
[04](#章节 主题 章节 主题 01 核心概念概述 09 注意力热力图 02 位置编码演变史 10 高效实现 03 复数域视角 11 计算流程 04 数学推导 12 方法对比 05 2D 旋转 13 RoPE 变体 06 多维旋转 14 代码实现 07 矩阵结构 15 总结 08 频率谱 — — "#04-%E6%95%B0%E5%AD%A6%E6%8E%A8%E5%AF%BC") 数学推导 [12](#章节 主题 章节 主题 01 核心概念概述 09 注意力热力图 02 位置编码演变史 10 高效实现 03 复数域视角 11 计算流程 04 数学推导 12 方法对比 05 2D 旋转 13 RoPE 变体 06 多维旋转 14 代码实现 07 矩阵结构 15 总结 08 频率谱 — — "#12-%E6%96%B9%E6%B3%95%E5%AF%B9%E6%AF%94") 方法对比
[05](#章节 主题 章节 主题 01 核心概念概述 09 注意力热力图 02 位置编码演变史 10 高效实现 03 复数域视角 11 计算流程 04 数学推导 12 方法对比 05 2D 旋转 13 RoPE 变体 06 多维旋转 14 代码实现 07 矩阵结构 15 总结 08 频率谱 — — "#05-2d%E6%97%8B%E8%BD%AC") 2D 旋转 [13](#章节 主题 章节 主题 01 核心概念概述 09 注意力热力图 02 位置编码演变史 10 高效实现 03 复数域视角 11 计算流程 04 数学推导 12 方法对比 05 2D 旋转 13 RoPE 变体 06 多维旋转 14 代码实现 07 矩阵结构 15 总结 08 频率谱 — — "#13-rope%E5%8F%98%E4%BD%93%E4%B8%8E%E6%89%A9%E5%B1%95") RoPE 变体
[06](#章节 主题 章节 主题 01 核心概念概述 09 注意力热力图 02 位置编码演变史 10 高效实现 03 复数域视角 11 计算流程 04 数学推导 12 方法对比 05 2D 旋转 13 RoPE 变体 06 多维旋转 14 代码实现 07 矩阵结构 15 总结 08 频率谱 — — "#06-%E5%A4%9A%E7%BB%B4%E6%97%8B%E8%BD%AC") 多维旋转 [14](#章节 主题 章节 主题 01 核心概念概述 09 注意力热力图 02 位置编码演变史 10 高效实现 03 复数域视角 11 计算流程 04 数学推导 12 方法对比 05 2D 旋转 13 RoPE 变体 06 多维旋转 14 代码实现 07 矩阵结构 15 总结 08 频率谱 — — "#14-%E4%BB%A3%E7%A0%81%E5%AE%9E%E7%8E%B0") 代码实现
[07](#章节 主题 章节 主题 01 核心概念概述 09 注意力热力图 02 位置编码演变史 10 高效实现 03 复数域视角 11 计算流程 04 数学推导 12 方法对比 05 2D 旋转 13 RoPE 变体 06 多维旋转 14 代码实现 07 矩阵结构 15 总结 08 频率谱 — — "#07-%E7%9F%A9%E9%98%B5%E7%BB%93%E6%9E%84") 矩阵结构 [15](#章节 主题 章节 主题 01 核心概念概述 09 注意力热力图 02 位置编码演变史 10 高效实现 03 复数域视角 11 计算流程 04 数学推导 12 方法对比 05 2D 旋转 13 RoPE 变体 06 多维旋转 14 代码实现 07 矩阵结构 15 总结 08 频率谱 — — "#15-%E6%80%BB%E7%BB%93") 总结
[08](#章节 主题 章节 主题 01 核心概念概述 09 注意力热力图 02 位置编码演变史 10 高效实现 03 复数域视角 11 计算流程 04 数学推导 12 方法对比 05 2D 旋转 13 RoPE 变体 06 多维旋转 14 代码实现 07 矩阵结构 15 总结 08 频率谱 — — "#08-%E9%A2%91%E7%8E%87%E8%B0%B1") 频率谱 --- ---

01 核心概念概述

RoPE(Rotary Position Embedding)是一种基于复数旋转的位置编码方法 ,由苏剑林等人在 2021 年论文《RoFormer》中提出。它通过旋转矩阵将位置信息注入到词向量中,使模型能够捕捉序列中 token 的相对位置关系

✦ 核心特性

特性 说明
💡 核心思想 利用二维平面旋转的几何性质,将位置编码转化为向量旋转角度。位置 m 的向量相当于位置 0 的向量旋转了 m·θ 角度
🔄 旋转矩阵 每个位置 m 对应一个旋转矩阵 R(m),通过 cossin 函数构造正交变换,保持向量长度不变
📐 相对位置 两个位置的注意力得分仅依赖于相对位置差 (m-n),而非绝对位置值
🌊 频率设计 使用几何级数递减的频率 θᵢ = 10000^(-2i/d),不同维度捕捉不同尺度的位置信息
计算高效 无需构造完整的 d×d 旋转矩阵,通过 Hadamard 积和 rotate_half 操作即可高效实现
🚀 长度外推 由于相对位置性质,RoPE 具有天然的长度外推能力,训练时未见过的序列长度也能正确处理

✦ 关键指标

scss 复制代码
┌──────────────┬──────────────┬──────────────┬──────────────┬──────────────┐
│  提出年份    │  计算复杂度  │  可学习参数  │  旋转子空间  │  频率基数    │
├──────────────┼──────────────┼──────────────┼──────────────┼──────────────┤
│    2021      │    O(d)      │      0       │    d/2       │   10000      │
└──────────────┴──────────────┴──────────────┴──────────────┴──────────────┘

02 位置编码演变史

位置编码的发展经历了从 绝对位置相对位置旋转位置 的演变过程。

timeline title 位置编码技术演进 2017 : Transformer 正弦位置编码 : 不同频率的正弦/余弦函数 2018 : BERT 可学习位置编码 : 通过反向传播学习位置嵌入 2020 : T5 相对位置编码 : 注意力计算中加入相对位置偏置 2021 : RoPE 旋转位置编码 ⭐ : 基于复数旋转的位置编码 : 兼具绝对和相对位置编码优点

① 绝对位置编码(加法)

最直接的方法:将位置向量直接加到词向量上。无法表达相对位置关系

math 复制代码
x̂ₘ = xₘ + pₘ

② 正弦位置编码(Sinusoidal)

Transformer 原论文方法:用不同频率的正弦余弦函数构造位置向量。

math 复制代码
PE(pos, 2i) = sin(pos / 10000^(2i/d))

③ 可学习位置编码(Learned)

BERT 使用的方法:将位置嵌入作为可学习参数。无法外推

math 复制代码
pₘ = Embedding(m)  ← 通过反向传播学习

④ 相对位置编码(Relative)

在注意力计算中直接使用相对位置偏置。能捕捉相对关系但增加开销

math 复制代码
Attention = softmax((QKᵀ + aᵢⱼ) / √d)

⑤ 旋转位置编码(RoPE)⭐

通过旋转变换将位置信息融入 Q/K 向量,注意力得分自然仅依赖相对位置

math 复制代码
q̂ₘ = R(m)·q     k̂ₙ = R(n)·k     ⟨q̂ₘ, k̂ₙ⟩ = f(q, k, m-n)

03 复数域视角与欧拉公式

RoPE 的数学基础建立在复数旋转之上。理解欧拉公式和复数乘法的几何意义,是深入理解 RoPE 的关键。

graph LR A[二维向量 x₁, x₂] --> B[映射为复数 z = x₁ + i·x₂] B --> C[乘以 e^(i·m·θ)] C --> D[复数旋转 m·θ 角度] D --> E[映射回二维向量 x̂₁, x̂₂] style A fill:#0ea5e9,color:#fff style B fill:#8b5cf6,color:#fff style C fill:#ec4899,color:#fff style D fill:#10b981,color:#fff style E fill:#f59e0b,color:#fff

📌 欧拉公式

math 复制代码
e^(iθ) = cos(θ) + i·sin(θ)

📌 复数乘法 = 旋转

math 复制代码
z · e^(iθ) = |z| · e^(i(φ+θ))

💡 乘以 e^(iθ) 等价于将复数 z 旋转 θ 角度

📌 向量到复数的映射

math 复制代码
(x₁, x₂) ↔ x₁ + i·x₂

📌 RoPE 的复数表示

math 复制代码
f(q, m) = q · e^(i·m·θ)

💡 位置 m 的编码 = 原始向量旋转 m·θ 角度


04 数学推导

注意力得分的相对位置性质

RoPE 的核心目标是使注意力得分仅依赖于相对位置 (m-n),而非绝对位置 mn

graph TB A[输入向量 q, k] --> B[应用旋转 R(m), R(n)] B --> C[q̂ₘ = R(m)·q] B --> D[k̂ₙ = R(n)·k] C --> E[计算内积 ⟨q̂ₘ, k̂ₙ⟩] D --> E E --> F[qᵀ · R(m)ᵀ · R(n) · k] F --> G[R(m)ᵀ · R(n) = R(n-m)] G --> H[注意力得分仅依赖 n-m ✅] style A fill:#0ea5e9,color:#fff style H fill:#10b981,color:#fff

推导步骤

Step 1 --- 定义旋转后的 Q 和 K:

math 复制代码
q̂ₘ = R(m)·q
k̂ₙ = R(n)·k

Step 2 --- 计算注意力得分:

math 复制代码
⟨q̂ₘ, k̂ₙ⟩ = (R(m)·q)ᵀ · (R(n)·k)
          = qᵀ · R(m)ᵀ · R(n) · k
          = qᵀ · R(n-m) · k

Step 3 --- 关键性质:

math 复制代码
R(m)ᵀ · R(n) = R(n-m)

Step 4 --- 结论:

✅ 注意力得分仅依赖于相对位置 (n-m)

旋转矩阵的性质

性质 公式 说明
正交性 R(θ)ᵀ · R(θ) = I 旋转矩阵是正交矩阵
可加性 R(θ₁) · R(θ₂) = R(θ₁ + θ₂) 旋转角度可叠加
逆矩阵 R(θ)⁻¹ = R(-θ) = R(θ)ᵀ 逆旋转等于反向旋转

05 2D 旋转

二维旋转矩阵

对于二维向量 (x₁, x₂),位置 m 的旋转矩阵为:

graph LR A["向量 (x₁, x₂)"] --> B["旋转矩阵 R(m)"] B --> C["旋转角度 m·θ"] C --> D["新向量 (x̂₁, x̂₂)"] subgraph 旋转矩阵 B end subgraph 旋转操作 C end style A fill:#0ea5e9,color:#fff style B fill:#8b5cf6,color:#fff style C fill:#ec4899,color:#fff style D fill:#10b981,color:#fff
math 复制代码
        [ cos(mθ)   -sin(mθ) ]
R(m) =  [ sin(mθ)    cos(mθ) ]

旋转后的向量

math 复制代码
[ x̂₁ ]   [ cos(mθ)   -sin(mθ) ] [ x₁ ]
[ x̂₂ ] = [ sin(mθ)    cos(mθ) ] [ x₂ ]

展开后:

math 复制代码
x̂₁ = x₁·cos(mθ) - x₂·sin(mθ)
x̂₂ = x₁·sin(mθ) + x₂·cos(mθ)

几何意义

  • 📍 原始向量 (x₁, x₂) 在二维平面上旋转了 m·θ 角度
  • 📏 旋转后向量的长度保持不变:|x̂| = |x|
  • 🔄 旋转角度与位置 m 成正比

06 多维旋转

多维向量的旋转分解

对于 d 维向量,RoPE 将其分解为 d/2 个二维旋转子空间:

graph TB A["d 维向量"] --> B["分解为 d/2 个二维子空间"] B --> C1["子空间 1: (x₁, x₂) → θ₀"] B --> C2["子空间 2: (x₃, x₄) → θ₁"] B --> C3["子空间 3: (x₅, x₆) → θ₂"] B --> C4["..."] B --> C5["子空间 d/2: (x_{d-1}, x_d) → θ_{d/2-1}"] C1 --> D1["旋转 m·θ₀"] C2 --> D2["旋转 m·θ₁"] C3 --> D3["旋转 m·θ₂"] C5 --> D5["旋转 m·θ_{d/2-1}"] D1 --> E["拼接回 d 维向量"] D2 --> E D3 --> E D5 --> E style A fill:#0ea5e9,color:#fff style B fill:#8b5cf6,color:#fff style E fill:#10b981,color:#fff
math 复制代码
x = (x₁, x₂, x₃, x₄, ..., x_{d-1}, x_d)
  → [(x₁, x₂), (x₃, x₄), ..., (x_{d-1}, x_d)]

每个子空间的频率

math 复制代码
θᵢ = base^(-2i/d)    i = 0, 1, ..., d/2 - 1

其中 base = 10000(默认值)

多维旋转矩阵

math 复制代码
R(m) = diag(R₁(m), R₂(m), ..., R_{d/2}(m))

其中每个 Rᵢ(m)2×2 的旋转矩阵:

math 复制代码
         [ cos(mθᵢ)   -sin(mθᵢ) ]
Rᵢ(m) =  [ sin(mθᵢ)    cos(mθᵢ) ]

频率分布特点

维度类型 索引 频率特征 捕捉能力
低频维度 i θᵢ 大,旋转角度变化快 短距离位置关系
高频维度 i θᵢ 小,旋转角度变化慢 长距离位置关系

07 矩阵结构

完整旋转矩阵

对于 d 维向量,完整的旋转矩阵是 d×d 的分块对角矩阵:

graph TB A["完整 d×d 旋转矩阵"] --> B["分块对角结构"] B --> C1["R₁(m) 2×2"] B --> C2["R₂(m) 2×2"] B --> C3["R₃(m) 2×2"] B --> C4["..."] B --> C5["R_{d/2}(m) 2×2"] C1 --> D["稀疏矩阵优化"] C2 --> D C3 --> D C5 --> D D --> E["Hadamard 积 + rotate_half"] E --> F["O(d) 复杂度"] style A fill:#0ea5e9,color:#fff style D fill:#f59e0b,color:#fff style F fill:#10b981,color:#fff
math 复制代码
        [ R₁(m)   0       0      ...   0      ]
        [ 0       R₂(m)   0      ...   0      ]
R(m) =  [ 0       0       R₃(m)  ...   0      ]
        [ ...     ...     ...    ...   ...    ]
        [ 0       0       0      ...   R_{d/2}(m) ]

稀疏性利用

实际实现中,不需要构造完整的 d×d 矩阵,而是利用其稀疏性:

math 复制代码
x̂ = x ⊙ cos(mθ) + rotate_half(x) ⊙ sin(mθ)
  • 表示 Hadamard 积(逐元素乘法)
  • rotate_half(x) 将向量后半部分取负后与前半部分拼接

rotate_half 操作

math 复制代码
rotate_half(x) = [-x_{d/2+1}, -x_{d/2+2}, ..., -x_d, x₁, x₂, ..., x_{d/2}]

08 频率谱

频率计算公式

graph LR A["base = 10000"] --> B["θᵢ = base^(-2i/d)"] B --> C1["i=0: θ₀ = 1.0"] B --> C2["i=1: θ₁ = 0.1"] B --> C3["i=2: θ₂ = 0.01"] B --> C4["i=3: θ₃ = 0.001"] C1 --> D["几何级数递减"] C2 --> D C3 --> D C4 --> D D --> E["多尺度位置信息捕捉"] style A fill:#0ea5e9,color:#fff style B fill:#8b5cf6,color:#fff style E fill:#10b981,color:#fff
math 复制代码
θᵢ = base^(-2i/d)    i = 0, 1, ..., d/2 - 1

频率分布表(d=8, base=10000)

维度对索引 i 频率 θᵢ 旋转角度 (m=1)
0 1.0000 1.0000 rad
1 0.1000 0.1000 rad
2 0.0100 0.0100 rad
3 0.0010 0.0010 rad

频率设计原理

  • 📉 几何级数递减 :频率按 base^(-2/d) 的比例递减
  • 🎯 多尺度捕捉:不同频率捕捉不同尺度的位置信息
  • 🔭 外推能力:高频维度保证长距离位置关系的可区分性

频率可视化

频率谱呈现指数衰减特性:

  • 横轴 :维度对索引 i
  • 纵轴 :频率 θᵢ(对数尺度)
  • 曲线θᵢ = base^(-2i/d)

09 注意力热力图

注意力得分计算

应用 RoPE 后,注意力得分矩阵的计算:

math 复制代码
Attention = softmax(Q̂K̂ᵀ / √d)

其中 是应用旋转后的 Q 和 K。

相对位置特性

注意力得分矩阵具有以下特性:

math 复制代码
A[m,n] = f(q_m, k_n, m-n)

即得分仅依赖于相对位置 (m-n),而非绝对位置 mn

热力图特征

区域 特征
对角线 相同位置 (m=n) 的注意力得分
平行于对角线 相同相对位置的得分相似
远离对角线 注意力得分通常衰减

10 高效实现

核心优化策略

  1. 避免完整矩阵乘法 :使用 Hadamard 积替代 d×d 矩阵乘法
  2. 预计算缓存cossin 值可预计算并缓存
  3. 批量处理:支持动态序列长度的高效批量计算

rotate_half 实现

python 复制代码
def rotate_half(x):
    """将向量后半部分取负后与前半部分拼接"""
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat([-x2, x1], dim=-1)

应用 RoPE

python 复制代码
def apply_rope(x, cos, sin):
    """应用旋转位置编码"""
    return x * cos + rotate_half(x) * sin

复杂度分析

操作 复杂度 说明
完整矩阵乘法 O(d²) 构造 d×d 旋转矩阵
Hadamard 积 O(d) 逐元素乘法
rotate_half O(d) 向量重排
总复杂度 O(d) ✅ 高效实现

11 计算流程

完整计算步骤

flowchart TD A[&#34;输入 x [batch, seq_len, dim]&#34;] --> B[&#34;Step 1: 预计算频率<br/>inv_freq[i] = base^(-2i/d)&#34;] B --> C[&#34;Step 2: 计算位置频率<br/>freq[m, i] = m × inv_freq[i]&#34;] C --> D[&#34;Step 3: 计算 cos 和 sin<br/>cos[m, i] = cos(freq[m, i])<br/>sin[m, i] = sin(freq[m, i])&#34;] D --> E[&#34;Step 4: 扩展到完整维度<br/>cos = cat([cos, cos], dim=-1)<br/>sin = cat([sin, sin], dim=-1)&#34;] E --> F[&#34;Step 5: 应用旋转<br/>x̂ = x * cos + rotate_half(x) * sin&#34;] F --> G[&#34;输出 x̂ [batch, seq_len, dim]&#34;] style A fill:#0ea5e9,color:#fff style B fill:#8b5cf6,color:#fff style C fill:#8b5cf6,color:#fff style D fill:#8b5cf6,color:#fff style E fill:#8b5cf6,color:#fff style F fill:#ec4899,color:#fff style G fill:#10b981,color:#fff

数据流图

r 复制代码
输入 x [batch, seq_len, dim]
         │
         ▼
预计算 cos, sin [seq_len, dim]
         │
         ▼
应用 rotate_half
         │
         ▼
Hadamard 积 + 加法
         │
         ▼
输出 x̂ [batch, seq_len, dim]

12 方法对比

位置编码方法对比

方法 相对位置 外推能力 可学习参数 复杂度 代表模型
绝对位置(加法) 0 O(d) ---
正弦位置编码 0 O(d) Transformer
可学习位置编码 O(N×d) O(d) BERT
相对位置编码 ⚠️ O(N²) O(N²) T5
RoPE 0 O(d) LLaMA, PaLM

RoPE 的优势

优势 说明
🎯 相对位置感知 注意力得分仅依赖相对位置
📏 长度外推 可处理训练时未见过的序列长度
🔒 零参数 无需学习,固定计算
计算高效 O(d) 复杂度,支持预计算
📐 数学优雅 基于复数旋转的清晰理论框架

13 RoPE 变体与扩展

graph TB A[&#34;RoPE 基础版本&#34;] --> B[&#34;PI 位置插值&#34;] A --> C[&#34;NTK-aware 插值&#34;] A --> D[&#34;YaRN 扩展&#34;] A --> E[&#34;Dynamic NTK&#34;] A --> F[&#34;LongRoPE&#34;] B --> B1[&#34;线性缩放位置索引<br/>简单有效&#34;] C --> C1[&#34;调整 base 基数<br/>高频保持,低频插值&#34;] D --> D1[&#34;NTK + 温度调整<br/>高频/低频分组处理&#34;] E --> E1[&#34;推理时动态调整<br/>无需微调&#34;] F --> F1[&#34;进化搜索最优方案<br/>支持 128K+ 上下文&#34;] style A fill:#0ea5e9,color:#fff style B fill:#8b5cf6,color:#fff style C fill:#8b5cf6,color:#fff style D fill:#8b5cf6,color:#fff style E fill:#8b5cf6,color:#fff style F fill:#8b5cf6,color:#fff

📌 PI --- Position Interpolation

通过线性插值将位置索引映射到训练长度范围内,而非外推。简单有效,是早期长度扩展的主流方法。

math 复制代码
θ'ᵢ = θᵢ × (L_train / L_target)
位置映射:m' = m × (L_train / L_target)

📌 NTK-aware 插值

通过调整基数 base 来改变频率分布,而非简单缩放位置。高频维度保持不变,低频维度被插值,效果优于 PI。

math 复制代码
base' = base × (scale_factor)^(d/(d-2))
θ'ᵢ = base'^(-2i/d)

📌 YaRN --- Yet another RoPE extensioN

结合 NTK-aware 插值和注意力温度调整。将维度分为高频/低频两组,分别处理,并引入注意力温度修正。

math 复制代码
高频维度:保持原始频率
低频维度:应用 NTK 插值
注意力温度:t = 0.1 × ln(scale) + 1

📌 Dynamic NTK --- 动态 NTK

在推理时动态调整基数 base,当序列长度超过训练长度时自动缩放频率。无需微调即可实现长度扩展。

math 复制代码
base' = base × (max_seq_len / L_train)^(d/(d-2))
推理时按当前长度动态计算

📌 LongRoPE --- 长上下文 RoPE

通过进化搜索寻找最优的位置插值方案,结合非均匀位置插值和维度无关的缩放因子。

math 复制代码
搜索最优缩放因子 λᵢ(每维度独立)
支持 128K+ 上下文长度

14 代码实现

① 基础实现

python 复制代码
import torch
import torch.nn as nn
import math


class RoPEEncoding(nn.Module):
    """基础 RoPE 旋转位置编码实现"""

    def __init__(self, dim: int, max_seq_len: int = 2048, base: float = 10000):
        super().__init__()
        # 计算频率 θᵢ = base^(-2i/d)
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)

        # 预计算位置编码
        t = torch.arange(max_seq_len).float()
        freqs = torch.outer(t, inv_freq)          # [seq_len, dim/2]
        emb = torch.cat([freqs, freqs], dim=-1)   # [seq_len, dim]
        self.register_buffer('cos_cached', emb.cos())
        self.register_buffer('sin_cached', emb.sin())

    def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
        """将向量后半部分取负后与前半部分拼接"""
        x1, x2 = x.chunk(2, dim=-1)
        return torch.cat([-x2, x1], dim=-1)

    def forward(self, x: torch.Tensor, seq_len: int = None) -> torch.Tensor:
        """
        参数:
            x: [batch, seq_len, dim]
            seq_len: 序列长度,默认使用 x 的实际长度
        返回:
            应用 RoPE 后的张量
        """
        if seq_len is None:
            seq_len = x.shape[1]
        cos = self.cos_cached[:seq_len].unsqueeze(0)
        sin = self.sin_cached[:seq_len].unsqueeze(0)
        return x * cos + self.rotate_half(x) * sin

② 批量高效版

python 复制代码
import torch
import torch.nn as nn


class EfficientRoPE(nn.Module):
    """高效批量 RoPE 实现,支持动态序列长度"""

    def __init__(self, dim: int, base: float = 10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self._seq_len_cached = 0
        self._cos_cached = None
        self._sin_cached = None

    def _update_cache(self, seq_len: int, device=None):
        """按需更新缓存,避免重复计算"""
        if seq_len != self._seq_len_cached:
            self._seq_len_cached = seq_len
            t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
            freqs = torch.outer(t, self.inv_freq)
            emb = torch.cat([freqs, freqs], dim=-1)
            self._cos_cached = emb.cos()
            self._sin_cached = emb.sin()
        return self._cos_cached, self._sin_cached

    def forward(self, q: torch.Tensor, k: torch.Tensor):
        """对 Q 和 K 分别应用 RoPE"""
        seq_len = q.shape[1]
        cos, sin = self._update_cache(seq_len, q.device)
        cos = cos.unsqueeze(0)  # [1, seq, dim]
        sin = sin.unsqueeze(0)

        def apply(x: torch.Tensor) -> torch.Tensor:
            x1, x2 = x.chunk(2, dim=-1)
            return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)

        return apply(q), apply(k)

③ LLaMA 风格

python 复制代码
import torch
import torch.nn as nn
from typing import Optional, Tuple


def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0) -> torch.Tensor:
    """
    预计算复数形式的旋转频率(LLaMA 风格)

    参数:
        dim: 向量维度
        end: 最大序列长度
        theta: 基础频率基数,默认 10000
    返回:
        复数形式的频率张量
    """
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, dtype=torch.float32)
    freqs = torch.outer(t, freqs)
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # e^(iθ)
    return freqs_cis


def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    使用复数乘法应用旋转位置编码

    参数:
        xq: Query 张量 [batch, seq_len, dim]
        xk: Key 张量 [batch, seq_len, dim]
        freqs_cis: 预计算的复数频率
    返回:
        应用 RoPE 后的 (xq_out, xk_out)
    """
    # 将 Q 和 K 重塑为复数形式
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))

    # 复数乘法 = 旋转
    freqs_cis = freqs_cis[None, :xq_.shape[1], None, :]
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)

    return xq_out.type_as(xq), xk_out.type_as(xk)

15 总结

核心贡献

维度 说明
🎯 核心贡献 RoPE 通过旋转矩阵将绝对位置信息编码为向量旋转,使注意力得分仅依赖相对位置,兼具绝对和相对位置编码的优点
📐 数学优雅 基于复数旋转的数学框架,利用欧拉公式 e^(iθ) 实现位置编码,理论清晰,推导优美
工程高效 通过 rotate_half 和 Hadamard 积实现 O(d) 复杂度,无需构造完整旋转矩阵,支持预计算缓存
🚀 生态影响 被 LLaMA、PaLM、GLM、Qwen 等主流大模型采用,催生了 PI、NTK-aware、YaRN 等长度扩展变体

💎 一句话总结

旋转即位置,相对即自然


关键概念回顾

概念 表示 说明
旋转矩阵 R(m) 位置 m 对应的旋转矩阵
复数表示 e^(imθ) 旋转的复数形式
相对位置 m-n 注意力得分依赖的核心
高效实现 O(d) 计算复杂度
频率基数 10000 默认基础频率基数
外推潜力 理论上的长度外推能力
相关推荐
五号厂房1 小时前
🔥 Claude Code 源码解析(二):揭秘对话引擎的核心机制
人工智能
明月照山海-1 小时前
机器学习周报四十八
人工智能·机器学习
KaMeidebaby1 小时前
卡梅德生物技术快报|细胞周期检测抗原流式分析:参数调试、软件拟合与问题排查
网络·人工智能·python·网络协议·tcp/ip·算法·机器学习
明明如月学长1 小时前
AI 会先淘汰这几类?我最近有个越来越强的判断
人工智能
cyyt1 小时前
深度学习周报(6.1~6.7)
人工智能·深度学习
yaoyouzhong1 小时前
2026 年 GPT 与 Gemini 怎么选?AI 工具适配哪些场景?
人工智能·gpt
码农阿强1 小时前
GPT-Image-2 技术原理与实战:开启推理驱动图像生成新时代
人工智能·gpt·ai·aigc·个人开发
Ajie'Blog1 小时前
Claude Opus 4.8 发布:Claude Code 能不能接住复杂项目
服务器·前端·javascript·人工智能·ai编程