本系列教程适用于没有任何pytorch的同学(简单的python语法还是要的),从代码的表层出发挖掘代码的深层含义,理解具体的意思和内涵。pytorch的很多函数看着非常简单,但是其中包含了很多内容,不了解其中的意思就只能【看懂代码】,无法【理解代码】。
目录
官方定义
torch.tensor.scatter_
是PyTorch中的一个函数,用于将指定索引处的值替换为给定的值。
函数定义:
Tensor.scatter_(dim, index, src, reduce=None) → Tensor
官方解释:
-
将张量
src
中的所有值写入索引张量中指定的index
处的self。 -
对于
src
中的每个值,它的输出索引由其在src
中的索引(dimension != dim)
和在index中对应的值(dimension = dim)
指定。
非常难以理解,十分抽象,从我个人的角度来说就是:
- 第一个参数
dim
表示维度,即在第几维度处理数据,保持其它维度不变。 reduce
参数是一个可选参数,用于指定如何在执行散射(scatter)操作时对重复的索引值进行合并或聚合。- index则是需要填充的列的索引,即根据维度从src中取对应的值填充到tensor中去。
怎么映射的,比如一个一个3维张量:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
官方的文档如下,TORCH.TENSOR.SCATTER_:
即使如此理解起来也是很复杂,下面从例子中去理解:
demo
下面是一个官方文档给出的例子:
python
import torch
src = torch.Tensor([[-1.0276, 0.2673, -1.1752, -0.8823],
[-0.6447, -0.8256, 0.1542, -0.4242]])
print(src)
output = torch.zeros(2, 5)
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])
output = output.scatter(1, index, src)
print(output)
输出的结果:
我们一步步理解代码:
- 首先,定义了一个
src
张量,后续output即从src中取值。 - 其次,定义了
output
,其值为二行五列的全零张量,后续对output
进行修改。 - 接着,定义了index,即从src取值的索引。
- 最后,根据index从src取值填充到output中,即完成操作。
那么具体是如何取值的呢?
首先,dim = 1
,意味着从维度值为1的地方取值,维度值为0的地方不变,那就是:
self[i][index[i][j]] = src[i][j] # if dim == 1
具体来说:
当i = 0, j = 0
时,output[0][index[0][0]] = src[0][0]
,因为index[0][0] = 3
,所以output[0][3] = src[0][0] = -1.0276
,这时候我们检查输出的output
值,确实是-1.0276
。
同理:
i = 0, j = 1
: output[0][index[0][1]] = output[0][1] = src[0][1] = 0.2673
i = 0, j = 2
: output[0][index[0][2]] = output[0][2] = src[0][2] = -1.1752
one-hot
作者在学习该函数时实在遇到one-hot编码时遇到的,而该函数在one-hot中应用很广:
index = torch.tensor([[3], [2], [0], [1]])
onehot = torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
print(onehot)