基于PyTorch的深度学习2——广播

PyTorch中的广播机制(Broadcasting Mechanism)是一种强大的功能,它允许不同形状的张量在进行算术运算时自动扩展其维度,从而使得这些操作成为可能,而无需显式地复制数据。这种机制极大地简化了代码,并提高了效率。

广播规则

广播机制遵循以下几条基本规则:

  1. 每个张量至少有一个维度
  2. 从后往前比较张量的各个维度(即从最后一个维度到第一个维度)。两个张量的对应维度要么相等,要么其中一个为1,或者一个张量在此维度上没有尺寸(即此维度不存在)。
  3. 如果某个维度上的大小是1,则该维度会被重复使用以匹配另一个张量的相应维度大小。
  4. 最终结果的形状由各输入张量中每个维度的最大值决定。

如果满足上述条件,那么这两个张量就是"广播兼容"的,可以执行元素级的操作如加法、减法等。

import torch

A = torch.tensor([[1, 2, 3], [4, 5, 6]])  # shape: (2, 3)
B = torch.tensor([10, 20, 30])            # shape: (3)

result = A + B
# 结果:
# tensor([[11, 22, 33],
#         [14, 25, 36]])

C = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])  # shape: (2, 2, 2)
D = torch.tensor([[10, 20], [30, 40]])                   # shape: (2, 2)
result = C + D
# 结果:
# tensor([[[11, 22],
#          [33, 44]],
#         [[15, 26],
#          [37, 48]]])

接下来进一步进行演示。

import torch
import numpy as np

# 创建NumPy数组
A = np.arange(0, 40, 10).reshape(4, 1)  # 形状为 (4, 1)
B = np.arange(0, 3)                     # 形状为 (3,)

# 将NumPy数组转换为PyTorch Tensor
A1 = torch.from_numpy(A)  # 形状为 (4, 1)
B1 = torch.from_numpy(B)  # 形状为 (3,)

# 使用广播机制自动扩展
C = A1 + B1
print("Using broadcasting:")
print(C)

# 手动实现广播
# 根据规则1,B1需要向A1看齐,把B变为(1, 3)
B2 = B1.unsqueeze(0)  # 形状变为 (1, 3)

# 使用expand函数重复数组,分别得到4x3的矩阵
A2 = A1.expand(4, 3)  # 形状变为 (4, 3)
B3 = B2.expand(4, 3)  # 形状变为 (4, 3)

# 然后进行相加,C1与C结果一致
C1 = A2 + B3
print("Manual broadcasting:")
print(C1)

无论你是通过自动广播还是手动模拟广播机制,最终的结果都是相同的:

Using broadcasting:
tensor([[ 0,  1,  2],
        [10, 11, 12],
        [20, 21, 22],
        [30, 31, 32]], dtype=torch.int32)

Manual broadcasting:
tensor([[ 0,  1,  2],
        [10, 11, 12],
        [20, 21, 22],
        [30, 31, 32]], dtype=torch.int32)

通过上述代码和解析,我们了解到:

  • 广播机制允许不同形状的张量进行元素级的操作,而无需显式地复制数据。
  • unsqueeze 函数可以在指定位置插入一个新的维度,这对于准备广播非常有用。
  • expand 方法可以将张量扩展到目标形状,但它不会分配新的内存,而是返回一个视图,除非必要时才会复制数据。
相关推荐
富 贵 儿 ¥16 分钟前
深度学习、宽度学习、持续学习与终身学习:全面解析与其在大模型方面的应用
人工智能·深度学习·学习
佛州小李哥18 分钟前
我代表中国受邀在亚马逊云科技全球云计算大会re:Invent中技术演讲
运维·人工智能·科技·云计算·aws·云安全·亚马逊云科技
今天炼丹了吗25 分钟前
RT-DETR融合YOLOv12中的R-ELAN结构
人工智能·深度学习·计算机视觉
视觉语言导航28 分钟前
微软具身智能感知交互多面手!Magma:基于基础模型的多模态AI智能体
人工智能·深度学习·具身智能
CITY_OF_MO_GY37 分钟前
RAG组件:向量数据库(Milvus)
人工智能·milvus
东临碣石8244 分钟前
【AI论文】MedVLM-R1:通过强化学习激励视觉语言模型(VLMs)的医疗推理能力
人工智能·语言模型·自然语言处理
飞3001 小时前
淘天集团算法岗-计算机视觉(T-Star Lab)内推
人工智能·算法·计算机视觉·业界资讯
正宗咸豆花1 小时前
【PromptCoder + Cursor】利用AI智能编辑器快速实现设计稿
前端·人工智能·编辑器·prompt·提示词
CP-DD1 小时前
Pycharm 远程执行无法显示 cv2.imshow() 的原因分析及解决方案
人工智能·opencv·计算机视觉
小青龙emmm1 小时前
机器学习(五)
人工智能·机器学习