Pytorch 张量的scatter_add_方法介绍

torch.Tensor.scatter_add_ 是 PyTorch 中的一个原地操作(in-place operation),用于将一个源张量(src)中的值根据指定的索引(index)累加到目标张量(self)中。它常用于分布式计算、加权聚合以及自定义深度学习层等场景。

函数签名

复制代码
Tensor.scatter_add_(dim, index, src) → Tensor
参数说明
  1. dim (int):指定沿着哪个维度进行索引和累加。

  2. index (LongTensor) :一个整数类型的张量,包含要累加的索引位置。index 的形状应与 src 相同,除了指定的维度 dim

  3. src (Tensor):源张量,包含要累加到目标张量的值。

功能

scatter_add_ 会根据 index 中的索引,将 src 中的值累加到目标张量 self 的指定位置。对于每个值,其目标位置由 index 指定,而其他维度的位置由其在 src 中的位置决定。

操作逻辑

对于一个三维张量,scatter_add_ 的更新规则如下:

复制代码
self[index[i][j][k]][j][k] += src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] += src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] += src[i][j][k]  # if dim == 2

示例

以下是一个简单的二维张量示例:

Python复制

复制代码
import torch

# 初始化目标张量
input_tensor = torch.zeros(3, 5)

# 源张量
src = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]], dtype=torch.float32)

# 索引张量
index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]], dtype=torch.long)

# 沿着维度 0 进行累加
input_tensor.scatter_add_(0, index, src)

print(input_tensor)

输出:

复制代码
tensor([[ 1.,  2.,  3.,  4.,  5.],
        [ 0.,  7.,  0.,  9.,  0.],
        [ 6.,  0.,  8.,  0., 10.]])

详细解析

  1. 目标张量input_tensor 是一个形状为 (3, 5) 的零张量。

  2. 源张量src 是一个形状为 (2, 5) 的张量,包含要累加的值。

  3. 索引张量index 是一个形状为 (2, 5) 的整数张量,指定 src 中的值应该累加到 input_tensor 的哪些位置。

  4. 累加操作

    • scatter_add_ 沿着维度 0 进行操作。

    • index 中的每个值指定了 src 中对应值的目标位置。

    • 例如:

      • index[0, 0] = 0,表示 src[0, 0] = 1 应该累加到 input_tensor[0, 0]

      • index[1, 1] = 0,表示 src[1, 1] = 7 应该累加到 input_tensor[0, 1]

注意事项

  1. 形状要求

    • indexsrc 的形状必须与目标张量 self 的形状兼容。

    • index.size(d) <= src.size(d) 对所有维度 d 成立。

    • index.size(d) <= self.size(d) 对所有维度 d != dim 成立。

  2. 非确定性行为

    • 在 CUDA 设备上,scatter_add_ 的行为可能是非确定性的。
  3. 反向传播

    • 反向传播仅在 src.shape == index.shape 时实现。
  4. 原地操作

    • scatter_add_ 是一个原地操作,会直接修改目标张量 self

总结

torch.Tensor.scatter_add_ 是一个强大的工具,用于将源张量中的值根据索引累加到目标张量中。它在处理稀疏更新和聚合操作时非常有用,尤其适合需要在特定位置累加值的场景。

相关推荐
杜子不疼.18 小时前
【Linux】进程的初步探险:基本概念与基本操作
linux·人工智能·ai
可触的未来,发芽的智生18 小时前
触摸未来2025.10.04:当神经网络拥有了内在记忆……
人工智能·python·神经网络·算法·架构
PKNLP18 小时前
深度学习之神经网络2(Neural Network)
人工智能·深度学习·神经网络
格林威18 小时前
常规的变焦镜头有哪些类型?能做什么?
人工智能·数码相机·opencv·计算机视觉·视觉检测·机器视觉·工业镜头
蔗理苦19 小时前
2025-10-07 Python不基础 20——全局变量与自由变量
开发语言·python
xiaohanbao0919 小时前
理解神经网络流程
python·神经网络
韩立学长19 小时前
【开题答辩实录分享】以《基于Python的旅游网站数据爬虫研究》为例进行答辩实录分享
python·旅游
心无旁骛~19 小时前
【OpenArm|Control】openarm机械臂ROS2仿真控制
人工智能·ros
程序员陆业聪19 小时前
AI智能体的未来:从语言泛化到交互革命
人工智能
小小程序媛(*^▽^*)19 小时前
第十二届全国社会媒体处理大会笔记
人工智能·笔记·学习·ai