豆包生成。
从基础 RoPE 到 YaRN:源码学习路线揭秘
我给你一套零基础、由浅入深的学习路线 ,先拆核心功能,再屏蔽复杂代码,最后啃进阶优化,保证你能彻底学会这段源码!
先记住核心结论(学习前必背)
这两个函数是大模型的旋转位置编码(RoPE) + 超长文本优化(YaRN):
-
作用 :给大模型的文字加「位置信息」,让模型知道词的先后顺序(比如 "我吃饭" 和 "吃我饭" 不一样);
-
分工:
-
precompute\_freqs\_cis:提前算好所有位置的旋转角度(cos/sin 值)(预计算,提速); -
apply\_rotary\_pos\_emb:拿着算好的角度,给 Q/K 矩阵加位置信息(核心应用);
-
-
学习顺序 :先学基础 RoPE → 再学 YaRN 进阶优化(直接看全文会懵!)
最优学习步骤(3 步走,无压力)
第一步:屏蔽复杂代码,只看【基础 RoPE】
把 if rope\_scaling is not None: 这段YaRN 超长文本代码全部删掉,只看最纯净的基础版,这是 90% 的核心逻辑!
纯净版基础代码(先学这个)
python
import torch
import math
# 1. 预计算cos/sin(核心:算位置旋转角度)
def precompute_freqs_cis(dim: int, end: int = 32768, rope_base: float = 1e6):
# 🔥 第一步:生成基础频率(数学公式,不用背,知道是算角度就行)
freqs = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# 🔥 第二步:生成位置序列 [0,1,2,...,end-1]
t = torch.arange(end, device=freqs.device)
# 🔥 第三步:外积 → 位置 × 频率 = 每个位置的旋转角度
freqs = torch.outer(t, freqs).float()
# 🔥 第四步:算cos/sin,拼接维度(适配模型特征维度)
freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1)
return freqs_cos, freqs_sin
# 2. 应用位置编码(核心:给Q/K旋转,加位置信息)
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
# 🔥 旋转函数:把特征分成两半,后半部分取反(RoPE固定操作)
def rotate_half(x):
return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
# 🔥 核心公式:给Q/K做旋转编码(RoPE标准计算)
q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))
k_embed = (k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))
return q_embed.to(q.dtype), k_embed.to(k.dtype)
逐行大白话解释(不用懂数学!)
函数 1:precompute_freqs_cis(预计算)
| 代码行 | 大白话作用 |
|---|---|
dim |
模型特征维度(比如 768) |
end |
最大文本长度(比如 32768) |
freqs = 1\.0 / \.\.\. |
算旋转频率(不同维度,旋转速度不同) |
t = torch\.arange\(end\) |
生成位置编号(第 1 个词 = 0,第 2 个 = 1...) |
torch\.outer\(t, freqs\) |
位置 × 频率 = 每个词的旋转角度 |
torch\.cos/sin |
把角度转成 cos/sin 值(模型计算用) |
函数 2:apply_rotary_pos_emb(应用编码)
| 代码行 | 大白话作用 |
|---|---|
rotate\_half |
RoPE 固定操作:特征分两半,后半部分取反 |
q \* cos \+ rotate\_half\(q\) \* sin |
给 Q 矩阵旋转,注入位置信息 |
k \* cos \+ rotate\_half\(k\) \* sin |
给 K 矩阵旋转,注入位置信息 |
第二步:看懂【YaRN 进阶优化】(原代码的复杂部分)
基础 RoPE 只能处理短文本,YaRN 是让模型能看懂超长文本(比如 10 万字) 的优化算法,就是原代码里的:
python
if rope_scaling is not None:
# 一堆YaRN逻辑...
freqs = freqs * (1 - ramp + ramp / factor)
极简理解(不用啃数学公式)
-
问题:基础 RoPE 超过最大长度(比如 2048)就失效;
-
YaRN 解决 :给高频维度(细粒度位置) 缩放频率,低频维度(粗粒度位置) 不缩放;
-
代码逻辑:
-
ramp:生成一个 0~1 的线性斜坡(低维度 = 0 不缩放,高维度 = 1 缩放); -
freqs = freqs \* \.\.\.:按斜坡缩放频率,适配超长文本;
-
-
结论 :这是工程优化,不是核心原理,看懂基础 RoPE 即可!
第三步:跑一个小例子(直观验证)
用最小的维度跑一遍,马上懂效果:
python
# 1. 配置参数
dim = 4 # 特征维度(极小值)
max_len = 5 # 文本长度(5个词)
# 2. 预计算cos/sin
cos, sin = precompute_freqs_cis(dim=dim, end=max_len)
print("cos形状:", cos.shape) # (5,4) → 5个位置,每个位置4维cos值
# 3. 模拟Q/K矩阵 (批次=1,头数=1,长度=5,维度=4)
q = torch.randn(1, 1, 5, 4)
k = torch.randn(1, 1, 5, 4)
# 4. 应用RoPE
q_rot, k_rot = apply_rotary_pos_emb(q, k, cos, sin)
print("旋转后Q形状:", q_rot.shape) # 形状不变,数值被旋转(加了位置信息)
学习源码的终极技巧(针对这类复杂代码)
-
先删复杂分支 :把
if rope\_scaling删掉,只看核心逻辑; -
先懂功能,再懂数学:RoPE 就是「给特征旋转加位置」,数学公式不用深究;
-
分函数理解:
-
一个预计算(算好角度存起来);
-
一个应用(拿角度旋转 Q/K);
-
-
小例子验证:用极小维度跑通,比看 10 遍代码都管用。
总结
-
核心功能 :RoPE 给大模型加位置信息 ,YaRN 优化超长文本;
-
两个函数分工:预计算 cos/sin → 应用到 Q/K 矩阵;
-
学习关键:先学纯净版 RoPE,再看 YaRN 优化,别一上来硬啃全代码!
按照这个路线,你 10 分钟就能彻底吃透这段源码~
(注:文档部分内容可能由 AI 生成)