llama源码学习·model.py[3]ROPE旋转位置编码(3)源码中的广播机制

一.源码注释

python 复制代码
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    '''
       这个函数的目的是为了确保freqs_cis可以根据广播规则与x进行元素级别的运算,特别是在x的维度数量大于2时。
       '''
    # 获取x的维度数量
    ndim = x.ndim
    
    # 确保x至少有两个维度
    assert ndim > 1
    
    # freqs_cis的形状与x的第二和最后一个维度相匹配
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    
    # 遍历x的每个维度,并为第二和最后一个维度保留其原始大小,而为所有其他维度赋值1。
    # 这是为了确保广播时,除了这两个特定维度外,其他所有维度都能自动扩展。
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    
    # 使用view函数来重塑freqs_cis的形状以匹配新的形状
    return freqs_cis.view(*shape)

二、举例说明

python 复制代码
freqs_cis = torch.randn(3,4)
print(freqs_cis.shape)

out: torch.Size([3, 4])

python 复制代码
x = torch.randn(2, 3, 4)
print(x.shape)

out: torch.Size([2, 3, 4])

python 复制代码
# 调用广播函数
reshaped_freqs_cis = reshape_for_broadcast(freqs_cis, x)
print(reshaped_freqs_cis.shape)

out: torch.Size([1, 3, 4])

python 复制代码
# 求和
s = reshaped_freqs_cis + x
print(s.shape)

out: torch.Size([2, 3, 4])

相关推荐
陈奕昆1 小时前
4.2【LLaMA-Factory实战】金融财报分析系统:从数据到部署的全流程实践
人工智能·金融·llama·大模型微调
陈奕昆4 小时前
4.3【LLaMA-Factory实战】教育大模型:个性化学习路径生成系统全解析
人工智能·python·学习·llama·大模型微调
zuozewei8 小时前
7D-AI系列:模型微调之llama-factory
人工智能·llama
weixin_444579308 小时前
基于Llama3的开发应用(一):Llama模型的简单部署
人工智能·深度学习·llama
o0o_-_1 天前
【瞎折腾/mi50 32G/ubuntu】mi50显卡ubuntu运行大模型开坑(二)使用llama.cpp部署Qwen3系列
linux·ubuntu·llama
DisonTangor1 天前
LLaMA-Omni 2:基于 LLM 的自回归流语音合成实时口语聊天机器人
人工智能·开源·aigc·音视频·llama
陈奕昆1 天前
二、【LLaMA-Factory实战】数据工程全流程:从格式规范到高质量数据集构建
前端·人工智能·python·llama·大模型微调
陈奕昆6 天前
【LLaMA-Factory实战】Web UI快速上手:可视化大模型微调全流程
前端·ui·llama·大模型微调实战
OJAC近屿智能6 天前
宇树科技开启“人形机器人格斗盛宴”
人工智能·科技·ui·机器人·aigc·llama·近屿智能
mingo_敏8 天前
Windows系统编译支持GPU的llama.cpp
windows·llama