从基础 RoPE 到 YaRN:源码学习路线揭秘

豆包生成。

从基础 RoPE 到 YaRN:源码学习路线揭秘

我给你一套零基础、由浅入深的学习路线先拆核心功能,再屏蔽复杂代码,最后啃进阶优化,保证你能彻底学会这段源码!

先记住核心结论(学习前必背)

这两个函数是大模型的旋转位置编码(RoPE) + 超长文本优化(YaRN)

  1. 作用 :给大模型的文字加「位置信息」,让模型知道词的先后顺序(比如 "我吃饭" 和 "吃我饭" 不一样);

  2. 分工

    • precompute\_freqs\_cis提前算好所有位置的旋转角度(cos/sin 值)(预计算,提速);

    • apply\_rotary\_pos\_emb拿着算好的角度,给 Q/K 矩阵加位置信息(核心应用);

  3. 学习顺序先学基础 RoPE → 再学 YaRN 进阶优化(直接看全文会懵!)


最优学习步骤(3 步走,无压力)

第一步:屏蔽复杂代码,只看【基础 RoPE】

if rope\_scaling is not None: 这段YaRN 超长文本代码全部删掉,只看最纯净的基础版,这是 90% 的核心逻辑!

纯净版基础代码(先学这个)

python 复制代码
import torch
import math

# 1. 预计算cos/sin(核心:算位置旋转角度)
def precompute_freqs_cis(dim: int, end: int = 32768, rope_base: float = 1e6):
    # 🔥 第一步:生成基础频率(数学公式,不用背,知道是算角度就行)
    freqs = 1.0 / (rope_base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    # 🔥 第二步:生成位置序列 [0,1,2,...,end-1]
    t = torch.arange(end, device=freqs.device)
    # 🔥 第三步:外积 → 位置 × 频率 = 每个位置的旋转角度
    freqs = torch.outer(t, freqs).float()
    # 🔥 第四步:算cos/sin,拼接维度(适配模型特征维度)
    freqs_cos = torch.cat([torch.cos(freqs), torch.cos(freqs)], dim=-1)
    freqs_sin = torch.cat([torch.sin(freqs), torch.sin(freqs)], dim=-1)
    return freqs_cos, freqs_sin

# 2. 应用位置编码(核心:给Q/K旋转,加位置信息)
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
    # 🔥 旋转函数:把特征分成两半,后半部分取反(RoPE固定操作)
    def rotate_half(x): 
        return torch.cat((-x[..., x.shape[-1] // 2:], x[..., : x.shape[-1] // 2]), dim=-1)
    
    # 🔥 核心公式:给Q/K做旋转编码(RoPE标准计算)
    q_embed = (q * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(q) * sin.unsqueeze(unsqueeze_dim))
    k_embed = (k * cos.unsqueeze(unsqueeze_dim)) + (rotate_half(k) * sin.unsqueeze(unsqueeze_dim))
    return q_embed.to(q.dtype), k_embed.to(k.dtype)

逐行大白话解释(不用懂数学!)

函数 1:precompute_freqs_cis(预计算)
代码行 大白话作用
dim 模型特征维度(比如 768)
end 最大文本长度(比如 32768)
freqs = 1\.0 / \.\.\. 旋转频率(不同维度,旋转速度不同)
t = torch\.arange\(end\) 生成位置编号(第 1 个词 = 0,第 2 个 = 1...)
torch\.outer\(t, freqs\) 位置 × 频率 = 每个词的旋转角度
torch\.cos/sin 把角度转成 cos/sin 值(模型计算用)
函数 2:apply_rotary_pos_emb(应用编码)
代码行 大白话作用
rotate\_half RoPE 固定操作:特征分两半,后半部分取反
q \* cos \+ rotate\_half\(q\) \* sin 给 Q 矩阵旋转,注入位置信息
k \* cos \+ rotate\_half\(k\) \* sin 给 K 矩阵旋转,注入位置信息

第二步:看懂【YaRN 进阶优化】(原代码的复杂部分)

基础 RoPE 只能处理短文本,YaRN 是让模型能看懂超长文本(比如 10 万字) 的优化算法,就是原代码里的:

python 复制代码
if rope_scaling is not None:
    # 一堆YaRN逻辑...
    freqs = freqs * (1 - ramp + ramp / factor)

极简理解(不用啃数学公式)

  1. 问题:基础 RoPE 超过最大长度(比如 2048)就失效;

  2. YaRN 解决 :给高频维度(细粒度位置) 缩放频率,低频维度(粗粒度位置) 不缩放;

  3. 代码逻辑

    • ramp:生成一个 0~1 的线性斜坡(低维度 = 0 不缩放,高维度 = 1 缩放);

    • freqs = freqs \* \.\.\.:按斜坡缩放频率,适配超长文本;

  4. 结论 :这是工程优化,不是核心原理,看懂基础 RoPE 即可!


第三步:跑一个小例子(直观验证)

用最小的维度跑一遍,马上懂效果:

python 复制代码
# 1. 配置参数
dim = 4    # 特征维度(极小值)
max_len = 5 # 文本长度(5个词)

# 2. 预计算cos/sin
cos, sin = precompute_freqs_cis(dim=dim, end=max_len)
print("cos形状:", cos.shape) # (5,4) → 5个位置,每个位置4维cos值

# 3. 模拟Q/K矩阵 (批次=1,头数=1,长度=5,维度=4)
q = torch.randn(1, 1, 5, 4)
k = torch.randn(1, 1, 5, 4)

# 4. 应用RoPE
q_rot, k_rot = apply_rotary_pos_emb(q, k, cos, sin)
print("旋转后Q形状:", q_rot.shape) # 形状不变,数值被旋转(加了位置信息)

学习源码的终极技巧(针对这类复杂代码)

  1. 先删复杂分支 :把 if rope\_scaling 删掉,只看核心逻辑;

  2. 先懂功能,再懂数学:RoPE 就是「给特征旋转加位置」,数学公式不用深究;

  3. 分函数理解

    • 一个预计算(算好角度存起来);

    • 一个应用(拿角度旋转 Q/K);

  4. 小例子验证:用极小维度跑通,比看 10 遍代码都管用。


总结

  1. 核心功能 :RoPE 给大模型加位置信息 ,YaRN 优化超长文本

  2. 两个函数分工:预计算 cos/sin → 应用到 Q/K 矩阵;

  3. 学习关键:先学纯净版 RoPE,再看 YaRN 优化,别一上来硬啃全代码!

按照这个路线,你 10 分钟就能彻底吃透这段源码~

(注:文档部分内容可能由 AI 生成)

相关推荐
Gigavision18 小时前
rPPGMamba:面向 PURE-UBFC-MMPD 跨被试远程生理感知的 Mamba 时序建模方案
python·深度学习·rppg
星夜夏空9919 小时前
STM32单片机学习(15) —— PC串口通信实验
stm32·单片机·学习
网络工程小王19 小时前
【大模型vLLM 使用】学习笔记
笔记·学习·llama
初心未改HD19 小时前
深度学习之优化器详解
人工智能·深度学习
星夜夏空9919 小时前
STM32单片机学习(14) —— STM32的串口外设
stm32·单片机·学习
栉甜19 小时前
APIs学习
前端·javascript·css·学习·html
吃好睡好便好19 小时前
说说梳头的保健作用
学习
love530love19 小时前
ComfyUI:为什么说它是 AIGC 应用层面的集大成者?
人工智能·pytorch·windows·aigc·devops·comfyui·extensions
wuxinyan12319 小时前
工业级大模型学习之路013:RAG零基础入门教程(第九篇):RAG幻觉治理
人工智能·学习·rag
AI技术控20 小时前
Transformer 的 Encoder 和 Decoder 模块介绍:从结构原理到大模型应用实践
人工智能·python·深度学习·自然语言处理·transformer