【llm对话系统】大模型源码分析之 LLaMA 位置编码 RoPE

在自然语言处理(NLP)领域,Transformer 模型已经成为主流。然而,Transformer 本身并不具备处理序列顺序的能力。为了让模型理解文本中词语的相对位置,我们需要引入位置编码(Positional Encoding)。本文将深入探讨 LLaMA 模型中使用的 Rotary Embedding(旋转式嵌入)位置编码方法,并对比传统的 Transformer 位置编码方案,分析其设计与实现的优势。

1. 传统 Transformer 的位置编码

1.1 正弦余弦编码

在原始的 Transformer 模型中,使用了基于正弦和余弦函数的位置编码。这种编码方式的公式如下:

复制代码
PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))

其中:

  • pos 代表词语在序列中的位置。
  • i 代表编码向量的维度索引。
  • d_model 是模型的维度大小。

这种编码方式的主要特点是:

  • 绝对位置编码: 为每个位置生成唯一的向量。
  • 易于泛化到更长的序列: 可以外推到训练期间未见过的序列长度。
  • 维度变化: 编码向量的每个维度上的频率都不同。

1.2 代码示例 (PyTorch)

python 复制代码
import torch
import math

def positional_encoding(pos, d_model):
    pe = torch.zeros(1, d_model)
    for i in range(0, d_model, 2):
        pe[0, i] = math.sin(pos / (10000 ** (i / d_model)))
        pe[0, i + 1] = math.cos(pos / (10000 ** (i / d_model)))
    return pe

# 示例
d_model = 512
max_len = 10
pos_encodings = torch.stack([positional_encoding(i, d_model) for i in range(max_len)])

print("Position Encodings Shape:", pos_encodings.shape) # 输出: torch.Size([10, 1, 512])
print("First 3 position encodings:\n", pos_encodings[:3])

1.3 缺点

传统的正弦余弦位置编码虽然有效,但也有其局限性:

  • 缺乏相对位置信息: 尽管编码能提供绝对位置,但难以直接捕捉词语之间的相对距离关系。
  • 位置编码与输入向量独立: 位置编码是直接加到输入词向量上的,没有与词向量进行交互,信息损失比较明显。

2. LLaMA 的 Rotary Embedding (RoPE)

LLaMA 模型采用了 Rotary Embedding(RoPE),一种相对位置编码方法,它通过旋转的方式将位置信息嵌入到词向量中。RoPE 的核心思想是将位置信息编码为旋转矩阵,然后将词向量进行旋转,从而引入位置信息。

2.1 RoPE 的核心公式

RoPE 的核心公式如下:

复制代码
RoPE(q, k, pos) = rotate(q, pos, Θ)

其中:

  • qk 分别代表查询向量和键向量。
  • pos 是两个向量之间的相对位置。
  • Θ 是一个旋转矩阵,根据 pos 和预定义的频率生成。
  • rotate(q, pos, Θ) 表示将 q 旋转 Θ 角度后的结果。

更具体来说,对于维度为 d 的向量 q,RoPE 将其分为 d/2 对 (q0, q1), (q2, q3) ..., (qd-2, qd-1)。每个维度对应用不同的旋转角度。旋转矩阵 R 的定义是:

复制代码
R(pos) =  [[cos(pos * θ_0), -sin(pos * θ_0)],
          [sin(pos * θ_0),  cos(pos * θ_0)]]  
          [[cos(pos * θ_1), -sin(pos * θ_1)],
          [sin(pos * θ_1),  cos(pos * θ_1)]]
          ...
          [[cos(pos * θ_d/2-1), -sin(pos * θ_d/2-1)],
          [sin(pos * θ_d/2-1),  cos(pos * θ_d/2-1)]]

其中 θ_i = 10000^(-2i/d) ,每个维度对的旋转角度不同。

将旋转矩阵应用于向量 q ,就是:
q_rotated = R(pos) * q

2.2 LLaMA 源码实现

下面是 LLaMA 中 RoPE 的核心代码(简化版,使用 PyTorch):

python 复制代码
import torch
import math

def precompute_freqs(dim, end, theta=10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
    t = torch.arange(end)
    freqs = torch.outer(t, freqs)
    return torch.cat((freqs, freqs), dim=1)
    
def apply_rotary_emb(xq, xk, freqs):
    xq_complex = torch.complex(xq.float(), torch.roll(xq.float(), shifts=-xq.shape[-1]//2, dims=-1))
    xk_complex = torch.complex(xk.float(), torch.roll(xk.float(), shifts=-xk.shape[-1]//2, dims=-1))
    
    freqs_complex = torch.complex(torch.cos(freqs), torch.sin(freqs))
    
    xq_rotated = xq_complex * freqs_complex
    xk_rotated = xk_complex * freqs_complex

    return xq_rotated.real.type_as(xq), xk_rotated.real.type_as(xk)

# 示例
batch_size = 2
seq_len = 5
d_model = 512
head_dim = d_model//8
xq = torch.randn(batch_size, seq_len, 8, head_dim) # 输入查询向量
xk = torch.randn(batch_size, seq_len, 8, head_dim) # 输入键向量

freqs = precompute_freqs(head_dim, seq_len)
xq_rotated, xk_rotated  = apply_rotary_emb(xq, xk, freqs)
print("Rotated Query Shape:", xq_rotated.shape)
print("Rotated Key Shape:", xk_rotated.shape)

代码解释

  1. precompute_freqs(dim, end, theta) :
    • 此函数用于预计算旋转矩阵中使用的频率。
    • dim: 表示词向量维度。
    • end: 表示最大序列长度。
    • 返回包含所有位置的频率列表。
  2. apply_rotary_emb(xq, xk, freqs) :
    • 函数将旋转操作应用于查询向量 xq 和键向量 xk
    • 通过 complex 表示实数向量的旋转,并使用复数乘法完成旋转操作。
    • 使用 torch.roll() 函数将 xq 分成实部和虚部,使用complex类型可以更快的完成旋转计算,避免了循环遍历,提高计算速度。
    • 使用复数乘法完成旋转,通过 .real 属性取出旋转后的实部,并将类型转换回原始类型

2.3 RoPE 的优势

与传统的正弦余弦位置编码相比,RoPE 具有以下优势:

  1. 相对位置编码: RoPE 专注于编码词语之间的相对位置信息,而不仅仅是绝对位置。通过向量旋转,使得向量之间的相对位置信息更直观。
  2. 高效计算: 通过使用复数乘法,RoPE 可以在GPU上进行高效的并行计算。
  3. 良好的外推能力: RoPE 可以比较容易地推广到训练期间未见过的序列长度,并且性能保持稳定。
  4. 可解释性: RoPE 的旋转操作使其相对位置信息具有更强的可解释性,有助于理解模型的行为。

3. 总结

本文详细介绍了 LLaMA 模型中使用的 Rotary Embedding 位置编码方法。通过源码分析和对比传统的位置编码,我们了解了 RoPE 的核心原理和优势。RoPE 通过旋转操作高效地编码相对位置信息,为 LLaMA 模型的强大性能提供了重要的基础。希望本文能帮助你更深入地理解 Transformer 模型中的位置编码机制。

4. 参考资料

相关推荐
珊珊而川33 分钟前
3.1监督微调
人工智能
我是小伍同学37 分钟前
基于卷积神经网络和Pyqt5的猫狗识别小程序
人工智能·python·神经网络·qt·小程序·cnn
界面开发小八哥3 小时前
界面控件DevExpress WinForms v25.1新功能预览 - 功能区组件全新升级
人工智能·.net·界面控件·winform·devexpress
zhz52143 小时前
开源数字人框架 AWESOME-DIGITAL-HUMAN 技术解析与应用指南
人工智能·ai·机器人·开源·ai编程·ai数字人·智能体
1296004523 小时前
pytorch基础的学习
人工智能·pytorch·学习
沉默媛4 小时前
RuntimeError: expected scalar type ComplexDouble but found Float
人工智能·pytorch·深度学习
契合qht53_shine4 小时前
NLP基础
人工智能·自然语言处理
闭月之泪舞4 小时前
YOLO目标检测算法
人工智能·yolo·目标检测
埃菲尔铁塔_CV算法4 小时前
POSE识别 神经网络
人工智能·深度学习·神经网络
大G哥4 小时前
加速LLM大模型推理,KV缓存技术详解与PyTorch实现
人工智能·pytorch·python·深度学习·缓存