llama源码学习·model.py[3]ROPE旋转位置编码(3)源码中的广播机制

一.源码注释

python 复制代码
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    '''
       这个函数的目的是为了确保freqs_cis可以根据广播规则与x进行元素级别的运算,特别是在x的维度数量大于2时。
       '''
    # 获取x的维度数量
    ndim = x.ndim
    
    # 确保x至少有两个维度
    assert ndim > 1
    
    # freqs_cis的形状与x的第二和最后一个维度相匹配
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    
    # 遍历x的每个维度,并为第二和最后一个维度保留其原始大小,而为所有其他维度赋值1。
    # 这是为了确保广播时,除了这两个特定维度外,其他所有维度都能自动扩展。
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    
    # 使用view函数来重塑freqs_cis的形状以匹配新的形状
    return freqs_cis.view(*shape)

二、举例说明

python 复制代码
freqs_cis = torch.randn(3,4)
print(freqs_cis.shape)

out: torch.Size([3, 4])

python 复制代码
x = torch.randn(2, 3, 4)
print(x.shape)

out: torch.Size([2, 3, 4])

python 复制代码
# 调用广播函数
reshaped_freqs_cis = reshape_for_broadcast(freqs_cis, x)
print(reshaped_freqs_cis.shape)

out: torch.Size([1, 3, 4])

python 复制代码
# 求和
s = reshaped_freqs_cis + x
print(s.shape)

out: torch.Size([2, 3, 4])

相关推荐
辣大辣条8 小时前
LLAMA-Factory Qwen3-1.7b模型微调
llama
我狸才不是赔钱货13 小时前
AI大模型“战国策”:主流LLM平台简单介绍
c++·人工智能·程序人生·github·llama
临街的小孩3 天前
Docker 容器访问宿主机 Ollama 服务配置教程
llama·argflow
鸿蒙小白龙3 天前
OpenHarmony平台大语言模型本地推理:llama深度适配与部署技术详解
人工智能·语言模型·harmonyos·鸿蒙·鸿蒙系统·llama·open harmony
AI大模型5 天前
轻松搞定百个大模型微调!LLaMA-Factory:你的AI模型量产神器
程序员·llm·llama
fly五行9 天前
大模型基础入门与 RAG 实战:从理论到 llama-index 项目搭建(有具体代码示例)
python·ai·llama·llamaindex
德育处主任Pro13 天前
前端玩转大模型,DeepSeek-R1 蒸馏 Llama 模型的 Bedrock 部署
前端·llama
relis14 天前
AVX-512深度实现分析:从原理到LLaMA.cpp的性能优化艺术
性能优化·llama
relis15 天前
llama.cpp RMSNorm CUDA 优化分析报告
算法·llama
云雾J视界15 天前
开源革命下的研发突围:Meta Llama系列模型的知识整合实践与启示
meta·开源·llama·知识管理·知识整合·知识迭代·知识共享