无脑入门pytorch系列(四)—— scatter_

本系列教程适用于没有任何pytorch的同学(简单的python语法还是要的),从代码的表层出发挖掘代码的深层含义,理解具体的意思和内涵。pytorch的很多函数看着非常简单,但是其中包含了很多内容,不了解其中的意思就只能【看懂代码】,无法【理解代码】。

目录

官方定义

torch.tensor.scatter_是PyTorch中的一个函数,用于将指定索引处的值替换为给定的值。

函数定义:

复制代码
Tensor.scatter_(dim, index, src, reduce=None) → Tensor

官方解释:

  • 将张量src中的所有值写入索引张量中指定的index处的self。

  • 对于src中的每个值,它的输出索引由其在src中的索引(dimension != dim)和在index中对应的值(dimension = dim)指定。

非常难以理解,十分抽象,从我个人的角度来说就是:

  • 第一个参数dim表示维度,即在第几维度处理数据,保持其它维度不变。
  • reduce参数是一个可选参数,用于指定如何在执行散射(scatter)操作时对重复的索引值进行合并或聚合。
  • index则是需要填充的列的索引,即根据维度从src中取对应的值填充到tensor中去。

怎么映射的,比如一个一个3维张量:

复制代码
self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

官方的文档如下,TORCH.TENSOR.SCATTER_:

即使如此理解起来也是很复杂,下面从例子中去理解:

demo

下面是一个官方文档给出的例子:

python 复制代码
import torch

src = torch.Tensor([[-1.0276,  0.2673, -1.1752, -0.8823],
        [-0.6447, -0.8256,  0.1542, -0.4242]])
print(src)

output = torch.zeros(2, 5)
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])

output = output.scatter(1, index, src)
print(output)

输出的结果:

我们一步步理解代码:

  1. 首先,定义了一个src张量,后续output即从src中取值。
  2. 其次,定义了output,其值为二行五列的全零张量,后续对output进行修改。
  3. 接着,定义了index,即从src取值的索引。
  4. 最后,根据index从src取值填充到output中,即完成操作。

那么具体是如何取值的呢?

首先,dim = 1,意味着从维度值为1的地方取值,维度值为0的地方不变,那就是:

复制代码
self[i][index[i][j]] = src[i][j]  # if dim == 1

具体来说:

i = 0, j = 0时,output[0][index[0][0]] = src[0][0],因为index[0][0] = 3,所以output[0][3] = src[0][0] = -1.0276,这时候我们检查输出的output值,确实是-1.0276

同理:

i = 0, j = 1: output[0][index[0][1]] = output[0][1] = src[0][1] = 0.2673

i = 0, j = 2: output[0][index[0][2]] = output[0][2] = src[0][2] = -1.1752

one-hot

作者在学习该函数时实在遇到one-hot编码时遇到的,而该函数在one-hot中应用很广:

复制代码
index = torch.tensor([[3], [2], [0], [1]])
onehot = torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
print(onehot)
相关推荐
专注数据的痴汉2 分钟前
「数据获取」内蒙古地理基础数据(道路、水系、四级行政边界、地级城市、DEM等)
大数据·人工智能·信息可视化
aopstudio2 分钟前
HuggingFace Tokenizer 的进化:从分词器到智能对话引擎
人工智能·自然语言处理·llm·huggingface
JQLvopkk3 分钟前
全栈可视化数字孪生开发平台阐述
人工智能·自动化
b***25114 分钟前
动力电池半自动生产线如何平衡自动化投入与规模化需求
人工智能
Hernon4 分钟前
AI智能体 - 优先级排序
大数据·人工智能
智航GIS7 分钟前
11.7 使用Pandas 模块中describe()、groupby()进行简单分析
python·pandas
Pyeako11 分钟前
机器学习--矿物数据清洗(六种填充方法)
人工智能·python·随机森林·机器学习·pycharm·线性回归·数据清洗
后端小肥肠12 分钟前
告别“抽卡”!n8n+Coze:固定IP角色+老纪先生风格漫画全自动生产
人工智能·aigc·coze
光羽隹衡13 分钟前
深度学习——卷积神经网络CNN
人工智能·深度学习·cnn
爱看科技17 分钟前
苹果Siri借谷歌Gemini大模型发力,微美全息全栈技术领航AI多元场景变革应用
人工智能