torch.gather的使用

torch.gather 函数的作用是按照指定的维度 dim 和索引 index 从输入张量 input 中收集数值。这个操作通常用于根据索引从一个维度中选择元素,并生成一个新的张量作为输出

1. 介绍

1.1 参数说明

  • input: 需要从中选取元素的原始张量。
  • dim: 沿着此维度选取元素。例如,如果 dim=0,则沿着第一个维度(通常是)选取;如果 dim=1,则沿着第二个维度(通常是)选取。
  • index: 一个长整型张量,包含要选取的索引。index 的形状应该与 input 的形状相同,或者可以广播到 input 的形状。

1.2. 索引张量 index 的作用

  • index 张量中的每个元素指定了在 input 张量中 dim 维度上的位置。例如,如果 dim=1(列) 并且 index[i, j] 的值为 k,则从第 i 行的第 k 列选取元素
  • 根据 index 张量中的索引,在 input 张量中沿着 dim 维度收集元素。
  • 输出张量的形状与 index 张量的形状相同。这意味着除了 dim 维度之外,其他所有维度的大小都与 index 相同。

2. 示例

py 复制代码
import torch

# 创建一个输入张量
input_tensor = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])

# 创建一个索引张量,其形状与输入张量相同
index_tensor = torch.tensor([[0, 2, 1], [2, 0, 1], [1, 0, 2]])

# 使用 torch.gather 收集元素,沿着列(dim=1)
output_tensor = torch.gather(input_tensor, 1, index_tensor)

print(output_tensor)

说明

在上面的示例中,torch.gather(input_tensor, 1, index_tensor) 的输出将是:

  • 对于第 0 行,列索引(dim为1)为 [0, 2, 1],所以收集的元素是 [1, 3, 2]。
  • 对于第 1 行,列索引为 [2, 0, 1],所以收集的元素是 [6, 4, 5]。
  • 对于第 2 行,列索引为 [1, 0, 2],所以收集的元素是 [8, 7, 9]。

因此,输出张量将是:

shell 复制代码
tensor([[1, 3, 2],
        [6, 4, 5],
        [8, 7, 9]])

注意事项:确保 index 中的所有值都在有效范围内,即从 0 到 input.size(dim) - 1。如果 index 中有任何值超出了这个范围,将会引发错误。

相关推荐
中杯可乐多加冰1 小时前
【AI落地应用实战】AIGC赋能职场PPT汇报:从效率工具到辅助优化
人工智能·深度学习·神经网络·aigc·powerpoint·ai赋能
烟锁池塘柳01 小时前
【大模型】解码策略:Greedy Search、Beam Search、Top-k/Top-p、Temperature Sampling等
人工智能·深度学习·机器学习
风逸hhh1 小时前
python打卡day58@浙大疏锦行
开发语言·python
盼小辉丶1 小时前
PyTorch实战(14)——条件生成对抗网络(conditional GAN,cGAN)
人工智能·pytorch·生成对抗网络
zzc9212 小时前
时频图数据集更正程序,去除坐标轴白边及调整对应的标签值
人工智能·深度学习·数据集·标签·时频图·更正·白边
烛阴2 小时前
一文搞懂 Python 闭包:让你的代码瞬间“高级”起来!
前端·python
JosieBook2 小时前
【Java编程动手学】Java中的数组与集合
java·开发语言·python
Blossom.1183 小时前
机器学习在智能供应链中的应用:需求预测与物流优化
人工智能·深度学习·神经网络·机器学习·计算机视觉·机器人·语音识别
Gyoku Mint3 小时前
深度学习×第4卷:Pytorch实战——她第一次用张量去拟合你的轨迹
人工智能·pytorch·python·深度学习·神经网络·算法·聚类
m0_751336395 小时前
突破性进展:超短等离子体脉冲实现单电子量子干涉,为飞行量子比特奠定基础
人工智能·深度学习·量子计算·材料科学·光子器件·光子学·无线电电子