基于PyTorch的深度学习2——广播

PyTorch中的广播机制(Broadcasting Mechanism)是一种强大的功能,它允许不同形状的张量在进行算术运算时自动扩展其维度,从而使得这些操作成为可能,而无需显式地复制数据。这种机制极大地简化了代码,并提高了效率。

广播规则

广播机制遵循以下几条基本规则:

  1. 每个张量至少有一个维度
  2. 从后往前比较张量的各个维度(即从最后一个维度到第一个维度)。两个张量的对应维度要么相等,要么其中一个为1,或者一个张量在此维度上没有尺寸(即此维度不存在)。
  3. 如果某个维度上的大小是1,则该维度会被重复使用以匹配另一个张量的相应维度大小。
  4. 最终结果的形状由各输入张量中每个维度的最大值决定。

如果满足上述条件,那么这两个张量就是"广播兼容"的,可以执行元素级的操作如加法、减法等。

复制代码
import torch

A = torch.tensor([[1, 2, 3], [4, 5, 6]])  # shape: (2, 3)
B = torch.tensor([10, 20, 30])            # shape: (3)

result = A + B
# 结果:
# tensor([[11, 22, 33],
#         [14, 25, 36]])

C = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])  # shape: (2, 2, 2)
D = torch.tensor([[10, 20], [30, 40]])                   # shape: (2, 2)
result = C + D
# 结果:
# tensor([[[11, 22],
#          [33, 44]],
#         [[15, 26],
#          [37, 48]]])

接下来进一步进行演示。

复制代码
import torch
import numpy as np

# 创建NumPy数组
A = np.arange(0, 40, 10).reshape(4, 1)  # 形状为 (4, 1)
B = np.arange(0, 3)                     # 形状为 (3,)

# 将NumPy数组转换为PyTorch Tensor
A1 = torch.from_numpy(A)  # 形状为 (4, 1)
B1 = torch.from_numpy(B)  # 形状为 (3,)

# 使用广播机制自动扩展
C = A1 + B1
print("Using broadcasting:")
print(C)

# 手动实现广播
# 根据规则1,B1需要向A1看齐,把B变为(1, 3)
B2 = B1.unsqueeze(0)  # 形状变为 (1, 3)

# 使用expand函数重复数组,分别得到4x3的矩阵
A2 = A1.expand(4, 3)  # 形状变为 (4, 3)
B3 = B2.expand(4, 3)  # 形状变为 (4, 3)

# 然后进行相加,C1与C结果一致
C1 = A2 + B3
print("Manual broadcasting:")
print(C1)

无论你是通过自动广播还是手动模拟广播机制,最终的结果都是相同的:

复制代码
Using broadcasting:
tensor([[ 0,  1,  2],
        [10, 11, 12],
        [20, 21, 22],
        [30, 31, 32]], dtype=torch.int32)

Manual broadcasting:
tensor([[ 0,  1,  2],
        [10, 11, 12],
        [20, 21, 22],
        [30, 31, 32]], dtype=torch.int32)

通过上述代码和解析,我们了解到:

  • 广播机制允许不同形状的张量进行元素级的操作,而无需显式地复制数据。
  • unsqueeze 函数可以在指定位置插入一个新的维度,这对于准备广播非常有用。
  • expand 方法可以将张量扩展到目标形状,但它不会分配新的内存,而是返回一个视图,除非必要时才会复制数据。
相关推荐
胡耀超4 分钟前
标签体系设计与管理:从理论基础到智能化实践的综合指南
人工智能·python·深度学习·数据挖掘·大模型·用户画像·语义分析
开-悟8 分钟前
嵌入式编程-使用AI查找BUG的启发
c语言·人工智能·嵌入式硬件·bug
大咖分享课29 分钟前
开源模型与商用模型协同开发机制设计
人工智能·开源·ai模型
你不知道我是谁?37 分钟前
AI 应用于进攻性安全
人工智能·安全
reddingtons1 小时前
Adobe高阶技巧与设计师创意思维的进阶指南
人工智能·adobe·illustrator·设计师·photoshop·创意设计·aftereffects
机器之心1 小时前
刚刚,Grok4跑分曝光:「人类最后考试」拿下45%,是Gemini 2.5两倍,但网友不信
人工智能
蹦蹦跳跳真可爱5891 小时前
Python----大模型(使用api接口调用大模型)
人工智能·python·microsoft·语言模型
小爷毛毛_卓寿杰1 小时前
突破政务文档理解瓶颈:基于多模态大模型的智能解析系统详解
人工智能·llm
Mr.Winter`1 小时前
障碍感知 | 基于3D激光雷达的三维膨胀栅格地图构建(附ROS C++仿真)
人工智能·机器人·自动驾驶·ros·具身智能·环境感知
好开心啊没烦恼2 小时前
Python 数据分析:numpy,抽提,整数数组索引与基本索引扩展(元组传参)。听故事学知识点怎么这么容易?
开发语言·人工智能·python·数据挖掘·数据分析·numpy·pandas