从基础 RoPE 到 YaRN:源码学习路线揭秘

豆包生成。

从基础 RoPE 到 YaRN:源码学习路线揭秘

我给你一套零基础、由浅入深的学习路线先拆核心功能,再屏蔽复杂代码,最后啃进阶优化,保证你能彻底学会这段源码!

先记住核心结论(学习前必背)

这两个函数是大模型的旋转位置编码(RoPE) + 超长文本优化(YaRN)

  1. 作用 :给大模型的文字加「位置信息」,让模型知道词的先后顺序(比如 "我吃饭" 和 "吃我饭" 不一样);

  2. 分工

    • precompute\_freqs\_cis提前算好所有位置的旋转角度(cos/sin 值)(预计算,提速);

    • apply\_rotary\_pos\_emb拿着算好的角度,给 Q/K 矩阵加位置信息(核心应用);

  3. 学习顺序先学基础 RoPE → 再学 YaRN 进阶优化(直接看全文会懵!)


最优学习步骤(3 步走,无压力)

第一步:屏蔽复杂代码,只看【基础 RoPE】

if rope\_scaling is not None: 这段YaRN 超长文本代码全部删掉,只看最纯净的基础版,这是 90% 的核心逻辑!

纯净版基础代码(先学这个)

python 复制代码
import torch
import math

# 1. 预计算cos/sin(核心:算位置旋转角度)
def precompute_freqs_cis(dim: int, end: int = 32768, rope_base: float = 1e6):
    # 🔥 第一步:生成基础频率(数学公式,不用背,知道是算角度就行)
    freqs = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    # 🔥 第二步:生成位置序列 [0,1,2,...,end-1]
    t = torch.arange(end, device=freqs.device)
    # 🔥 第三步:外积 → 位置 × 频率 = 每个位置的旋转角度
    freqs = torch.outer(t, freqs).float()
    # 🔥 第四步:算cos/sin,拼接维度(适配模型特征维度)
    freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
    freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1)
    return freqs_cos, freqs_sin

# 2. 应用位置编码(核心:给Q/K旋转,加位置信息)
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
    # 🔥 旋转函数:把特征分成两半,后半部分取反(RoPE固定操作)
    def rotate_half(x): 
        return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
    
    # 🔥 核心公式:给Q/K做旋转编码(RoPE标准计算)
    q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))
    k_embed = (k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))
    return q_embed.to(q.dtype), k_embed.to(k.dtype)

逐行大白话解释(不用懂数学!)

函数 1:precompute_freqs_cis(预计算)
代码行 大白话作用
dim 模型特征维度(比如 768)
end 最大文本长度(比如 32768)
freqs = 1\.0 / \.\.\. 旋转频率(不同维度,旋转速度不同)
t = torch\.arange\(end\) 生成位置编号(第 1 个词 = 0,第 2 个 = 1...)
torch\.outer\(t, freqs\) 位置 × 频率 = 每个词的旋转角度
torch\.cos/sin 把角度转成 cos/sin 值(模型计算用)
函数 2:apply_rotary_pos_emb(应用编码)
代码行 大白话作用
rotate\_half RoPE 固定操作:特征分两半,后半部分取反
q \* cos \+ rotate\_half\(q\) \* sin 给 Q 矩阵旋转,注入位置信息
k \* cos \+ rotate\_half\(k\) \* sin 给 K 矩阵旋转,注入位置信息

第二步:看懂【YaRN 进阶优化】(原代码的复杂部分)

基础 RoPE 只能处理短文本,YaRN 是让模型能看懂超长文本(比如 10 万字) 的优化算法,就是原代码里的:

python 复制代码
if rope_scaling is not None:
    # 一堆YaRN逻辑...
    freqs = freqs * (1 - ramp + ramp / factor)

极简理解(不用啃数学公式)

  1. 问题:基础 RoPE 超过最大长度(比如 2048)就失效;

  2. YaRN 解决 :给高频维度(细粒度位置) 缩放频率,低频维度(粗粒度位置) 不缩放;

  3. 代码逻辑

    • ramp:生成一个 0~1 的线性斜坡(低维度 = 0 不缩放,高维度 = 1 缩放);

    • freqs = freqs \* \.\.\.:按斜坡缩放频率,适配超长文本;

  4. 结论 :这是工程优化,不是核心原理,看懂基础 RoPE 即可!


第三步:跑一个小例子(直观验证)

用最小的维度跑一遍,马上懂效果:

python 复制代码
# 1. 配置参数
dim = 4    # 特征维度(极小值)
max_len = 5 # 文本长度(5个词)

# 2. 预计算cos/sin
cos, sin = precompute_freqs_cis(dim=dim, end=max_len)
print("cos形状:", cos.shape) # (5,4) → 5个位置,每个位置4维cos值

# 3. 模拟Q/K矩阵 (批次=1,头数=1,长度=5,维度=4)
q = torch.randn(1, 1, 5, 4)
k = torch.randn(1, 1, 5, 4)

# 4. 应用RoPE
q_rot, k_rot = apply_rotary_pos_emb(q, k, cos, sin)
print("旋转后Q形状:", q_rot.shape) # 形状不变,数值被旋转(加了位置信息)

学习源码的终极技巧(针对这类复杂代码)

  1. 先删复杂分支 :把 if rope\_scaling 删掉,只看核心逻辑;

  2. 先懂功能,再懂数学:RoPE 就是「给特征旋转加位置」,数学公式不用深究;

  3. 分函数理解

    • 一个预计算(算好角度存起来);

    • 一个应用(拿角度旋转 Q/K);

  4. 小例子验证:用极小维度跑通,比看 10 遍代码都管用。


总结

  1. 核心功能 :RoPE 给大模型加位置信息 ,YaRN 优化超长文本

  2. 两个函数分工:预计算 cos/sin → 应用到 Q/K 矩阵;

  3. 学习关键:先学纯净版 RoPE,再看 YaRN 优化,别一上来硬啃全代码!

按照这个路线,你 10 分钟就能彻底吃透这段源码~

(注:文档部分内容可能由 AI 生成)

相关推荐
xuhaoyu_cpp_java3 小时前
MyBatis学习(五)
经验分享·笔记·学习·mybatis
ECT-OS-JiuHuaShan3 小时前
整体论体系定理,全球开放,无法绕过
人工智能·科技·学习·算法·生活
淘矿人4 小时前
2026年4月-DeepSeek V4 vs GPT-5.5深度对比测评:weelinking一键切换实测
服务器·数据库·人工智能·python·gpt·学习·php
一只机电自动化菜鸟4 小时前
一建机电备考笔记(27)测量技术—仪器(含考频+题型)
经验分享·笔记·学习·职场和发展·求职招聘·课程设计
xiaoxiaoxiaolll4 小时前
《Light: Science & Applications》SSH模型能带首次在光子芯片上直接读出:混合频率架构赋能拓扑量子模拟
学习
Be for thing4 小时前
Android Studio 常用快捷键总结
android·学习
HackTorjan4 小时前
深度解析雪花算法及其高性能优化策略
人工智能·深度学习·算法·性能优化·dreamweaver
茜子.Java4 小时前
postman 进阶使用教程
学习
爱上好庆祝4 小时前
学习js的第四天
前端·css·学习·html·css3·js