PyTorch 中广播机制(Broadcasting)笔记

在 PyTorch 中存在广播(Broadcasting),广播是一种机制,用于自动扩展较小的张量以匹配较大张量的形状,从而使得它们能够进行元素级操作(如加法、减法、乘法等)。广播并不改变张量的实际数据,而是通过虚拟扩展来简化操作。

目录

广播机制的规则

  1. 如果两个张量的维度数量不同,则将较小的那个张量的形状前面补 1,直到两个张量的维度数量相同。

  2. 如果两个张量在某个维度上的大小不一致,但其中一个张量在该维度上的大小是 1,则可以在该维度上进行广播。

  3. 如果两个张量在任何维度上的大小既不相等也不为 1,则无法进行广播。

  4. 广播后的张量形状是每个维度上大小的最大值。

python 复制代码
import torch

# 示例 1: 形状不同的张量相加
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([1, 2, 3])
# b 会被广播成 [[1, 2, 3], [1, 2, 3]]
result = a + b
print(result)
# 输出:
# tensor([[ 2,  4,  6],
#         [ 5,  7,  9]])

# 示例 2: 形状不同的张量相乘
a = torch.tensor([[1, 2], [3, 4], [5, 6]])
b = torch.tensor([1, 2])
# b 会被广播成 [[1, 2], [1, 2], [1, 2]]
result = a * b
print(result)
# 输出:
# tensor([[ 1,  4],
#         [ 3,  8],
#         [ 5, 12]])

# 示例 3: 形状不同的张量相加
a = torch.tensor([1, 2, 3])
b = torch.tensor([[1], [2], [3]])
# a 会被广播成 [[1, 2, 3], [1, 2, 3], [1, 2, 3]]
# b 会被广播成 [[1, 1, 1], [2, 2, 2], [3, 3, 3]]
result = a + b
print(result)
# 输出:
# tensor([[2, 3, 4],
#         [3, 4, 5],
#         [4, 5, 6]])

广播机制在张量乘法中的应用

在进行张量乘法时,广播机制也可以简化操作,尤其是在批次维度不同时:

python 复制代码
import torch

# 张量 A 的形状是 (2, 3, 4)
A = torch.randn(2, 3, 4)

# 张量 B 的形状是 (4, 5)
B = torch.randn(4, 5)

# 使用广播机制进行张量乘法
# B 会被广播成 (2, 4, 5)
result = torch.matmul(A, B)
print(result.shape)
# 输出:
# torch.Size([2, 3, 5])

判断两个张量是否可以进行广播操作

主要遵循以下规则:

  1. 如果两个张量的维度数量不同,则将较小的那个张量的形状前面补 1,直到两个张量的维度数量相同。
  2. 从最后一个维度开始,逐个维度向前检查:
  • 如果两个张量在某个维度上的大小相同,或者其中一个张量在该维度上的大小是 1,则可以在该维度上进行广播。
  • 如果两个张量在任何维度上的大小既不相等也不为 1,则无法进行广播。

具体步骤

假设有两个张量 A 和 B,其形状分别为 shapeA 和 shapeB。

  1. 对齐维度:将较小的形状前面补 1,使得两个形状的长度相同。
  2. 逐维度检查:从最后一个维度开始,逐个维度向前检查:
  • 如果两个维度大小相同,或其中一个维度大小为 1,则该维度可以进行广播。
  • 如果两个维度大小既不相同且都不为 1,则无法进行广播。

以下是一个 Python 函数,用于判断两个张量是否可以进行广播操作:

python 复制代码
def can_broadcast(shapeA, shapeB):
    # 对齐维度
    lenA, lenB = len(shapeA), len(shapeB)
    if lenA < lenB:
        shapeA = (1,) * (lenB - lenA) + shapeA
    elif lenB < lenA:
        shapeB = (1,) * (lenA - lenB) + shapeB

    # 逐维度检查
    for dimA, dimB in zip(shapeA, shapeB):
        if dimA != dimB and dimA != 1 and dimB != 1:
            return False
    return True

# 示例
shapeA = (2, 3, 4)
shapeB = (4, 5)
print(can_broadcast(shapeA, shapeB))  # 输出: False
'''
对齐维度:(2, 3, 4) 和 (1, 4, 5)
第三个维度:4 和 5,不相等且都不为 1,无法广播。
第二个维度:3 和 4,不相等且都不为 1,无法广播。
第一个维度:2 和 1,可以广播。
'''

shapeA = (2, 3, 4)
shapeB = (1, 4, 5)
print(can_broadcast(shapeA, shapeB))  # 输出: False
'''
对齐维度:(2, 3, 4) 和 (1, 4, 5)
第三个维度:4 和 5,无法广播。
第二个维度:3 和 4,无法广播。
第一个维度:2 和 1,可以广播。
'''

shapeA = (3, 4)
shapeB = (2, 1, 4)
print(can_broadcast(shapeA, shapeB))  # 输出: True
'''
对齐维度:(1, 3, 4) 和 (2, 1, 4)
'''
shapeA = (5,)
shapeB = (1, 5)
print(can_broadcast(shapeA, shapeB)) # 输出: True
'''
对齐维度:将较小的形状前面补 1,使得两个形状的长度相同。得到 (1, 5) 和 (1, 5)
'''

广播机制结合张量乘法例子

示例 1: 形状为 (2, 3, 4) 和 (4,) 的张量

python 复制代码
import numpy as np

# 张量 A 的形状为 (2, 3, 4)
A = np.random.rand(2, 3, 4)

# 张量 B 的形状为 (4,)
B = np.random.rand(4)

# 广播机制会将 B 扩展为 (1, 1, 4),然后再扩展为 (2, 3, 4)
result = A * B

print("A.shape:", A.shape) # A.shape: (2, 3, 4)
print("B.shape:", B.shape) # B.shape: (4,)
print("result.shape:", result.shape) # 逐元素乘法操作成功进行,结果的形状为 (2, 3, 4)

示例 2: 形状为 (2, 3, 4) 和 (3, 4) 的张量

python 复制代码
import numpy as np

# 张量 A 的形状为 (2, 3, 4)
A = np.random.rand(2, 3, 4)

# 张量 B 的形状为 (3, 4)
B = np.random.rand(3, 4)

# 广播机制会将 B 扩展为 (1, 3, 4),然后再扩展为 (2, 3, 4)
result = A * B

print("A.shape:", A.shape)
print("B.shape:", B.shape)
print("result.shape:", result.shape)

示例 3: 形状为 (2, 1, 3) 和 (3, 4) 的张量

python 复制代码
import numpy as np

# 张量 A 的形状为 (2, 1, 3)
A = np.random.rand(2, 1, 3)

# 张量 B 的形状为 (3, 4)
B = np.random.rand(3, 4)

# 广播机制会将 A 扩展为 (2, 1, 3),B 扩展为 (1, 3, 4)
# 矩阵乘法会在最后两个维度上进行
result = np.matmul(A, B)

print("A.shape:", A.shape)
print("B.shape:", B.shape)
print("result.shape:", result.shape)

广播机制允许我们在不同形状的张量之间进行逐元素操作或矩阵操作,而无需显式地扩展张量的形状。这大大简化了张量操作的复杂性,并提高了代码的可读性和效率。

相关推荐
jimbo_lee8 小时前
yocto 用法(随手笔记,记录以备不时之需)
笔记·yocto
钓了猫的鱼儿10 小时前
基于深度学习+AI的卷心菜目标检测与预警系统(Python源码+数据集+UI可视化界面+YOLOv11训练结果)
人工智能·深度学习·目标检测
汽车仪器仪表相关领域10 小时前
南华 NHA-604/605 汽车排放气体测试仪:国六b全适配高精度便携检测设备
大数据·人工智能·功能测试·深度学习·安全·fpga开发·压力测试
胡图图不糊涂^_^10 小时前
测试用例篇——设计测试用例的方法
笔记·学习·测试用例·判定表法·正交法生成用例测试·等价类·边界值
CV实验室11 小时前
Remote Sensing 29个SITS基准数据集综述:多模态遥感分类的新起点
人工智能·深度学习·计算机视觉·音视频
IT199511 小时前
Dify笔记-知识库创建后设置和召回测试
笔记·dify
飞翔中文网11 小时前
Java学习笔记之抽象类
java·笔记·学习
手写码匠13 小时前
华为云Flexus+DeepSeek征文|基于华为云Flexus X实例 + Dify + DeepSeek 构建企业级智能知识库问答系统实战
人工智能·深度学习·算法·aigc
lqqjuly13 小时前
语音识别:隐马尔可夫模型、深度学习与序列转导
人工智能·深度学习·语音识别