PyTorch 中广播机制(Broadcasting)笔记

在 PyTorch 中存在广播(Broadcasting),广播是一种机制,用于自动扩展较小的张量以匹配较大张量的形状,从而使得它们能够进行元素级操作(如加法、减法、乘法等)。广播并不改变张量的实际数据,而是通过虚拟扩展来简化操作。

目录

广播机制的规则

  1. 如果两个张量的维度数量不同,则将较小的那个张量的形状前面补 1,直到两个张量的维度数量相同。

  2. 如果两个张量在某个维度上的大小不一致,但其中一个张量在该维度上的大小是 1,则可以在该维度上进行广播。

  3. 如果两个张量在任何维度上的大小既不相等也不为 1,则无法进行广播。

  4. 广播后的张量形状是每个维度上大小的最大值。

python 复制代码
import torch

# 示例 1: 形状不同的张量相加
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([1, 2, 3])
# b 会被广播成 [[1, 2, 3], [1, 2, 3]]
result = a + b
print(result)
# 输出:
# tensor([[ 2,  4,  6],
#         [ 5,  7,  9]])

# 示例 2: 形状不同的张量相乘
a = torch.tensor([[1, 2], [3, 4], [5, 6]])
b = torch.tensor([1, 2])
# b 会被广播成 [[1, 2], [1, 2], [1, 2]]
result = a * b
print(result)
# 输出:
# tensor([[ 1,  4],
#         [ 3,  8],
#         [ 5, 12]])

# 示例 3: 形状不同的张量相加
a = torch.tensor([1, 2, 3])
b = torch.tensor([[1], [2], [3]])
# a 会被广播成 [[1, 2, 3], [1, 2, 3], [1, 2, 3]]
# b 会被广播成 [[1, 1, 1], [2, 2, 2], [3, 3, 3]]
result = a + b
print(result)
# 输出:
# tensor([[2, 3, 4],
#         [3, 4, 5],
#         [4, 5, 6]])

广播机制在张量乘法中的应用

在进行张量乘法时,广播机制也可以简化操作,尤其是在批次维度不同时:

python 复制代码
import torch

# 张量 A 的形状是 (2, 3, 4)
A = torch.randn(2, 3, 4)

# 张量 B 的形状是 (4, 5)
B = torch.randn(4, 5)

# 使用广播机制进行张量乘法
# B 会被广播成 (2, 4, 5)
result = torch.matmul(A, B)
print(result.shape)
# 输出:
# torch.Size([2, 3, 5])

判断两个张量是否可以进行广播操作

主要遵循以下规则:

  1. 如果两个张量的维度数量不同,则将较小的那个张量的形状前面补 1,直到两个张量的维度数量相同。
  2. 从最后一个维度开始,逐个维度向前检查:
  • 如果两个张量在某个维度上的大小相同,或者其中一个张量在该维度上的大小是 1,则可以在该维度上进行广播。
  • 如果两个张量在任何维度上的大小既不相等也不为 1,则无法进行广播。

具体步骤

假设有两个张量 A 和 B,其形状分别为 shapeA 和 shapeB。

  1. 对齐维度:将较小的形状前面补 1,使得两个形状的长度相同。
  2. 逐维度检查:从最后一个维度开始,逐个维度向前检查:
  • 如果两个维度大小相同,或其中一个维度大小为 1,则该维度可以进行广播。
  • 如果两个维度大小既不相同且都不为 1,则无法进行广播。

以下是一个 Python 函数,用于判断两个张量是否可以进行广播操作:

python 复制代码
def can_broadcast(shapeA, shapeB):
    # 对齐维度
    lenA, lenB = len(shapeA), len(shapeB)
    if lenA < lenB:
        shapeA = (1,) * (lenB - lenA) + shapeA
    elif lenB < lenA:
        shapeB = (1,) * (lenA - lenB) + shapeB

    # 逐维度检查
    for dimA, dimB in zip(shapeA, shapeB):
        if dimA != dimB and dimA != 1 and dimB != 1:
            return False
    return True

# 示例
shapeA = (2, 3, 4)
shapeB = (4, 5)
print(can_broadcast(shapeA, shapeB))  # 输出: False
'''
对齐维度:(2, 3, 4) 和 (1, 4, 5)
第三个维度:4 和 5,不相等且都不为 1,无法广播。
第二个维度:3 和 4,不相等且都不为 1,无法广播。
第一个维度:2 和 1,可以广播。
'''

shapeA = (2, 3, 4)
shapeB = (1, 4, 5)
print(can_broadcast(shapeA, shapeB))  # 输出: False
'''
对齐维度:(2, 3, 4) 和 (1, 4, 5)
第三个维度:4 和 5,无法广播。
第二个维度:3 和 4,无法广播。
第一个维度:2 和 1,可以广播。
'''

shapeA = (3, 4)
shapeB = (2, 1, 4)
print(can_broadcast(shapeA, shapeB))  # 输出: True
'''
对齐维度:(1, 3, 4) 和 (2, 1, 4)
'''
shapeA = (5,)
shapeB = (1, 5)
print(can_broadcast(shapeA, shapeB)) # 输出: True
'''
对齐维度:将较小的形状前面补 1,使得两个形状的长度相同。得到 (1, 5) 和 (1, 5)
'''

广播机制结合张量乘法例子

示例 1: 形状为 (2, 3, 4) 和 (4,) 的张量

python 复制代码
import numpy as np

# 张量 A 的形状为 (2, 3, 4)
A = np.random.rand(2, 3, 4)

# 张量 B 的形状为 (4,)
B = np.random.rand(4)

# 广播机制会将 B 扩展为 (1, 1, 4),然后再扩展为 (2, 3, 4)
result = A * B

print("A.shape:", A.shape) # A.shape: (2, 3, 4)
print("B.shape:", B.shape) # B.shape: (4,)
print("result.shape:", result.shape) # 逐元素乘法操作成功进行,结果的形状为 (2, 3, 4)

示例 2: 形状为 (2, 3, 4) 和 (3, 4) 的张量

python 复制代码
import numpy as np

# 张量 A 的形状为 (2, 3, 4)
A = np.random.rand(2, 3, 4)

# 张量 B 的形状为 (3, 4)
B = np.random.rand(3, 4)

# 广播机制会将 B 扩展为 (1, 3, 4),然后再扩展为 (2, 3, 4)
result = A * B

print("A.shape:", A.shape)
print("B.shape:", B.shape)
print("result.shape:", result.shape)

示例 3: 形状为 (2, 1, 3) 和 (3, 4) 的张量

python 复制代码
import numpy as np

# 张量 A 的形状为 (2, 1, 3)
A = np.random.rand(2, 1, 3)

# 张量 B 的形状为 (3, 4)
B = np.random.rand(3, 4)

# 广播机制会将 A 扩展为 (2, 1, 3),B 扩展为 (1, 3, 4)
# 矩阵乘法会在最后两个维度上进行
result = np.matmul(A, B)

print("A.shape:", A.shape)
print("B.shape:", B.shape)
print("result.shape:", result.shape)

广播机制允许我们在不同形状的张量之间进行逐元素操作或矩阵操作,而无需显式地扩展张量的形状。这大大简化了张量操作的复杂性,并提高了代码的可读性和效率。

相关推荐
Selina K27 分钟前
shell脚本知识点记录
笔记·shell
39 分钟前
开源竞争-数据驱动成长-11/05-大专生的思考
人工智能·笔记·学习·算法·机器学习
霍格沃兹测试开发学社测试人社区1 小时前
软件测试学习笔记丨Flask操作数据库-数据库和表的管理
软件测试·笔记·测试开发·学习·flask
幸运超级加倍~2 小时前
软件设计师-上午题-16 算法(4-5分)
笔记·算法
王俊山IT2 小时前
C++学习笔记----10、模块、头文件及各种主题(一)---- 模块(5)
开发语言·c++·笔记·学习
好喜欢吃红柚子2 小时前
万字长文解读空间、通道注意力机制机制和超详细代码逐行分析(SE,CBAM,SGE,CA,ECA,TA)
人工智能·pytorch·python·计算机视觉·cnn
羊小猪~~3 小时前
神经网络基础--什么是正向传播??什么是方向传播??
人工智能·pytorch·python·深度学习·神经网络·算法·机器学习
Yawesh_best3 小时前
思源笔记轻松连接本地Ollama大语言模型,开启AI写作新体验!
笔记·语言模型·ai写作
软工菜鸡3 小时前
预训练语言模型BERT——PaddleNLP中的预训练模型
大数据·人工智能·深度学习·算法·语言模型·自然语言处理·bert
哔哩哔哩技术4 小时前
B站S赛直播中的关键事件识别与应用
深度学习