llama源码学习·model.py[3]ROPE旋转位置编码(3)源码中的广播机制

一.源码注释

python 复制代码
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    '''
       这个函数的目的是为了确保freqs_cis可以根据广播规则与x进行元素级别的运算,特别是在x的维度数量大于2时。
       '''
    # 获取x的维度数量
    ndim = x.ndim
    
    # 确保x至少有两个维度
    assert ndim > 1
    
    # freqs_cis的形状与x的第二和最后一个维度相匹配
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    
    # 遍历x的每个维度,并为第二和最后一个维度保留其原始大小,而为所有其他维度赋值1。
    # 这是为了确保广播时,除了这两个特定维度外,其他所有维度都能自动扩展。
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    
    # 使用view函数来重塑freqs_cis的形状以匹配新的形状
    return freqs_cis.view(*shape)

二、举例说明

python 复制代码
freqs_cis = torch.randn(3,4)
print(freqs_cis.shape)

out: torch.Size([3, 4])

python 复制代码
x = torch.randn(2, 3, 4)
print(x.shape)

out: torch.Size([2, 3, 4])

python 复制代码
# 调用广播函数
reshaped_freqs_cis = reshape_for_broadcast(freqs_cis, x)
print(reshaped_freqs_cis.shape)

out: torch.Size([1, 3, 4])

python 复制代码
# 求和
s = reshaped_freqs_cis + x
print(s.shape)

out: torch.Size([2, 3, 4])

相关推荐
Mr_sst15 小时前
infra-ai模块宏观设计解析:业务与模型之间的中间层核心架构
大数据·人工智能·ai·llama
微软技术分享21 小时前
Windows平台下CUDA安装及llama.cpp使用教程
windows·llama
小wu学cv2 天前
llama.cpp调用GPU推理Qwen3.5-0.8b模型
llama
zhangfeng11332 天前
LLaMA-Factory 保存 checkpoint 时崩溃解决办法 OOM 内存溢出(不是显存)
运维·服务器·人工智能·深度学习·llama
老唐7773 天前
30分钟手搓 Agent:LLM + Tools + Loop + Memory 跑通最小闭环
人工智能·ai·语言模型·agent·llama·智能体
高兴就好(石3 天前
Mac使用llama.cpp
macos·llama
zhangfeng11334 天前
No space left on device (28) llamafactory微调训练的时候 报错,需要调节 dataloader_num_workers
人工智能·语言模型·llama
阿珊和她的猫4 天前
大模型在客服场景:落地路径 + 效果评估
ai·agent·llama·cli·mcp
谷子熟了5 天前
电商智能客服系统本地搭建
经验分享·docker·typescript·ai编程·llama
YXHPY5 天前
开源 AI 工作流底座正在加速:从 llama.cpp、Ollama 到 vLLM 与 Agent 编排
人工智能·开源·llama