torch.Tensor.scatter_add_
是 PyTorch 中的一个原地操作(in-place operation),用于将一个源张量(src
)中的值根据指定的索引(index
)累加到目标张量(self
)中。它常用于分布式计算、加权聚合以及自定义深度学习层等场景。
函数签名
Tensor.scatter_add_(dim, index, src) → Tensor
参数说明
-
dim
(int):指定沿着哪个维度进行索引和累加。 -
index
(LongTensor) :一个整数类型的张量,包含要累加的索引位置。index
的形状应与src
相同,除了指定的维度dim
。 -
src
(Tensor):源张量,包含要累加到目标张量的值。
功能
scatter_add_
会根据 index
中的索引,将 src
中的值累加到目标张量 self
的指定位置。对于每个值,其目标位置由 index
指定,而其他维度的位置由其在 src
中的位置决定。
操作逻辑
对于一个三维张量,scatter_add_
的更新规则如下:
self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2
示例
以下是一个简单的二维张量示例:
Python复制
import torch
# 初始化目标张量
input_tensor = torch.zeros(3, 5)
# 源张量
src = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]], dtype=torch.float32)
# 索引张量
index = torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]], dtype=torch.long)
# 沿着维度 0 进行累加
input_tensor.scatter_add_(0, index, src)
print(input_tensor)
输出:
tensor([[ 1., 2., 3., 4., 5.],
[ 0., 7., 0., 9., 0.],
[ 6., 0., 8., 0., 10.]])
详细解析
-
目标张量 :
input_tensor
是一个形状为(3, 5)
的零张量。 -
源张量 :
src
是一个形状为(2, 5)
的张量,包含要累加的值。 -
索引张量 :
index
是一个形状为(2, 5)
的整数张量,指定src
中的值应该累加到input_tensor
的哪些位置。 -
累加操作:
-
scatter_add_
沿着维度0
进行操作。 -
index
中的每个值指定了src
中对应值的目标位置。 -
例如:
-
index[0, 0] = 0
,表示src[0, 0] = 1
应该累加到input_tensor[0, 0]
。 -
index[1, 1] = 0
,表示src[1, 1] = 7
应该累加到input_tensor[0, 1]
。
-
-
注意事项
-
形状要求:
-
index
和src
的形状必须与目标张量self
的形状兼容。 -
index.size(d) <= src.size(d)
对所有维度d
成立。 -
index.size(d) <= self.size(d)
对所有维度d != dim
成立。
-
-
非确定性行为:
- 在 CUDA 设备上,
scatter_add_
的行为可能是非确定性的。
- 在 CUDA 设备上,
-
反向传播:
- 反向传播仅在
src.shape == index.shape
时实现。
- 反向传播仅在
-
原地操作:
scatter_add_
是一个原地操作,会直接修改目标张量self
。
总结
torch.Tensor.scatter_add_
是一个强大的工具,用于将源张量中的值根据索引累加到目标张量中。它在处理稀疏更新和聚合操作时非常有用,尤其适合需要在特定位置累加值的场景。