基于PyTorch的深度学习2——广播

PyTorch中的广播机制(Broadcasting Mechanism)是一种强大的功能,它允许不同形状的张量在进行算术运算时自动扩展其维度,从而使得这些操作成为可能,而无需显式地复制数据。这种机制极大地简化了代码,并提高了效率。

广播规则

广播机制遵循以下几条基本规则:

  1. 每个张量至少有一个维度
  2. 从后往前比较张量的各个维度(即从最后一个维度到第一个维度)。两个张量的对应维度要么相等,要么其中一个为1,或者一个张量在此维度上没有尺寸(即此维度不存在)。
  3. 如果某个维度上的大小是1,则该维度会被重复使用以匹配另一个张量的相应维度大小。
  4. 最终结果的形状由各输入张量中每个维度的最大值决定。

如果满足上述条件,那么这两个张量就是"广播兼容"的,可以执行元素级的操作如加法、减法等。

复制代码
import torch

A = torch.tensor([[1, 2, 3], [4, 5, 6]])  # shape: (2, 3)
B = torch.tensor([10, 20, 30])            # shape: (3)

result = A + B
# 结果:
# tensor([[11, 22, 33],
#         [14, 25, 36]])

C = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])  # shape: (2, 2, 2)
D = torch.tensor([[10, 20], [30, 40]])                   # shape: (2, 2)
result = C + D
# 结果:
# tensor([[[11, 22],
#          [33, 44]],
#         [[15, 26],
#          [37, 48]]])

接下来进一步进行演示。

复制代码
import torch
import numpy as np

# 创建NumPy数组
A = np.arange(0, 40, 10).reshape(4, 1)  # 形状为 (4, 1)
B = np.arange(0, 3)                     # 形状为 (3,)

# 将NumPy数组转换为PyTorch Tensor
A1 = torch.from_numpy(A)  # 形状为 (4, 1)
B1 = torch.from_numpy(B)  # 形状为 (3,)

# 使用广播机制自动扩展
C = A1 + B1
print("Using broadcasting:")
print(C)

# 手动实现广播
# 根据规则1,B1需要向A1看齐,把B变为(1, 3)
B2 = B1.unsqueeze(0)  # 形状变为 (1, 3)

# 使用expand函数重复数组,分别得到4x3的矩阵
A2 = A1.expand(4, 3)  # 形状变为 (4, 3)
B3 = B2.expand(4, 3)  # 形状变为 (4, 3)

# 然后进行相加,C1与C结果一致
C1 = A2 + B3
print("Manual broadcasting:")
print(C1)

无论你是通过自动广播还是手动模拟广播机制,最终的结果都是相同的:

复制代码
Using broadcasting:
tensor([[ 0,  1,  2],
        [10, 11, 12],
        [20, 21, 22],
        [30, 31, 32]], dtype=torch.int32)

Manual broadcasting:
tensor([[ 0,  1,  2],
        [10, 11, 12],
        [20, 21, 22],
        [30, 31, 32]], dtype=torch.int32)

通过上述代码和解析,我们了解到:

  • 广播机制允许不同形状的张量进行元素级的操作,而无需显式地复制数据。
  • unsqueeze 函数可以在指定位置插入一个新的维度,这对于准备广播非常有用。
  • expand 方法可以将张量扩展到目标形状,但它不会分配新的内存,而是返回一个视图,除非必要时才会复制数据。
相关推荐
亚马逊云开发者2 分钟前
GenDev 智能开发:Amazon Q Developer CLI 赋能Amazon Code Family实现代码审核
人工智能
CoovallyAIHub7 分钟前
全球OCR新标杆!百度0.9B小模型斩获四项SOTA,读懂复杂文档像人一样自然
深度学习·算法·计算机视觉
weixin_3776348410 分钟前
【强化学习】RLMT强制 CoT提升训练效果
人工智能·算法·机器学习
Francek Chen17 分钟前
【深度学习计算机视觉】14:实战Kaggle比赛:狗的品种识别(ImageNet Dogs)
人工智能·pytorch·深度学习·计算机视觉·kaggle·imagenet dogs
dxnb2221 分钟前
Datawhale25年10月组队学习:math for AI+Task3线性代数(下)
人工智能·学习·线性代数
渡我白衣37 分钟前
《未来的 AI 操作系统(四)——AgentOS 的内核设计:调度、记忆与自我反思机制》
人工智能·深度学习·机器学习·语言模型·数据挖掘·人机交互·语音识别
飞哥数智坊1 小时前
Claude Skills 实测体验:不用翻墙,GLM-4.6 也能玩转
人工智能·claude·chatglm (智谱)
FreeBuf_1 小时前
微软数字防御报告:AI成为新型威胁,自动化漏洞利用技术颠覆传统
人工智能·microsoft·自动化
MoRanzhi12031 小时前
Pillow 基础图像操作与数据预处理
图像处理·python·深度学习·机器学习·numpy·pillow·数据预处理
IT_陈寒1 小时前
Vue3性能优化实战:这7个技巧让我的应用加载速度提升50%!
前端·人工智能·后端