llama源码学习·model.py[3]ROPE旋转位置编码(2)旋转角度生成代码

一、源码注释

python 复制代码
def precompute_freqs_cis(dim: int, end: int, theta: float = 1000.0):
    '''预先计算频率和复数的cosine和sine值,用于后续的Positional Encoding
    dim: 维度
    end: 一个序列的最大长度或位置的最大值
    theta: 用于计算频率的超参数,默认值为1000.0
    '''
    # 生成一个等比数列,即频率(frequencies),这种方法是基于 "Attention is All You Need" 论文中描述的位置编码
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    
    # 生成了一个从0到end的序列
    t = torch.arange(end, device=freqs.device)
    
    # 计算两个向量的外积
    # 结果矩阵的形状是(end, dim//2)
    # 这里的freqs 其实是旋转角度 theta
    freqs = torch.outer(t, freqs).float()
    
    # 将极坐标转换为复数形式
    # torch.polar(r, theta): 是一个函数,它接受两个参数:模 r 和相位 theta,然后返回一个复数,
    #                       该复数的实部为 r * cos(theta),虚部为 r * sin(theta)。
    # torch.ones_like(freqs): 生成一个与 freqs 形状相同的张量,但所有元素都是1,这意味着模r为1。
    # freqs: 它表示每个位置的相位或角度。
    # freqs_cis: 是一个形状为(end, dim//2)的复数矩阵,每个元素都是一个复数,用于后续的位置编码。
    
    # 这行代码实际上为每个位置和每个频率生成了一个复数,其模为1,而相位为我们之前计算的频率。
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
    return freqs_cis

二、源码与公式的对应

第一步:旋转嵌入生成

需要给定一个位置索引 p o s pos pos 和频率向量 f r e q freq freq, 来计算旋转角度 θ = p o s × f r e q \theta = pos \times freq θ=pos×freq

python 复制代码
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))

生成的这个等比数列就是频率向量,这是基于 "Attention is All You Need" 论文中描述的位置编码来实现的

python 复制代码
 t = torch.arange(end, device=freqs.device)

这个长度为 e n d end end 的数列是位置索引 p o s pos pos

python 复制代码
freqs = torch.outer(t, freqs).float()

这一行是在计算两个位置索引 p o s pos pos 和 频率向量 f r e q freq freq 的外积生成旋转角度 θ \theta θ ,不过旋转角度的信息在代码中依旧存储在 f r e q s freqs freqs 这个变量中

python 复制代码
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)

freqs 是旋转角度向量,freqs_cis 使用复数表示的旋转矩阵

三、举例说明

1.假设函数参数

dim = 4 位置编码的维度是 4

end = 3 序列的最大长度是 3

2.生成频率向量 freq

python\ 复制代码
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) 

假设 x = torch.arange(0, dim, 2) 也就是从 0 ~ dim 步长为二的等比数列

f r e q s = 1 θ x d i m / / 2 = 1 θ [ 0 , 2 ] 4 / / 2 = 1 θ [ 0 , 0.5 ] = 1 [ 1 , θ ] = [ 1 , 1 θ ] freqs = \frac{1}{\theta^{\frac{x}{dim // 2}}} = \frac{1}{\theta^{\frac{[0, 2]}{4 // 2}}} = \frac{1}{\theta^{[0, 0.5]}} = \frac{1}{[1, \sqrt{\theta}]} = [1, \frac{1}{\sqrt{\theta}}] freqs=θdim//2x1=θ4//2[0,2]1=θ[0,0.5]1=[1,θ ]1=[1,θ 1]

3.生成从 0 到 end 的位置索引

python 复制代码
t = torch.arange(end, device=freqs.device) 

t = [ 0 , 1 , 2 ] t = [0, 1, 2] t=[0,1,2]

4.计算两个向量的外积得到旋转角度 theta

python 复制代码
freqs = torch.outer(t, freqs).float()

5.将极坐标转换为复数形式

python 复制代码
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)

这里返回的 freqs_cis 是一个用复数表示的旋转矩阵

相关推荐
fly五行4 天前
大模型基础入门与 RAG 实战:从理论到 llama-index 项目搭建(有具体代码示例)
python·ai·llama·llamaindex
德育处主任Pro7 天前
前端玩转大模型,DeepSeek-R1 蒸馏 Llama 模型的 Bedrock 部署
前端·llama
relis8 天前
AVX-512深度实现分析:从原理到LLaMA.cpp的性能优化艺术
性能优化·llama
relis10 天前
llama.cpp RMSNorm CUDA 优化分析报告
算法·llama
云雾J视界10 天前
开源革命下的研发突围:Meta Llama系列模型的知识整合实践与启示
meta·开源·llama·知识管理·知识整合·知识迭代·知识共享
丁学文武11 天前
大模型原理与实践:第三章-预训练语言模型详解_第3部分-Decoder-Only(GPT、LLama、GLM)
人工智能·gpt·语言模型·自然语言处理·大模型·llama·glm
余衫马11 天前
llama.cpp:本地大模型推理的高性能 C++ 框架
c++·人工智能·llm·llama·大模型部署
LETTER•15 天前
Llama 模型架构解析:从 Pre-RMSNorm 到 GQA 的技术演进
深度学习·语言模型·自然语言处理·llama
拓端研究室15 天前
JupyterLab+PyTorch:LoRA+4-bit量化+SFT微调Llama 4医疗推理应用|附代码数据
llama
之歆17 天前
LangGraph构建多智能体
人工智能·python·llama