PyTorch中的广播机制(Broadcasting Mechanism)是一种强大的功能,它允许不同形状的张量在进行算术运算时自动扩展其维度,从而使得这些操作成为可能,而无需显式地复制数据。这种机制极大地简化了代码,并提高了效率。
广播规则
广播机制遵循以下几条基本规则:
- 每个张量至少有一个维度。
- 从后往前比较张量的各个维度(即从最后一个维度到第一个维度)。两个张量的对应维度要么相等,要么其中一个为1,或者一个张量在此维度上没有尺寸(即此维度不存在)。
- 如果某个维度上的大小是1,则该维度会被重复使用以匹配另一个张量的相应维度大小。
- 最终结果的形状由各输入张量中每个维度的最大值决定。
如果满足上述条件,那么这两个张量就是"广播兼容"的,可以执行元素级的操作如加法、减法等。
import torch
A = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: (2, 3)
B = torch.tensor([10, 20, 30]) # shape: (3)
result = A + B
# 结果:
# tensor([[11, 22, 33],
# [14, 25, 36]])
C = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) # shape: (2, 2, 2)
D = torch.tensor([[10, 20], [30, 40]]) # shape: (2, 2)
result = C + D
# 结果:
# tensor([[[11, 22],
# [33, 44]],
# [[15, 26],
# [37, 48]]])
接下来进一步进行演示。
import torch
import numpy as np
# 创建NumPy数组
A = np.arange(0, 40, 10).reshape(4, 1) # 形状为 (4, 1)
B = np.arange(0, 3) # 形状为 (3,)
# 将NumPy数组转换为PyTorch Tensor
A1 = torch.from_numpy(A) # 形状为 (4, 1)
B1 = torch.from_numpy(B) # 形状为 (3,)
# 使用广播机制自动扩展
C = A1 + B1
print("Using broadcasting:")
print(C)
# 手动实现广播
# 根据规则1,B1需要向A1看齐,把B变为(1, 3)
B2 = B1.unsqueeze(0) # 形状变为 (1, 3)
# 使用expand函数重复数组,分别得到4x3的矩阵
A2 = A1.expand(4, 3) # 形状变为 (4, 3)
B3 = B2.expand(4, 3) # 形状变为 (4, 3)
# 然后进行相加,C1与C结果一致
C1 = A2 + B3
print("Manual broadcasting:")
print(C1)
无论你是通过自动广播还是手动模拟广播机制,最终的结果都是相同的:
Using broadcasting:
tensor([[ 0, 1, 2],
[10, 11, 12],
[20, 21, 22],
[30, 31, 32]], dtype=torch.int32)
Manual broadcasting:
tensor([[ 0, 1, 2],
[10, 11, 12],
[20, 21, 22],
[30, 31, 32]], dtype=torch.int32)
通过上述代码和解析,我们了解到:
- 广播机制允许不同形状的张量进行元素级的操作,而无需显式地复制数据。
unsqueeze
函数可以在指定位置插入一个新的维度,这对于准备广播非常有用。expand
方法可以将张量扩展到目标形状,但它不会分配新的内存,而是返回一个视图,除非必要时才会复制数据。