Pytorch 张量的scatter_add_方法介绍

torch.Tensor.scatter_add_ 是 PyTorch 中的一个原地操作(in-place operation),用于将一个源张量(src)中的值根据指定的索引(index)累加到目标张量(self)中。它常用于分布式计算、加权聚合以及自定义深度学习层等场景。

函数签名

复制代码
Tensor.scatter_add_(dim, index, src) → Tensor
参数说明
  1. dim (int):指定沿着哪个维度进行索引和累加。

  2. index (LongTensor) :一个整数类型的张量,包含要累加的索引位置。index 的形状应与 src 相同,除了指定的维度 dim

  3. src (Tensor):源张量,包含要累加到目标张量的值。

功能

scatter_add_ 会根据 index 中的索引,将 src 中的值累加到目标张量 self 的指定位置。对于每个值,其目标位置由 index 指定,而其他维度的位置由其在 src 中的位置决定。

操作逻辑

对于一个三维张量,scatter_add_ 的更新规则如下:

复制代码
self[index[i][j][k]][j][k] += src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] += src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] += src[i][j][k]  # if dim == 2

示例

以下是一个简单的二维张量示例:

Python复制

复制代码
import torch

# 初始化目标张量
input_tensor = torch.zeros(3, 5)

# 源张量
src = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]], dtype=torch.float32)

# 索引张量
index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]], dtype=torch.long)

# 沿着维度 0 进行累加
input_tensor.scatter_add_(0, index, src)

print(input_tensor)

输出:

复制代码
tensor([[ 1.,  2.,  3.,  4.,  5.],
        [ 0.,  7.,  0.,  9.,  0.],
        [ 6.,  0.,  8.,  0., 10.]])

详细解析

  1. 目标张量input_tensor 是一个形状为 (3, 5) 的零张量。

  2. 源张量src 是一个形状为 (2, 5) 的张量,包含要累加的值。

  3. 索引张量index 是一个形状为 (2, 5) 的整数张量,指定 src 中的值应该累加到 input_tensor 的哪些位置。

  4. 累加操作

    • scatter_add_ 沿着维度 0 进行操作。

    • index 中的每个值指定了 src 中对应值的目标位置。

    • 例如:

      • index[0, 0] = 0,表示 src[0, 0] = 1 应该累加到 input_tensor[0, 0]

      • index[1, 1] = 0,表示 src[1, 1] = 7 应该累加到 input_tensor[0, 1]

注意事项

  1. 形状要求

    • indexsrc 的形状必须与目标张量 self 的形状兼容。

    • index.size(d) <= src.size(d) 对所有维度 d 成立。

    • index.size(d) <= self.size(d) 对所有维度 d != dim 成立。

  2. 非确定性行为

    • 在 CUDA 设备上,scatter_add_ 的行为可能是非确定性的。
  3. 反向传播

    • 反向传播仅在 src.shape == index.shape 时实现。
  4. 原地操作

    • scatter_add_ 是一个原地操作,会直接修改目标张量 self

总结

torch.Tensor.scatter_add_ 是一个强大的工具,用于将源张量中的值根据索引累加到目标张量中。它在处理稀疏更新和聚合操作时非常有用,尤其适合需要在特定位置累加值的场景。

相关推荐
rongcj12 小时前
2026,“硅基经济”的时代正在悄然来临
人工智能
狼叔也疯狂12 小时前
英语启蒙SSS绘本第一辑50册高清PDF可打印
人工智能·全文检索
心静财富之门12 小时前
退出 for 循环,break和continue 语句
开发语言·python
WJSKad123512 小时前
YOLO11-FDPN-DASI实现羽毛球拍与球的实时检测与识别研究
python
万行12 小时前
机器学习&第四章支持向量机
人工智能·机器学习·支持向量机
幻云201012 小时前
Next.js之道:从入门到精通
人工智能·python
0和1的舞者12 小时前
GUI自动化测试详解(三):测试框架pytest完全指南
自动化测试·python·测试开发·自动化·pytest·测试
予枫的编程笔记12 小时前
【Java集合】深入浅出 Java HashMap:从链表到红黑树的“进化”之路
java·开发语言·数据结构·人工智能·链表·哈希算法
llddycidy12 小时前
峰值需求预测中的机器学习:基础、趋势和见解(最新文献)
网络·人工智能·深度学习
larance12 小时前
机器学习的一些基本知识
人工智能·机器学习