pytorch张量运算的广播机制

PyTorch 的广播机制(broadcasting)是指在进行张量运算时,自动扩展较小张量的形状以匹配较大张量的形状,使它们能够进行逐元素运算。广播机制避免了手动扩展张量的繁琐过程,并且在不增加内存开销的情况下进行高效计算。

广播规则

  1. 比较张量的形状:从后向前比较两个张量的每个维度(即从最右边的维度开始)。
  2. 维度匹配
    • 如果两个维度相等,则可以进行运算。
    • 如果一个张量在该维度上为 1,另一个张量为任意数值,则形状为 1 的张量会沿着该维度扩展,以匹配另一个张量的形状。
    • 如果两个张量在某个维度上不相等且没有一个是 1,则无法进行广播,运算会抛出错误。

广播机制的示例代码

1. 标量与张量的运算
import torch

# 标量和张量相加
scalar = torch.tensor(3)
tensor = torch.tensor([1, 2, 3])

result = scalar + tensor
print(result)  # 输出: tensor([4, 5, 6])
2. 不同形状的张量运算
import torch

# 创建两个形状不同的张量
A = torch.tensor([[1, 2, 3], [4, 5, 6]])  # 形状为 [2, 3]
B = torch.tensor([1, 2, 3])              # 形状为 [3]

# B 张量的形状会沿着第一个维度自动扩展为 [2, 3]
result = A + B 
print(result)   # 输出: tensor([[2, 4, 6], [5, 7, 9]])
3. 高维张量的广播
import torch

# 形状为 [2, 1, 3]
C = torch.tensor([[[1, 2, 3]], [[4, 5, 6]]])

# 形状为 [3]
D = torch.tensor([1, 2, 3])

# D 张量的形状会沿着第一个和第二个维度扩展为 [2, 1, 3]
result = C + D  
print(result)   # 输出: tensor([[[ 2,  4,  6]], [[ 5,  7,  9]]])
4. 不兼容的张量运算
import torch

# 形状为 [2, 3]
E = torch.tensor([[1, 2, 3], [4, 5, 6]])

# 形状为 [2]
F = torch.tensor([1, 2])

# 尝试进行运算将抛出错误,因为 E 和 F 在最后一个维度上不匹配
try:
    result = E + F  # 形状不兼容,会抛出错误
except RuntimeError as e:
    print("Error:", e)

总结

广播机制极大简化了张量运算的代码编写,特别是在处理不同行数和列数的张量时。理解广播机制能够帮助你编写更高效、简洁的代码,并充分利用 PyTorch 的计算能力。

相关推荐
沐雪架构师8 分钟前
AI大模型开发原理篇-2:语言模型雏形之词袋模型
人工智能·语言模型·自然语言处理
python算法(魔法师版)1 小时前
深度学习深度解析:从基础到前沿
人工智能·深度学习
小王子10241 小时前
设计模式Python版 组合模式
python·设计模式·组合模式
kakaZhui1 小时前
【llm对话系统】大模型源码分析之 LLaMA 位置编码 RoPE
人工智能·深度学习·chatgpt·aigc·llama
struggle20252 小时前
一个开源 GenBI AI 本地代理(确保本地数据安全),使数据驱动型团队能够与其数据进行互动,生成文本到 SQL、图表、电子表格、报告和 BI
人工智能·深度学习·目标检测·语言模型·自然语言处理·数据挖掘·集成学习
佛州小李哥2 小时前
通过亚马逊云科技Bedrock打造自定义AI智能体Agent(上)
人工智能·科技·ai·语言模型·云计算·aws·亚马逊云科技
Mason Lin3 小时前
2025年1月22日(网络编程 udp)
网络·python·udp
清弦墨客3 小时前
【蓝桥杯】43697.机器人塔
python·蓝桥杯·程序算法
云空3 小时前
《DeepSeek 网页/API 性能异常(DeepSeek Web/API Degraded Performance):网络安全日志》
运维·人工智能·web安全·网络安全·开源·网络攻击模型·安全威胁分析
AIGC大时代3 小时前
对比DeepSeek、ChatGPT和Kimi的学术写作关键词提取能力
论文阅读·人工智能·chatgpt·数据分析·prompt