【llm对话系统】大模型源码分析之 LLaMA 位置编码 RoPE

在自然语言处理(NLP)领域,Transformer 模型已经成为主流。然而,Transformer 本身并不具备处理序列顺序的能力。为了让模型理解文本中词语的相对位置,我们需要引入位置编码(Positional Encoding)。本文将深入探讨 LLaMA 模型中使用的 Rotary Embedding(旋转式嵌入)位置编码方法,并对比传统的 Transformer 位置编码方案,分析其设计与实现的优势。

1. 传统 Transformer 的位置编码

1.1 正弦余弦编码

在原始的 Transformer 模型中,使用了基于正弦和余弦函数的位置编码。这种编码方式的公式如下:

复制代码
PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

其中:

  • pos 代表词语在序列中的位置。
  • i 代表编码向量的维度索引。
  • d_model 是模型的维度大小。

这种编码方式的主要特点是:

  • 绝对位置编码: 为每个位置生成唯一的向量。
  • 易于泛化到更长的序列: 可以外推到训练期间未见过的序列长度。
  • 维度变化: 编码向量的每个维度上的频率都不同。

1.2 代码示例 (PyTorch)

python 复制代码
import torch
import math

def positional_encoding(pos, d_model):
    pe = torch.zeros(1, d_model)
    for i in range(0, d_model, 2):
        pe[0, i] = math.sin(pos / (10000 ** (i / d_model)))
        pe[0, i + 1] = math.cos(pos / (10000 ** (i / d_model)))
    return pe

# 示例
d_model = 512
max_len = 10
pos_encodings = torch.stack([positional_encoding(i, d_model) for i in range(max_len)])

print("Position Encodings Shape:", pos_encodings.shape) # 输出: torch.Size([10, 1, 512])
print("First 3 position encodings:\n", pos_encodings[:3])

1.3 缺点

传统的正弦余弦位置编码虽然有效,但也有其局限性:

  • 缺乏相对位置信息: 尽管编码能提供绝对位置,但难以直接捕捉词语之间的相对距离关系。
  • 位置编码与输入向量独立: 位置编码是直接加到输入词向量上的,没有与词向量进行交互,信息损失比较明显。

2. LLaMA 的 Rotary Embedding (RoPE)

LLaMA 模型采用了 Rotary Embedding(RoPE),一种相对位置编码方法,它通过旋转的方式将位置信息嵌入到词向量中。RoPE 的核心思想是将位置信息编码为旋转矩阵,然后将词向量进行旋转,从而引入位置信息。

2.1 RoPE 的核心公式

RoPE 的核心公式如下:

复制代码
RoPE(q, k, pos) = rotate(q, pos, Θ)

其中:

  • qk 分别代表查询向量和键向量。
  • pos 是两个向量之间的相对位置。
  • Θ 是一个旋转矩阵,根据 pos 和预定义的频率生成。
  • rotate(q, pos, Θ) 表示将 q 旋转 Θ 角度后的结果。

更具体来说,对于维度为 d 的向量 q,RoPE 将其分为 d/2 对 (q0, q1), (q2, q3) ..., (qd-2, qd-1)。每个维度对应用不同的旋转角度。旋转矩阵 R 的定义是:

复制代码
R(pos) =  [[cos(pos * θ_0), -sin(pos * θ_0)],
          [sin(pos * θ_0),  cos(pos * θ_0)]]  
          [[cos(pos * θ_1), -sin(pos * θ_1)],
          [sin(pos * θ_1),  cos(pos * θ_1)]]
          ...
          [[cos(pos * θ_d/2-1), -sin(pos * θ_d/2-1)],
          [sin(pos * θ_d/2-1),  cos(pos * θ_d/2-1)]]

其中 θ_i = 10000^(-2i/d) ,每个维度对的旋转角度不同。

将旋转矩阵应用于向量 q ,就是:
q_rotated = R(pos) * q

2.2 LLaMA 源码实现

下面是 LLaMA 中 RoPE 的核心代码(简化版,使用 PyTorch):

python 复制代码
import torch
import math

def precompute_freqs(dim, end, theta=10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(end)
    freqs = torch.outer(t, freqs)
    return torch.cat((freqs, freqs), dim=1)
    
def apply_rotary_emb(xq, xk, freqs):
    xq_complex = torch.complex(xq.float(), torch.roll(xq.float(), shifts=-xq.shape[-1]//2, dims=-1))
    xk_complex = torch.complex(xk.float(), torch.roll(xk.float(), shifts=-xk.shape[-1]//2, dims=-1))
    
    freqs_complex = torch.complex(torch.cos(freqs), torch.sin(freqs))
    
    xq_rotated = xq_complex * freqs_complex
    xk_rotated = xk_complex * freqs_complex

    return xq_rotated.real.type_as(xq), xk_rotated.real.type_as(xk)

# 示例
batch_size = 2
seq_len = 5
d_model = 512
head_dim = d_model//8
xq = torch.randn(batch_size, seq_len, 8, head_dim) # 输入查询向量
xk = torch.randn(batch_size, seq_len, 8, head_dim) # 输入键向量

freqs = precompute_freqs(head_dim, seq_len)
xq_rotated, xk_rotated  = apply_rotary_emb(xq, xk, freqs)
print("Rotated Query Shape:", xq_rotated.shape)
print("Rotated Key Shape:", xk_rotated.shape)

代码解释

  1. precompute_freqs(dim, end, theta) :
    • 此函数用于预计算旋转矩阵中使用的频率。
    • dim: 表示词向量维度。
    • end: 表示最大序列长度。
    • 返回包含所有位置的频率列表。
  2. apply_rotary_emb(xq, xk, freqs) :
    • 函数将旋转操作应用于查询向量 xq 和键向量 xk
    • 通过 complex 表示实数向量的旋转,并使用复数乘法完成旋转操作。
    • 使用 torch.roll() 函数将 xq 分成实部和虚部,使用complex类型可以更快的完成旋转计算,避免了循环遍历,提高计算速度。
    • 使用复数乘法完成旋转,通过 .real 属性取出旋转后的实部,并将类型转换回原始类型

2.3 RoPE 的优势

与传统的正弦余弦位置编码相比,RoPE 具有以下优势:

  1. 相对位置编码: RoPE 专注于编码词语之间的相对位置信息,而不仅仅是绝对位置。通过向量旋转,使得向量之间的相对位置信息更直观。
  2. 高效计算: 通过使用复数乘法,RoPE 可以在GPU上进行高效的并行计算。
  3. 良好的外推能力: RoPE 可以比较容易地推广到训练期间未见过的序列长度,并且性能保持稳定。
  4. 可解释性: RoPE 的旋转操作使其相对位置信息具有更强的可解释性,有助于理解模型的行为。

3. 总结

本文详细介绍了 LLaMA 模型中使用的 Rotary Embedding 位置编码方法。通过源码分析和对比传统的位置编码,我们了解了 RoPE 的核心原理和优势。RoPE 通过旋转操作高效地编码相对位置信息,为 LLaMA 模型的强大性能提供了重要的基础。希望本文能帮助你更深入地理解 Transformer 模型中的位置编码机制。

4. 参考资料

相关推荐
陈天伟教授1 小时前
基于学习的人工智能(3)机器学习基本框架
人工智能·学习·机器学习·知识图谱
搞科研的小刘选手2 小时前
【厦门大学主办】第六届计算机科学与管理科技国际学术会议(ICCSMT 2025)
人工智能·科技·计算机网络·计算机·云计算·学术会议
fanstuck2 小时前
深入解析 PyPTO Operator:以 DeepSeek‑V3.2‑Exp 模型为例的实战指南
人工智能·语言模型·aigc·gpu算力
萤丰信息2 小时前
智慧园区能源革命:从“耗电黑洞”到零碳样本的蜕变
java·大数据·人工智能·科技·安全·能源·智慧园区
世洋Blog2 小时前
更好的利用ChatGPT进行项目的开发
人工智能·unity·chatgpt
噜~噜~噜~5 小时前
最大熵原理(Principle of Maximum Entropy,MaxEnt)的个人理解
深度学习·最大熵原理
serve the people6 小时前
机器学习(ML)和人工智能(AI)技术在WAF安防中的应用
人工智能·机器学习
0***K8926 小时前
前端机器学习
人工智能·机器学习
陈天伟教授6 小时前
基于学习的人工智能(5)机器学习基本框架
人工智能·学习·机器学习