llama源码学习·model.py[3]ROPE旋转位置编码(3)源码中的广播机制

一.源码注释

python 复制代码
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    '''
       这个函数的目的是为了确保freqs_cis可以根据广播规则与x进行元素级别的运算,特别是在x的维度数量大于2时。
       '''
    # 获取x的维度数量
    ndim = x.ndim
    
    # 确保x至少有两个维度
    assert ndim > 1
    
    # freqs_cis的形状与x的第二和最后一个维度相匹配
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    
    # 遍历x的每个维度,并为第二和最后一个维度保留其原始大小,而为所有其他维度赋值1。
    # 这是为了确保广播时,除了这两个特定维度外,其他所有维度都能自动扩展。
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    
    # 使用view函数来重塑freqs_cis的形状以匹配新的形状
    return freqs_cis.view(*shape)

二、举例说明

python 复制代码
freqs_cis = torch.randn(3,4)
print(freqs_cis.shape)

out: torch.Size([3, 4])

python 复制代码
x = torch.randn(2, 3, 4)
print(x.shape)

out: torch.Size([2, 3, 4])

python 复制代码
# 调用广播函数
reshaped_freqs_cis = reshape_for_broadcast(freqs_cis, x)
print(reshaped_freqs_cis.shape)

out: torch.Size([1, 3, 4])

python 复制代码
# 求和
s = reshaped_freqs_cis + x
print(s.shape)

out: torch.Size([2, 3, 4])

相关推荐
风筝超冷1 天前
LLaMA-Factory - 批量推理(inference)的脚本
llama
bluebonnet273 天前
【agent开发】部署LLM(一)
python·llama
阿牛大牛中3 天前
LLaDa——基于 Diffusion 的大语言模型 打平 LLama 3
人工智能·语言模型·llama
Lilith的AI学习日记4 天前
【AI面试秘籍】| 第25期:RAG的关键痛点及解决方案深度解析
人工智能·深度学习·机器学习·chatgpt·aigc·llama
LChuck6 天前
【大模型微调】魔搭社区GPU进行LLaMA-Factory微调大模型自我认知
人工智能·语言模型·自然语言处理·nlp·llama·魔搭社区·modelscope
燕双嘤6 天前
Fine-tuning:微调技术,训练方式,LLaMA-Factory,ms-swift
llama
装不满的克莱因瓶9 天前
【小白AI教程】大模型知识扫盲通识
人工智能·数学建模·ai·大模型·llm·llama·rag
TGITCIC11 天前
英伟达破局1000 Token/秒!Llama 4以光速重塑AI推理边界
人工智能·大模型·llama·英伟达·大模型速度·ai赛道·大模型基座
天天爱吃肉821812 天前
【 大模型技术驱动智能网联汽车革命:关键技术解析与未来趋势】
语言模型·汽车·llama
Lilith的AI学习日记15 天前
【AI面试秘籍】| 第17期:MoE并行策略面试全攻略:从理论到调参的降维打击指南
人工智能·python·面试·职场和发展·llama