PyTorch 中广播机制(Broadcasting)笔记

在 PyTorch 中存在广播(Broadcasting),广播是一种机制,用于自动扩展较小的张量以匹配较大张量的形状,从而使得它们能够进行元素级操作(如加法、减法、乘法等)。广播并不改变张量的实际数据,而是通过虚拟扩展来简化操作。

目录

广播机制的规则

  1. 如果两个张量的维度数量不同,则将较小的那个张量的形状前面补 1,直到两个张量的维度数量相同。

  2. 如果两个张量在某个维度上的大小不一致,但其中一个张量在该维度上的大小是 1,则可以在该维度上进行广播。

  3. 如果两个张量在任何维度上的大小既不相等也不为 1,则无法进行广播。

  4. 广播后的张量形状是每个维度上大小的最大值。

python 复制代码
import torch

# 示例 1: 形状不同的张量相加
a = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = torch.tensor([1, 2, 3])
# b 会被广播成 [[1, 2, 3], [1, 2, 3]]
result = a + b
print(result)
# 输出:
# tensor([[ 2,  4,  6],
#         [ 5,  7,  9]])

# 示例 2: 形状不同的张量相乘
a = torch.tensor([[1, 2], [3, 4], [5, 6]])
b = torch.tensor([1, 2])
# b 会被广播成 [[1, 2], [1, 2], [1, 2]]
result = a * b
print(result)
# 输出:
# tensor([[ 1,  4],
#         [ 3,  8],
#         [ 5, 12]])

# 示例 3: 形状不同的张量相加
a = torch.tensor([1, 2, 3])
b = torch.tensor([[1], [2], [3]])
# a 会被广播成 [[1, 2, 3], [1, 2, 3], [1, 2, 3]]
# b 会被广播成 [[1, 1, 1], [2, 2, 2], [3, 3, 3]]
result = a + b
print(result)
# 输出:
# tensor([[2, 3, 4],
#         [3, 4, 5],
#         [4, 5, 6]])

广播机制在张量乘法中的应用

在进行张量乘法时,广播机制也可以简化操作,尤其是在批次维度不同时:

python 复制代码
import torch

# 张量 A 的形状是 (2, 3, 4)
A = torch.randn(2, 3, 4)

# 张量 B 的形状是 (4, 5)
B = torch.randn(4, 5)

# 使用广播机制进行张量乘法
# B 会被广播成 (2, 4, 5)
result = torch.matmul(A, B)
print(result.shape)
# 输出:
# torch.Size([2, 3, 5])

判断两个张量是否可以进行广播操作

主要遵循以下规则:

  1. 如果两个张量的维度数量不同,则将较小的那个张量的形状前面补 1,直到两个张量的维度数量相同。
  2. 从最后一个维度开始,逐个维度向前检查:
  • 如果两个张量在某个维度上的大小相同,或者其中一个张量在该维度上的大小是 1,则可以在该维度上进行广播。
  • 如果两个张量在任何维度上的大小既不相等也不为 1,则无法进行广播。

具体步骤

假设有两个张量 A 和 B,其形状分别为 shapeA 和 shapeB。

  1. 对齐维度:将较小的形状前面补 1,使得两个形状的长度相同。
  2. 逐维度检查:从最后一个维度开始,逐个维度向前检查:
  • 如果两个维度大小相同,或其中一个维度大小为 1,则该维度可以进行广播。
  • 如果两个维度大小既不相同且都不为 1,则无法进行广播。

以下是一个 Python 函数,用于判断两个张量是否可以进行广播操作:

python 复制代码
def can_broadcast(shapeA, shapeB):
    # 对齐维度
    lenA, lenB = len(shapeA), len(shapeB)
    if lenA < lenB:
        shapeA = (1,) * (lenB - lenA) + shapeA
    elif lenB < lenA:
        shapeB = (1,) * (lenA - lenB) + shapeB

    # 逐维度检查
    for dimA, dimB in zip(shapeA, shapeB):
        if dimA != dimB and dimA != 1 and dimB != 1:
            return False
    return True

# 示例
shapeA = (2, 3, 4)
shapeB = (4, 5)
print(can_broadcast(shapeA, shapeB))  # 输出: False
'''
对齐维度:(2, 3, 4) 和 (1, 4, 5)
第三个维度:4 和 5,不相等且都不为 1,无法广播。
第二个维度:3 和 4,不相等且都不为 1,无法广播。
第一个维度:2 和 1,可以广播。
'''

shapeA = (2, 3, 4)
shapeB = (1, 4, 5)
print(can_broadcast(shapeA, shapeB))  # 输出: False
'''
对齐维度:(2, 3, 4) 和 (1, 4, 5)
第三个维度:4 和 5,无法广播。
第二个维度:3 和 4,无法广播。
第一个维度:2 和 1,可以广播。
'''

shapeA = (3, 4)
shapeB = (2, 1, 4)
print(can_broadcast(shapeA, shapeB))  # 输出: True
'''
对齐维度:(1, 3, 4) 和 (2, 1, 4)
'''
shapeA = (5,)
shapeB = (1, 5)
print(can_broadcast(shapeA, shapeB)) # 输出: True
'''
对齐维度:将较小的形状前面补 1,使得两个形状的长度相同。得到 (1, 5) 和 (1, 5)
'''

广播机制结合张量乘法例子

示例 1: 形状为 (2, 3, 4) 和 (4,) 的张量

python 复制代码
import numpy as np

# 张量 A 的形状为 (2, 3, 4)
A = np.random.rand(2, 3, 4)

# 张量 B 的形状为 (4,)
B = np.random.rand(4)

# 广播机制会将 B 扩展为 (1, 1, 4),然后再扩展为 (2, 3, 4)
result = A * B

print("A.shape:", A.shape) # A.shape: (2, 3, 4)
print("B.shape:", B.shape) # B.shape: (4,)
print("result.shape:", result.shape) # 逐元素乘法操作成功进行,结果的形状为 (2, 3, 4)

示例 2: 形状为 (2, 3, 4) 和 (3, 4) 的张量

python 复制代码
import numpy as np

# 张量 A 的形状为 (2, 3, 4)
A = np.random.rand(2, 3, 4)

# 张量 B 的形状为 (3, 4)
B = np.random.rand(3, 4)

# 广播机制会将 B 扩展为 (1, 3, 4),然后再扩展为 (2, 3, 4)
result = A * B

print("A.shape:", A.shape)
print("B.shape:", B.shape)
print("result.shape:", result.shape)

示例 3: 形状为 (2, 1, 3) 和 (3, 4) 的张量

python 复制代码
import numpy as np

# 张量 A 的形状为 (2, 1, 3)
A = np.random.rand(2, 1, 3)

# 张量 B 的形状为 (3, 4)
B = np.random.rand(3, 4)

# 广播机制会将 A 扩展为 (2, 1, 3),B 扩展为 (1, 3, 4)
# 矩阵乘法会在最后两个维度上进行
result = np.matmul(A, B)

print("A.shape:", A.shape)
print("B.shape:", B.shape)
print("result.shape:", result.shape)

广播机制允许我们在不同形状的张量之间进行逐元素操作或矩阵操作,而无需显式地扩展张量的形状。这大大简化了张量操作的复杂性,并提高了代码的可读性和效率。

相关推荐
哥廷根数学学派3 小时前
基于Maximin的异常检测方法(MATLAB)
开发语言·人工智能·深度学习·机器学习
xrgs_shz3 小时前
人工智能、机器学习、神经网络、深度学习和卷积神经网络的概念和关系
人工智能·深度学习·神经网络·机器学习·卷积神经网络
muren6 小时前
昇思MindSpore学习笔记2-01 LLM原理和实践 --基于 MindSpore 实现 BERT 对话情绪识别
笔记·深度学习·学习
渔舟小调7 小时前
技术浅谈:如何入门一门编程语言
经验分享·笔记
Beast Cheng7 小时前
07-7.1.1 查找的基本概念
数据结构·笔记·考研·算法·学习方法
~山花~8 小时前
Vue3实战笔记(64)—Vue 3自定义指令的艺术:实战中的最佳实践
javascript·vue.js·笔记·自定义指令·vue3实战笔记
鱼仰泳8 小时前
【笔记】解决 CSS:backface-visibility:hidden; 容器翻转 引起的容器内 input不可用
前端·css·笔记
Silver_7778 小时前
WIFI信号状态信息 CSI 深度学习篇之CNN(Matlab)
深度学习·神经网络·机器学习·matlab
HealthScience8 小时前
torch.where()
人工智能·pytorch·深度学习
观鉴词recommend9 小时前
【c++刷题笔记-贪心】day30:56. 合并区间 、 738.单调递增的数字
c++·笔记·算法·leetcode