Pytorch 张量的scatter_add_方法介绍

torch.Tensor.scatter_add_ 是 PyTorch 中的一个原地操作(in-place operation),用于将一个源张量(src)中的值根据指定的索引(index)累加到目标张量(self)中。它常用于分布式计算、加权聚合以及自定义深度学习层等场景。

函数签名

复制代码
Tensor.scatter_add_(dim, index, src) → Tensor
参数说明
  1. dim (int):指定沿着哪个维度进行索引和累加。

  2. index (LongTensor) :一个整数类型的张量,包含要累加的索引位置。index 的形状应与 src 相同,除了指定的维度 dim

  3. src (Tensor):源张量,包含要累加到目标张量的值。

功能

scatter_add_ 会根据 index 中的索引,将 src 中的值累加到目标张量 self 的指定位置。对于每个值,其目标位置由 index 指定,而其他维度的位置由其在 src 中的位置决定。

操作逻辑

对于一个三维张量,scatter_add_ 的更新规则如下:

复制代码
self[index[i][j][k]][j][k] += src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] += src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] += src[i][j][k]  # if dim == 2

示例

以下是一个简单的二维张量示例:

Python复制

复制代码
import torch

# 初始化目标张量
input_tensor = torch.zeros(3, 5)

# 源张量
src = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]], dtype=torch.float32)

# 索引张量
index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]], dtype=torch.long)

# 沿着维度 0 进行累加
input_tensor.scatter_add_(0, index, src)

print(input_tensor)

输出:

复制代码
tensor([[ 1.,  2.,  3.,  4.,  5.],
        [ 0.,  7.,  0.,  9.,  0.],
        [ 6.,  0.,  8.,  0., 10.]])

详细解析

  1. 目标张量input_tensor 是一个形状为 (3, 5) 的零张量。

  2. 源张量src 是一个形状为 (2, 5) 的张量,包含要累加的值。

  3. 索引张量index 是一个形状为 (2, 5) 的整数张量,指定 src 中的值应该累加到 input_tensor 的哪些位置。

  4. 累加操作

    • scatter_add_ 沿着维度 0 进行操作。

    • index 中的每个值指定了 src 中对应值的目标位置。

    • 例如:

      • index[0, 0] = 0,表示 src[0, 0] = 1 应该累加到 input_tensor[0, 0]

      • index[1, 1] = 0,表示 src[1, 1] = 7 应该累加到 input_tensor[0, 1]

注意事项

  1. 形状要求

    • indexsrc 的形状必须与目标张量 self 的形状兼容。

    • index.size(d) <= src.size(d) 对所有维度 d 成立。

    • index.size(d) <= self.size(d) 对所有维度 d != dim 成立。

  2. 非确定性行为

    • 在 CUDA 设备上,scatter_add_ 的行为可能是非确定性的。
  3. 反向传播

    • 反向传播仅在 src.shape == index.shape 时实现。
  4. 原地操作

    • scatter_add_ 是一个原地操作,会直接修改目标张量 self

总结

torch.Tensor.scatter_add_ 是一个强大的工具,用于将源张量中的值根据索引累加到目标张量中。它在处理稀疏更新和聚合操作时非常有用,尤其适合需要在特定位置累加值的场景。

相关推荐
m0_564876843 分钟前
提示词工程手册学习
人工智能·python·深度学习·学习
波诺波21 分钟前
p1项目system_model.py代码
开发语言·python
AI精钢28 分钟前
谷歌时隔一年发布“更加开源“的 Gemma 4,意图何为?
人工智能·云原生·开源·aigc
静心观复30 分钟前
Python 虚拟环境与 pipx 详解
开发语言·python
卷心菜狗33 分钟前
Re.从零开始使用Python构建本地大模型网页智慧聊天机器人
开发语言·python·机器人
洞见新研社37 分钟前
从算力到电力,谁在搭建AI时代的“能源基座”?
人工智能·能源
书到用时方恨少!1 小时前
Python NumPy 使用指南:科学计算的基石
开发语言·python·numpy
小程故事多_801 小时前
自然语言智能体控制框架,重塑AI Agent的协作与执行范式
人工智能·架构·aigc·ai编程·harness
2501_933329551 小时前
技术深度拆解:Infoseek舆情系统的全链路架构与核心实现
开发语言·人工智能·分布式·架构
aosky1 小时前
OmniVoice:支持 600+ 语言的零样本语音克隆 TTS 系统
人工智能·tts