PyTorch 中张量运算广播

TLDR

右对齐,空补一,从左往右依维运算
[m] + [x, y] = [m +x, m + y]

正文

以如下 a b 两个 tensor 计算为例

python 复制代码
a = torch.tensor([
    [1],
    [2],
    [3],
])
b = torch.tensor([
    [
        [1, 2, 3],
    ],
    [
        [4, 5, 6],
    ],
    [
        [7, 8, 9],
    ],
])
# a.shape = (3, 1)
# b.shape = (3, 1, 3)

首先将两个 tensor 的 shape 右对齐
a( , 3, 1)
b(3, 1, 3)

判断两个 tensor 是否满足广播规则

  • tensor 至少有一个维度(比如 torch.tensor((0,)) 便不符合本要求)
  • 检查上一步对齐的 tensor shape,要求两个 tensor 对应维度的大小:要么相同;要么其中一个为 1;要么其中一个为空
  • 如果满足上述规则,则继续,否则报错

将对齐后空缺的维度设置为 1
a(1, 3, 1)
b(3, 1, 3)

其实就是对 a 进行了扩维,此时两个 tensor 为:

python 复制代码
a = torch.tensor([
    [
        [1],
        [2],
        [3],
    ],
])
b = torch.tensor([
    [
        [1, 2, 3],
    ],
    [
        [4, 5, 6],
    ],
    [
        [7, 8, 9],
    ],
])
# a.shape = (1, 3, 1)
# b.shape = (3, 1, 3)

从左往右对两个 tensor 的每一个维度进行运算,按照以下规则

  • 如果大小相同,则直接进行运算即可(一一对应)
  • 如果其中一个大小为 1,则使用这个元素与另一个 tensor 当前维度下的每个元素进行运算(本质是一个递归操作)

例如计算 a + b (这两个 tensor 已经经过上述步骤处理,即维度已经相同)

python 复制代码
# 1. 因为 a.shape[0] == 1,所以将 a[0] 分别与 b[0]、b[1]、b[2] 相加
[
	a[0] + b[0],
	a[0] + b[1],
	a[0] + b[2],
]

# 2. 接下来继续往后计算,以 a[0] + b[0] 为例
#    因为 a[0].shape[0] = 3, b[0].shape[0] = 1,
#    所以将 b[0][0] 分别与 a[0][0]、a[0][1]、a[0][2] 相加
[
	[	# a[0] + b[0]
		a[0][0] + b[0][0],
		a[0][1] + b[0][0],
		a[0][2] + b[0][0],
	],
	[	# a[0] + b[1]
		a[0][0] + b[1][0],
		a[0][1] + b[1][0],
		a[0][2] + b[1][0],
	],
	[	# a[0] + b[2]
		a[0][0] + b[2][0],
		a[0][1] + b[2][0],
		a[0][2] + b[2][0],
	],
]

# 3. 继续往后计算,以 a[0][0] + b[0][0] 为例
#    因为 a[0][0].shape[0] == 1,
#    所以将 a[0][0][0] 分别与 b[0][0][0]、b[0][0][1]、b[0][0][2] 相加
[
	[	# a[0] + b[0]
		[ 	# a[0][0] + b[0][0]
			a[0][0][0] + b[0][0][0],
			a[0][0][0] + b[0][0][1],
			a[0][0][0] + b[0][0][2],
		],
		[ 	# a[0][1] + b[0][0]
			a[0][1][0] + b[0][0][0],
			a[0][1][0] + b[0][0][1],
			a[0][1][0] + b[0][0][2],
		],
		[ 	# a[0][2] + b[0][0]
			a[0][2][0] + b[0][0][0],
			a[0][2][0] + b[0][0][1],
			a[0][2][0] + b[0][0][2],
		],
	],
	[	# a[0] + b[1]
		[ 	# a[0][0] + b[1][0]
			a[0][0][0] + b[1][0][0],
			a[0][0][0] + b[1][0][1],
			a[0][0][0] + b[1][0][2],
		],
		[ 	# a[0][1] + b[1][0]
			a[0][1][0] + b[1][0][0],
			a[0][1][0] + b[1][0][1],
			a[0][1][0] + b[1][0][2],
		],
		[ 	# a[0][2] + b[1][0]
			a[0][2][0] + b[1][0][0],
			a[0][2][0] + b[1][0][1],
			a[0][2][0] + b[1][0][2],
		],
	],
	[	# a[0] + b[2]
		[ 	# a[0][0] + b[2][0]
			a[0][0][0] + b[2][0][0],
			a[0][0][0] + b[2][0][1],
			a[0][0][0] + b[2][0][2],
		],
		[ 	# a[0][1] + b[2][0]
			a[0][1][0] + b[2][0][0],
			a[0][1][0] + b[2][0][1],
			a[0][1][0] + b[2][0][2],
		],
		[ 	# a[0][2] + b[2][0]
			a[0][2][0] + b[2][0][0],
			a[0][2][0] + b[2][0][1],
			a[0][2][0] + b[2][0][2],
		],
	],
]

总结

右对齐空补一 ,从左往右依维递归 )运算。

一个 tensor 的某个维度大小为 1 时的计算规则:[1] + [2, 3, 4] = [1 + 2, 1 + 3, 1 + 4]

《PyTorch 官方文档:BROADCASTING SEMANTICS》

相关推荐
东方佑7 分钟前
当人眼遇见神经网络:用残差结构模拟视觉调焦的奇妙类比
人工智能·深度学习·神经网络
烟锁池塘柳012 分钟前
【已解决,亲测有效】解决使用Python Matplotlib库绘制图表中出现中文乱码(中文显示为框)的问题的方法
开发语言·python·matplotlib
周小码14 分钟前
llama-stack实战:Python构建Llama应用的可组合开发框架(8k星)
开发语言·python·llama
智驱力人工智能14 分钟前
深度学习在离岗检测中的应用
人工智能·深度学习·安全·视觉检测·离岗检测
hjs_deeplearning19 分钟前
认知篇#12:基于非深度学习方法的图像特征提取
人工智能·深度学习·目标检测
IT学长编程19 分钟前
计算机毕业设计 基于Hadoop的南昌房价数据分析系统的设计与实现 Python 大数据毕业设计 Hadoop毕业设计选题【附源码+文档报告+安装调试
大数据·hadoop·python·毕业设计·课程设计·毕业论文·豆瓣电影数据可视化分析
Tony Bai22 分钟前
【AI应用开发第一课】11 实战串讲:用 Go 构建一个 AI 驱动的 GitHub Issue 助手
人工智能·issue
阿杜杜不是阿木木30 分钟前
开始 ComfyUI 的 AI 绘图之旅-Flux.1 ControlNet (十)
人工智能·深度学习·ai·ai作画·lora
郑洁文31 分钟前
豆瓣网影视数据分析与应用
大数据·python·数据挖掘·数据分析
格林威36 分钟前
Linux使用-MySQL的使用
linux·运维·人工智能·数码相机·mysql·计算机视觉·视觉检测