Pytorch 张量的scatter_add_方法介绍

torch.Tensor.scatter_add_ 是 PyTorch 中的一个原地操作(in-place operation),用于将一个源张量(src)中的值根据指定的索引(index)累加到目标张量(self)中。它常用于分布式计算、加权聚合以及自定义深度学习层等场景。

函数签名

复制代码
Tensor.scatter_add_(dim, index, src) → Tensor
参数说明
  1. dim (int):指定沿着哪个维度进行索引和累加。

  2. index (LongTensor) :一个整数类型的张量,包含要累加的索引位置。index 的形状应与 src 相同,除了指定的维度 dim

  3. src (Tensor):源张量,包含要累加到目标张量的值。

功能

scatter_add_ 会根据 index 中的索引,将 src 中的值累加到目标张量 self 的指定位置。对于每个值,其目标位置由 index 指定,而其他维度的位置由其在 src 中的位置决定。

操作逻辑

对于一个三维张量,scatter_add_ 的更新规则如下:

复制代码
self[index[i][j][k]][j][k] += src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] += src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] += src[i][j][k]  # if dim == 2

示例

以下是一个简单的二维张量示例:

Python复制

复制代码
import torch

# 初始化目标张量
input_tensor = torch.zeros(3, 5)

# 源张量
src = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]], dtype=torch.float32)

# 索引张量
index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]], dtype=torch.long)

# 沿着维度 0 进行累加
input_tensor.scatter_add_(0, index, src)

print(input_tensor)

输出:

复制代码
tensor([[ 1.,  2.,  3.,  4.,  5.],
        [ 0.,  7.,  0.,  9.,  0.],
        [ 6.,  0.,  8.,  0., 10.]])

详细解析

  1. 目标张量input_tensor 是一个形状为 (3, 5) 的零张量。

  2. 源张量src 是一个形状为 (2, 5) 的张量,包含要累加的值。

  3. 索引张量index 是一个形状为 (2, 5) 的整数张量,指定 src 中的值应该累加到 input_tensor 的哪些位置。

  4. 累加操作

    • scatter_add_ 沿着维度 0 进行操作。

    • index 中的每个值指定了 src 中对应值的目标位置。

    • 例如:

      • index[0, 0] = 0,表示 src[0, 0] = 1 应该累加到 input_tensor[0, 0]

      • index[1, 1] = 0,表示 src[1, 1] = 7 应该累加到 input_tensor[0, 1]

注意事项

  1. 形状要求

    • indexsrc 的形状必须与目标张量 self 的形状兼容。

    • index.size(d) <= src.size(d) 对所有维度 d 成立。

    • index.size(d) <= self.size(d) 对所有维度 d != dim 成立。

  2. 非确定性行为

    • 在 CUDA 设备上,scatter_add_ 的行为可能是非确定性的。
  3. 反向传播

    • 反向传播仅在 src.shape == index.shape 时实现。
  4. 原地操作

    • scatter_add_ 是一个原地操作,会直接修改目标张量 self

总结

torch.Tensor.scatter_add_ 是一个强大的工具,用于将源张量中的值根据索引累加到目标张量中。它在处理稀疏更新和聚合操作时非常有用,尤其适合需要在特定位置累加值的场景。

相关推荐
全栈技术负责人9 小时前
AI驱动开发 (AI-DLC) 实战经验分享:重构人机协作的上下文工程
人工智能·重构
Wu_Dylan9 小时前
智能体系列(二):规划(Planning):从 CoT、ToT 到动态采样与搜索
人工智能·算法
一招定胜负9 小时前
OpenCV轮廓检测完全指南:从原理到实战
人工智能·opencv·计算机视觉
毕设源码-郭学长9 小时前
【开题答辩全过程】以 基于python电商商城系统为例,包含答辩的问题和答案
开发语言·python
black0moonlight9 小时前
win11 isaacsim 5.1.0 和lab配置
python
知乎的哥廷根数学学派9 小时前
基于多尺度注意力机制融合连续小波变换与原型网络的滚动轴承小样本故障诊断方法(Pytorch)
网络·人工智能·pytorch·python·深度学习·算法·机器学习
网安CILLE9 小时前
PHP四大输出语句
linux·开发语言·python·web安全·网络安全·系统安全·php
xiatianxy9 小时前
云酷科技用智能化方案破解行业难题
人工智能·科技·安全·智能安全带
jjjddfvv9 小时前
超级简单启动llamafactory!
windows·python·深度学习·神经网络·微调·audiolm·llamafactory
星云数灵9 小时前
大模型高级工程师考试练习题8
人工智能·机器学习·大模型·大模型考试题库·阿里云aca·阿里云acp大模型考试题库·大模型高级工程师acp