pytorch张量运算的广播机制

PyTorch 的广播机制(broadcasting)是指在进行张量运算时,自动扩展较小张量的形状以匹配较大张量的形状,使它们能够进行逐元素运算。广播机制避免了手动扩展张量的繁琐过程,并且在不增加内存开销的情况下进行高效计算。

广播规则

  1. 比较张量的形状:从后向前比较两个张量的每个维度(即从最右边的维度开始)。
  2. 维度匹配
    • 如果两个维度相等,则可以进行运算。
    • 如果一个张量在该维度上为 1,另一个张量为任意数值,则形状为 1 的张量会沿着该维度扩展,以匹配另一个张量的形状。
    • 如果两个张量在某个维度上不相等且没有一个是 1,则无法进行广播,运算会抛出错误。

广播机制的示例代码

1. 标量与张量的运算
复制代码
import torch

# 标量和张量相加
scalar = torch.tensor(3)
tensor = torch.tensor([1, 2, 3])

result = scalar + tensor
print(result)  # 输出: tensor([4, 5, 6])
2. 不同形状的张量运算
复制代码
import torch

# 创建两个形状不同的张量
A = torch.tensor([[1, 2, 3], [4, 5, 6]])  # 形状为 [2, 3]
B = torch.tensor([1, 2, 3])              # 形状为 [3]

# B 张量的形状会沿着第一个维度自动扩展为 [2, 3]
result = A + B 
print(result)   # 输出: tensor([[2, 4, 6], [5, 7, 9]])
3. 高维张量的广播
复制代码
import torch

# 形状为 [2, 1, 3]
C = torch.tensor([[[1, 2, 3]], [[4, 5, 6]]])

# 形状为 [3]
D = torch.tensor([1, 2, 3])

# D 张量的形状会沿着第一个和第二个维度扩展为 [2, 1, 3]
result = C + D  
print(result)   # 输出: tensor([[[ 2,  4,  6]], [[ 5,  7,  9]]])
4. 不兼容的张量运算
复制代码
import torch

# 形状为 [2, 3]
E = torch.tensor([[1, 2, 3], [4, 5, 6]])

# 形状为 [2]
F = torch.tensor([1, 2])

# 尝试进行运算将抛出错误,因为 E 和 F 在最后一个维度上不匹配
try:
    result = E + F  # 形状不兼容,会抛出错误
except RuntimeError as e:
    print("Error:", e)

总结

广播机制极大简化了张量运算的代码编写,特别是在处理不同行数和列数的张量时。理解广播机制能够帮助你编写更高效、简洁的代码,并充分利用 PyTorch 的计算能力。

相关推荐
数科云4 小时前
AI提示词(Prompt)入门:什么是Prompt?为什么要写好Prompt?
人工智能·aigc·ai写作·ai工具集·最新ai资讯
Devlive 开源社区5 小时前
技术日报|Claude Code超级能力库superpowers登顶日增1538星,自主AI循环ralph爆火登榜第二
人工智能
软件供应链安全指南5 小时前
灵脉 IAST 5.4 升级:双轮驱动 AI 漏洞治理与业务逻辑漏洞精准检测
人工智能·安全
lanmengyiyu5 小时前
单塔和双塔的区别和共同点
人工智能·双塔模型·网络结构·单塔模型
微光闪现5 小时前
AI识别宠物焦虑、紧张和晕车行为,是否已经具备实际可行性?
大数据·人工智能·宠物
技术小黑屋_6 小时前
用好Few-shot Prompting,AI 准确率提升100%
人工智能
中草药z6 小时前
【嵌入模型】概念、应用与两大 AI 开源社区(Hugging Face / 魔塔)
人工智能·算法·机器学习·数据集·向量·嵌入模型
web3.08889996 小时前
微店商品详情API实用
python·json·时序数据库
知乎的哥廷根数学学派6 小时前
基于数据驱动的自适应正交小波基优化算法(Python)
开发语言·网络·人工智能·pytorch·python·深度学习·算法
DisonTangor6 小时前
GLM-Image:面向密集知识与高保真图像生成的自回归模型
人工智能·ai作画·数据挖掘·回归·aigc