PyTorch 中广播机制(Broadcasting)笔记

在 PyTorch 中存在广播(Broadcasting),广播是一种机制,用于自动扩展较小的张量以匹配较大张量的形状,从而使得它们能够进行元素级操作(如加法、减法、乘法等)。广播并不改变张量的实际数据,而是通过虚拟扩展来简化操作。

目录

广播机制的规则

  1. 如果两个张量的维度数量不同,则将较小的那个张量的形状前面补 1,直到两个张量的维度数量相同。

  2. 如果两个张量在某个维度上的大小不一致,但其中一个张量在该维度上的大小是 1,则可以在该维度上进行广播。

  3. 如果两个张量在任何维度上的大小既不相等也不为 1,则无法进行广播。

  4. 广播后的张量形状是每个维度上大小的最大值。

python 复制代码
import torch

# 示例 1: 形状不同的张量相加
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([1, 2, 3])
# b 会被广播成 [[1, 2, 3], [1, 2, 3]]
result = a + b
print(result)
# 输出:
# tensor([[ 2,  4,  6],
#         [ 5,  7,  9]])

# 示例 2: 形状不同的张量相乘
a = torch.tensor([[1, 2], [3, 4], [5, 6]])
b = torch.tensor([1, 2])
# b 会被广播成 [[1, 2], [1, 2], [1, 2]]
result = a * b
print(result)
# 输出:
# tensor([[ 1,  4],
#         [ 3,  8],
#         [ 5, 12]])

# 示例 3: 形状不同的张量相加
a = torch.tensor([1, 2, 3])
b = torch.tensor([[1], [2], [3]])
# a 会被广播成 [[1, 2, 3], [1, 2, 3], [1, 2, 3]]
# b 会被广播成 [[1, 1, 1], [2, 2, 2], [3, 3, 3]]
result = a + b
print(result)
# 输出:
# tensor([[2, 3, 4],
#         [3, 4, 5],
#         [4, 5, 6]])

广播机制在张量乘法中的应用

在进行张量乘法时,广播机制也可以简化操作,尤其是在批次维度不同时:

python 复制代码
import torch

# 张量 A 的形状是 (2, 3, 4)
A = torch.randn(2, 3, 4)

# 张量 B 的形状是 (4, 5)
B = torch.randn(4, 5)

# 使用广播机制进行张量乘法
# B 会被广播成 (2, 4, 5)
result = torch.matmul(A, B)
print(result.shape)
# 输出:
# torch.Size([2, 3, 5])

判断两个张量是否可以进行广播操作

主要遵循以下规则:

  1. 如果两个张量的维度数量不同,则将较小的那个张量的形状前面补 1,直到两个张量的维度数量相同。
  2. 从最后一个维度开始,逐个维度向前检查:
  • 如果两个张量在某个维度上的大小相同,或者其中一个张量在该维度上的大小是 1,则可以在该维度上进行广播。
  • 如果两个张量在任何维度上的大小既不相等也不为 1,则无法进行广播。

具体步骤

假设有两个张量 A 和 B,其形状分别为 shapeA 和 shapeB。

  1. 对齐维度:将较小的形状前面补 1,使得两个形状的长度相同。
  2. 逐维度检查:从最后一个维度开始,逐个维度向前检查:
  • 如果两个维度大小相同,或其中一个维度大小为 1,则该维度可以进行广播。
  • 如果两个维度大小既不相同且都不为 1,则无法进行广播。

以下是一个 Python 函数,用于判断两个张量是否可以进行广播操作:

python 复制代码
def can_broadcast(shapeA, shapeB):
    # 对齐维度
    lenA, lenB = len(shapeA), len(shapeB)
    if lenA < lenB:
        shapeA = (1,) * (lenB - lenA) + shapeA
    elif lenB < lenA:
        shapeB = (1,) * (lenA - lenB) + shapeB

    # 逐维度检查
    for dimA, dimB in zip(shapeA, shapeB):
        if dimA != dimB and dimA != 1 and dimB != 1:
            return False
    return True

# 示例
shapeA = (2, 3, 4)
shapeB = (4, 5)
print(can_broadcast(shapeA, shapeB))  # 输出: False
'''
对齐维度:(2, 3, 4) 和 (1, 4, 5)
第三个维度:4 和 5,不相等且都不为 1,无法广播。
第二个维度:3 和 4,不相等且都不为 1,无法广播。
第一个维度:2 和 1,可以广播。
'''

shapeA = (2, 3, 4)
shapeB = (1, 4, 5)
print(can_broadcast(shapeA, shapeB))  # 输出: False
'''
对齐维度:(2, 3, 4) 和 (1, 4, 5)
第三个维度:4 和 5,无法广播。
第二个维度:3 和 4,无法广播。
第一个维度:2 和 1,可以广播。
'''

shapeA = (3, 4)
shapeB = (2, 1, 4)
print(can_broadcast(shapeA, shapeB))  # 输出: True
'''
对齐维度:(1, 3, 4) 和 (2, 1, 4)
'''
shapeA = (5,)
shapeB = (1, 5)
print(can_broadcast(shapeA, shapeB)) # 输出: True
'''
对齐维度:将较小的形状前面补 1,使得两个形状的长度相同。得到 (1, 5) 和 (1, 5)
'''

广播机制结合张量乘法例子

示例 1: 形状为 (2, 3, 4) 和 (4,) 的张量

python 复制代码
import numpy as np

# 张量 A 的形状为 (2, 3, 4)
A = np.random.rand(2, 3, 4)

# 张量 B 的形状为 (4,)
B = np.random.rand(4)

# 广播机制会将 B 扩展为 (1, 1, 4),然后再扩展为 (2, 3, 4)
result = A * B

print("A.shape:", A.shape) # A.shape: (2, 3, 4)
print("B.shape:", B.shape) # B.shape: (4,)
print("result.shape:", result.shape) # 逐元素乘法操作成功进行,结果的形状为 (2, 3, 4)

示例 2: 形状为 (2, 3, 4) 和 (3, 4) 的张量

python 复制代码
import numpy as np

# 张量 A 的形状为 (2, 3, 4)
A = np.random.rand(2, 3, 4)

# 张量 B 的形状为 (3, 4)
B = np.random.rand(3, 4)

# 广播机制会将 B 扩展为 (1, 3, 4),然后再扩展为 (2, 3, 4)
result = A * B

print("A.shape:", A.shape)
print("B.shape:", B.shape)
print("result.shape:", result.shape)

示例 3: 形状为 (2, 1, 3) 和 (3, 4) 的张量

python 复制代码
import numpy as np

# 张量 A 的形状为 (2, 1, 3)
A = np.random.rand(2, 1, 3)

# 张量 B 的形状为 (3, 4)
B = np.random.rand(3, 4)

# 广播机制会将 A 扩展为 (2, 1, 3),B 扩展为 (1, 3, 4)
# 矩阵乘法会在最后两个维度上进行
result = np.matmul(A, B)

print("A.shape:", A.shape)
print("B.shape:", B.shape)
print("result.shape:", result.shape)

广播机制允许我们在不同形状的张量之间进行逐元素操作或矩阵操作,而无需显式地扩展张量的形状。这大大简化了张量操作的复杂性,并提高了代码的可读性和效率。

相关推荐
GOTXX8 分钟前
基于Opencv的图像处理软件
图像处理·人工智能·深度学习·opencv·卷积神经网络
熙曦Sakura37 分钟前
完全竞争市场
笔记
糖豆豆今天也要努力鸭1 小时前
torch.__version__的torch版本和conda list的torch版本不一致
linux·pytorch·python·深度学习·conda·torch
何大春1 小时前
【弱监督语义分割】Self-supervised Image-specific Prototype Exploration for WSSS 论文阅读
论文阅读·人工智能·python·深度学习·论文笔记·原型模式
uncle_ll1 小时前
PyTorch图像预处理:计算均值和方差以实现标准化
图像处理·人工智能·pytorch·均值算法·标准化
dr李四维2 小时前
iOS构建版本以及Hbuilder打iOS的ipa包全流程
前端·笔记·ios·产品运营·产品经理·xcode
Suyuoa2 小时前
附录2-pytorch yolov5目标检测
python·深度学习·yolo
余生H3 小时前
transformer.js(三):底层架构及性能优化指南
javascript·深度学习·架构·transformer
罗小罗同学3 小时前
医工交叉入门书籍分享:Transformer模型在机器学习领域的应用|个人观点·24-11-22
深度学习·机器学习·transformer
孤独且没人爱的纸鹤3 小时前
【深度学习】:从人工神经网络的基础原理到循环神经网络的先进技术,跨越智能算法的关键发展阶段及其未来趋势,探索技术进步与应用挑战
人工智能·python·深度学习·机器学习·ai