基于PyTorch的深度学习2——广播

PyTorch中的广播机制(Broadcasting Mechanism)是一种强大的功能,它允许不同形状的张量在进行算术运算时自动扩展其维度,从而使得这些操作成为可能,而无需显式地复制数据。这种机制极大地简化了代码,并提高了效率。

广播规则

广播机制遵循以下几条基本规则:

  1. 每个张量至少有一个维度
  2. 从后往前比较张量的各个维度(即从最后一个维度到第一个维度)。两个张量的对应维度要么相等,要么其中一个为1,或者一个张量在此维度上没有尺寸(即此维度不存在)。
  3. 如果某个维度上的大小是1,则该维度会被重复使用以匹配另一个张量的相应维度大小。
  4. 最终结果的形状由各输入张量中每个维度的最大值决定。

如果满足上述条件,那么这两个张量就是"广播兼容"的,可以执行元素级的操作如加法、减法等。

复制代码
import torch

A = torch.tensor([[1, 2, 3], [4, 5, 6]])  # shape: (2, 3)
B = torch.tensor([10, 20, 30])            # shape: (3)

result = A + B
# 结果:
# tensor([[11, 22, 33],
#         [14, 25, 36]])

C = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])  # shape: (2, 2, 2)
D = torch.tensor([[10, 20], [30, 40]])                   # shape: (2, 2)
result = C + D
# 结果:
# tensor([[[11, 22],
#          [33, 44]],
#         [[15, 26],
#          [37, 48]]])

接下来进一步进行演示。

复制代码
import torch
import numpy as np

# 创建NumPy数组
A = np.arange(0, 40, 10).reshape(4, 1)  # 形状为 (4, 1)
B = np.arange(0, 3)                     # 形状为 (3,)

# 将NumPy数组转换为PyTorch Tensor
A1 = torch.from_numpy(A)  # 形状为 (4, 1)
B1 = torch.from_numpy(B)  # 形状为 (3,)

# 使用广播机制自动扩展
C = A1 + B1
print("Using broadcasting:")
print(C)

# 手动实现广播
# 根据规则1,B1需要向A1看齐,把B变为(1, 3)
B2 = B1.unsqueeze(0)  # 形状变为 (1, 3)

# 使用expand函数重复数组,分别得到4x3的矩阵
A2 = A1.expand(4, 3)  # 形状变为 (4, 3)
B3 = B2.expand(4, 3)  # 形状变为 (4, 3)

# 然后进行相加,C1与C结果一致
C1 = A2 + B3
print("Manual broadcasting:")
print(C1)

无论你是通过自动广播还是手动模拟广播机制,最终的结果都是相同的:

复制代码
Using broadcasting:
tensor([[ 0,  1,  2],
        [10, 11, 12],
        [20, 21, 22],
        [30, 31, 32]], dtype=torch.int32)

Manual broadcasting:
tensor([[ 0,  1,  2],
        [10, 11, 12],
        [20, 21, 22],
        [30, 31, 32]], dtype=torch.int32)

通过上述代码和解析,我们了解到:

  • 广播机制允许不同形状的张量进行元素级的操作,而无需显式地复制数据。
  • unsqueeze 函数可以在指定位置插入一个新的维度,这对于准备广播非常有用。
  • expand 方法可以将张量扩展到目标形状,但它不会分配新的内存,而是返回一个视图,除非必要时才会复制数据。
相关推荐
低调小一20 小时前
AI 时代旧敏捷开发的核心矛盾与系统困境
人工智能·敏捷流程
红目香薰20 小时前
GitCode-我的运气的可量化方案-更新v5版本
人工智能·开源·文心一言·gitcode
黑客思维者20 小时前
机器学习071:深度学习【卷积神经网络】目标检测“三剑客”:YOLO、SSD、Faster R-CNN对比
深度学习·yolo·目标检测·机器学习·cnn·ssd·faster r-cnn
草莓熊Lotso20 小时前
脉脉独家【AI创作者xAMA】|当豆包手机遭遇“全网封杀”:AI学会操作手机,我们的饭碗还保得住吗?
运维·开发语言·人工智能·智能手机·脉脉
C7211BA20 小时前
通义灵码和Qoder的差异
大数据·人工智能
杜子不疼.20 小时前
脉脉AI创作者活动:聊聊AI时代技术人的真实出路
人工智能
散峰而望20 小时前
【Coze - AI Agent 开发平台】-- 你真的了解 Coze 吗
开发语言·人工智能·python·aigc·ai编程·ai写作
鸽芷咕21 小时前
【2025年度总结】时光知味,三载同行:落笔皆是沉淀,前行自有光芒
linux·c++·人工智能·2025年度总结
北山小恐龙21 小时前
卷积神经网络(CNN)与Transformer
深度学习·cnn·transformer
tap.AI21 小时前
Deepseek(七)去“AI 味儿”进阶:如何输出更具人情味与专业度?
人工智能