大模型-位置编码RoPE的具体实现

python 复制代码
def precompute_pos_cis(dim: int, end: int = int(32*1024), theta: float =1e6):
    """位置编码预处理

    Args:
        dim (int): 输入的维度
        end (int, optional): 最大输出Token数. Defaults to int(32*1024).
        theta (float, optional): 控制频率衰减的参数. Defaults to 1e6.
    """
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # freq=1/e^freq,频率的多尺度和dim有关
    t = torch.arange(end ,device = freqs.device) # (end,) 表示输出序列最大长度
    freqs = torch.outer(t, freqs).float() # (end, dim//2),外积,表示每一个位置带有不同频率的旋转
    pos_cis = torch.polar(torch.ones_like(freqs), freqs) # polar(单位阵, 频率阵), 生成复数矩阵
    return pos_cis

解释:

python 复制代码
def apply_rotary_emb(xq,xk, pos_cis):
    """RoPE位置编码

    Args:
        xq (_type_): Q矩阵 (batchsize, seqlen, heads, dim//2)
        xk (_type_): K矩阵 (batchsize, seqlen, heads, dim//2)
        pos_cis (_type_): 预处理后的旋转频率矩阵 (end, dim//2)
    """
    def unite_shape(pos_cis, x):
        """对预处理后的旋转频率矩阵进行广播

        Args:
            pos_cis (_type_): 旋转频率矩阵
            x (_type_): 输入
        """
        ndim = x.ndim # 维度,4 (batchsize, seqlen, heads, dim//2)
        assert 0 <= 1 < ndim
        assert pos_cis.shape == (x.shape[1], x.shape[-1]) # (x.shape[1], x.shape[-1]) = (seqlen, dim//2)
        shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # reshape x = (1, seqlen, 1, dim//2)
        return pos_cis.view(*shape) # reshape pos_cis (1, 1, end. dim//2)
    
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1 ,2))
    pos_cis = unite_shape(*(pos_cis, xq_))
    xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
    
    return xq_out.type_as(xq), xk_out.type_as(xk)      

解释:

相关推荐
RWKV元始智能4 小时前
RWKV超并发项目教程,RWKV-LM训练提速40%
人工智能·rnn·深度学习·自然语言处理·开源
AI技术增长6 小时前
Pytorch图像去噪实战(六):CBDNet真实噪声去噪实战,解决合成噪声模型落地效果差的问题
pytorch·深度学习·机器学习
小糖学代码8 小时前
LLM系列:2.pytorch入门:8.神经网络的损失函数(criterion)
人工智能·深度学习·神经网络
Jmayday8 小时前
Pytorch:RNN理论基础
pytorch·rnn·深度学习
AI周红伟10 小时前
周红伟:GPT-Image-2深度解析:从技术原理到实战教程,为什么它能让整个AI圈炸锅?
人工智能·gpt·深度学习·机器学习·语言模型·openclaw
端平入洛11 小时前
梯度是什么:PyTorch 自动求导详解
人工智能·深度学习
Uopiasd1234oo11 小时前
上下文引导模块改进YOLOv26局部与全局特征融合能力双重提升
深度学习·yolo·机器学习
动物园猫13 小时前
工业织物缺陷目标检测数据集分享(适用于YOLO系列深度学习分类检测任务)
深度学习·yolo·目标检测
ACCELERATOR_LLC13 小时前
【DataWhale组队学习】DIY-LLM Task6 评估与基准测试
人工智能·深度学习·大模型·模型评估
狮子座明仔13 小时前
ThinkTwice: 让模型学会“做完题再检查一遍“,推理+自纠错联合训练只加3%开销
大数据·人工智能·深度学习