Pytorch 张量的scatter_add_方法介绍

torch.Tensor.scatter_add_ 是 PyTorch 中的一个原地操作(in-place operation),用于将一个源张量(src)中的值根据指定的索引(index)累加到目标张量(self)中。它常用于分布式计算、加权聚合以及自定义深度学习层等场景。

函数签名

复制代码
Tensor.scatter_add_(dim, index, src) → Tensor
参数说明
  1. dim (int):指定沿着哪个维度进行索引和累加。

  2. index (LongTensor) :一个整数类型的张量,包含要累加的索引位置。index 的形状应与 src 相同,除了指定的维度 dim

  3. src (Tensor):源张量,包含要累加到目标张量的值。

功能

scatter_add_ 会根据 index 中的索引,将 src 中的值累加到目标张量 self 的指定位置。对于每个值,其目标位置由 index 指定,而其他维度的位置由其在 src 中的位置决定。

操作逻辑

对于一个三维张量,scatter_add_ 的更新规则如下:

复制代码
self[index[i][j][k]][j][k] += src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] += src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] += src[i][j][k]  # if dim == 2

示例

以下是一个简单的二维张量示例:

Python复制

复制代码
import torch

# 初始化目标张量
input_tensor = torch.zeros(3, 5)

# 源张量
src = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]], dtype=torch.float32)

# 索引张量
index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]], dtype=torch.long)

# 沿着维度 0 进行累加
input_tensor.scatter_add_(0, index, src)

print(input_tensor)

输出:

复制代码
tensor([[ 1.,  2.,  3.,  4.,  5.],
        [ 0.,  7.,  0.,  9.,  0.],
        [ 6.,  0.,  8.,  0., 10.]])

详细解析

  1. 目标张量input_tensor 是一个形状为 (3, 5) 的零张量。

  2. 源张量src 是一个形状为 (2, 5) 的张量,包含要累加的值。

  3. 索引张量index 是一个形状为 (2, 5) 的整数张量,指定 src 中的值应该累加到 input_tensor 的哪些位置。

  4. 累加操作

    • scatter_add_ 沿着维度 0 进行操作。

    • index 中的每个值指定了 src 中对应值的目标位置。

    • 例如:

      • index[0, 0] = 0,表示 src[0, 0] = 1 应该累加到 input_tensor[0, 0]

      • index[1, 1] = 0,表示 src[1, 1] = 7 应该累加到 input_tensor[0, 1]

注意事项

  1. 形状要求

    • indexsrc 的形状必须与目标张量 self 的形状兼容。

    • index.size(d) <= src.size(d) 对所有维度 d 成立。

    • index.size(d) <= self.size(d) 对所有维度 d != dim 成立。

  2. 非确定性行为

    • 在 CUDA 设备上,scatter_add_ 的行为可能是非确定性的。
  3. 反向传播

    • 反向传播仅在 src.shape == index.shape 时实现。
  4. 原地操作

    • scatter_add_ 是一个原地操作,会直接修改目标张量 self

总结

torch.Tensor.scatter_add_ 是一个强大的工具,用于将源张量中的值根据索引累加到目标张量中。它在处理稀疏更新和聚合操作时非常有用,尤其适合需要在特定位置累加值的场景。

相关推荐
凛铄linshuo1 小时前
爬虫简单实操2——以贴吧为例爬取“某吧”前10页的网页代码
爬虫·python·学习
牛客企业服务1 小时前
2025年AI面试推荐榜单,数字化招聘转型优选
人工智能·python·算法·面试·职场和发展·金融·求职招聘
胡斌附体1 小时前
linux测试端口是否可被外部访问
linux·运维·服务器·python·测试·端口测试·临时服务器
视觉语言导航1 小时前
RAL-2025 | 清华大学数字孪生驱动的机器人视觉导航!VR-Robo:面向视觉机器人导航与运动的现实-模拟-现实框架
人工智能·深度学习·机器人·具身智能
**梯度已爆炸**2 小时前
自然语言处理入门
人工智能·自然语言处理
likeGhee2 小时前
python缓存装饰器实现方案
开发语言·python·缓存
ctrlworks2 小时前
楼宇自控核心功能:实时监控设备运行,快速诊断故障,赋能设备寿命延长
人工智能·ba系统厂商·楼宇自控系统厂家·ibms系统厂家·建筑管理系统厂家·能耗监测系统厂家
项目題供诗2 小时前
黑马python(二十五)
开发语言·python
读书点滴2 小时前
笨方法学python -练习14
java·前端·python
笑衬人心。2 小时前
Ubuntu 22.04 修改默认 Python 版本为 Python3 笔记
笔记·python·ubuntu