PyTorch中的take_along_dim

技术背景

在此前的一篇博客中,我们介绍过take_along_axis这个算子的具体使用方法。这里针对于Pytorch的take_along_dim算子,再重新介绍一次。

Numpy版本使用

这里我们展示的案例是基于numpy-2.0.1版本实现的:

bash 复制代码
$ python3 -m pip show numpy
Name: numpy
Version: 2.0.1
Summary: Fundamental package for array computing in Python
Home-page: https://numpy.org
Author: Travis E. Oliphant et al.

示例如下:

python 复制代码
In [1]: import numpy as np

In [2]: a = np.arange(12).reshape((1,4,3))

In [3]: a
Out[3]: 
array([[[ 0,  1,  2],
        [ 3,  4,  5],
        [ 6,  7,  8],
        [ 9, 10, 11]]])

In [4]: idx = np.array([1,2])

In [6]: b = np.take_along_axis(a, idx[None,:,None], axis=1)

In [7]: b
Out[7]: 
array([[[3, 4, 5],
        [6, 7, 8]]])

In [8]: b = np.take_along_axis(a, idx[None,None,:], axis=2)

In [9]: b
Out[9]: 
array([[[ 1,  2],
        [ 4,  5],
        [ 7,  8],
        [10, 11]]])

在这个基础示例中,我们分别展示了同一个索引矩阵,在不同的维度上进行索引的结果。使用take_along_axis有一个默认的要求:原始数组和索引数组的维度数量需要保持一致。但是因为这里的索引矩阵是一维的,那么我们只要用slice的方法对索引矩阵进行扩维就好了。例如,我们需要在第二个维度进行提取,那么就可以用arr[None,:,None]来进行扩维。

PyTorch版实现

这里我们使用的torch是2.5.1的稳定版:

bash 复制代码
$ python3 -m pip show torch
Name: torch
Version: 2.5.1
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3-Clause
Location: /miniconda3/envs/pytorch/lib/python3.9/site-packages
Requires: filelock, fsspec, jinja2, networkx, sympy, typing-extensions
Required-by: torchaudio, torchmetrics, torchvision

相关的API接口文档如下:

其实实现起来跟numpy的操作是非常类似的:

python 复制代码
In [1]: import torch as tc

In [2]: a = tc.arange(12).reshape((1,4,3))

In [3]: idx = tc.tensor([1,2])

In [4]: b = tc.take_along_dim(a, idx[None,:,None], dim=1)

In [5]: b
Out[5]: 
tensor([[[3, 4, 5],
         [6, 7, 8]]])

In [6]: b = tc.take_along_dim(a, idx[None,None,:], dim=2)

In [7]: b
Out[7]: 
tensor([[[ 1,  2],
         [ 4,  5],
         [ 7,  8],
         [10, 11]]])

可以说是基本一致。那么同样的,也是要做一个扩维的处理。唯一一个不同的地方就是,在torch中是take_along_dim而不是像numpy或者mindspore中的take_along_axis,在torch中用dim替代了axis,包括函数名称和传入的关键词参数。

总结概要

接前面一篇take_along_axis的文章,本文主要介绍在PyTorch框架下,功能基本一样的函数take_along_dim。二者除了命名和一些关键词参数不一致之外,用法是一样的。需要注意的是,两者都要求输入的数组和索引数组维度数量一致。在特定场景下,需要手动进行扩维。

版权声明

本文首发链接为:https://www.cnblogs.com/dechinphy/p/take_along_dim.html

作者ID:DechinPhy

更多原著文章请参考:https://www.cnblogs.com/dechinphy/

打赏专用链接:https://www.cnblogs.com/dechinphy/gallery/image/379634.html

腾讯云专栏同步:https://cloud.tencent.com/developer/column/91958

相关推荐
smj2302_7968265216 分钟前
解决leetcode第3782题交替删除操作后最后剩下的整数
python·算法·leetcode
好奇龙猫28 分钟前
【AI学习-comfyUI学习-第十九节-comtrolnet艺术线处理器工作流-各个部分学习】
人工智能·学习
老蒋新思维1 小时前
从「流量算法」到「增长算法」:AI智能体如何重构企业增长的内在逻辑
大数据·网络·人工智能·重构·创始人ip·创客匠人·知识变现
苍何1 小时前
在全世界都教你做小红书图片的时候,我基于秒哒Pro做了个一键生成的网站。
人工智能
苍何1 小时前
用即梦视频3.5pro复刻爆款AI探班视频,直接发现一个AI片场!
人工智能
dulu~dulu1 小时前
机器学习题目总结(一)
人工智能·神经网络·决策树·机器学习·学习笔记·线性模型·模型评估与选择
gCode Teacher 格码致知1 小时前
Python基础教学:Python 3中的字符串在解释运行时的内存编码表示-由Deepseek产生
python·内存编码
苍何1 小时前
免费!漫画 PPT + 全文档讲解,这谁顶得住啊。。。
人工智能
苍何1 小时前
用 LiblibAI 做爆款动态海报,绝了!(附教程)
人工智能
翔云 OCR API1 小时前
承兑汇票识别接口技术解析与应用实践
开发语言·人工智能·python·计算机视觉·ocr