无脑入门pytorch系列(四)—— scatter_

本系列教程适用于没有任何pytorch的同学(简单的python语法还是要的),从代码的表层出发挖掘代码的深层含义,理解具体的意思和内涵。pytorch的很多函数看着非常简单,但是其中包含了很多内容,不了解其中的意思就只能【看懂代码】,无法【理解代码】。

目录

官方定义

torch.tensor.scatter_是PyTorch中的一个函数,用于将指定索引处的值替换为给定的值。

函数定义:

复制代码
Tensor.scatter_(dim, index, src, reduce=None) → Tensor

官方解释:

  • 将张量src中的所有值写入索引张量中指定的index处的self。

  • 对于src中的每个值,它的输出索引由其在src中的索引(dimension != dim)和在index中对应的值(dimension = dim)指定。

非常难以理解,十分抽象,从我个人的角度来说就是:

  • 第一个参数dim表示维度,即在第几维度处理数据,保持其它维度不变。
  • reduce参数是一个可选参数,用于指定如何在执行散射(scatter)操作时对重复的索引值进行合并或聚合。
  • index则是需要填充的列的索引,即根据维度从src中取对应的值填充到tensor中去。

怎么映射的,比如一个一个3维张量:

复制代码
self[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

官方的文档如下,TORCH.TENSOR.SCATTER_:

即使如此理解起来也是很复杂,下面从例子中去理解:

demo

下面是一个官方文档给出的例子:

python 复制代码
import torch

src = torch.Tensor([[-1.0276,  0.2673, -1.1752, -0.8823],
        [-0.6447, -0.8256,  0.1542, -0.4242]])
print(src)

output = torch.zeros(2, 5)
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])

output = output.scatter(1, index, src)
print(output)

输出的结果:

我们一步步理解代码:

  1. 首先,定义了一个src张量,后续output即从src中取值。
  2. 其次,定义了output,其值为二行五列的全零张量,后续对output进行修改。
  3. 接着,定义了index,即从src取值的索引。
  4. 最后,根据index从src取值填充到output中,即完成操作。

那么具体是如何取值的呢?

首先,dim = 1,意味着从维度值为1的地方取值,维度值为0的地方不变,那就是:

复制代码
self[i][index[i][j]] = src[i][j]  # if dim == 1

具体来说:

i = 0, j = 0时,output[0][index[0][0]] = src[0][0],因为index[0][0] = 3,所以output[0][3] = src[0][0] = -1.0276,这时候我们检查输出的output值,确实是-1.0276

同理:

i = 0, j = 1: output[0][index[0][1]] = output[0][1] = src[0][1] = 0.2673

i = 0, j = 2: output[0][index[0][2]] = output[0][2] = src[0][2] = -1.1752

one-hot

作者在学习该函数时实在遇到one-hot编码时遇到的,而该函数在one-hot中应用很广:

复制代码
index = torch.tensor([[3], [2], [0], [1]])
onehot = torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
print(onehot)
相关推荐
芝士爱知识a8 分钟前
2026年教资备考数字化生存指南:主流App深度测评与AI技术应用分析
人工智能·教资·ai教育·教育技术·教资面试·app测评·2026教资
AIArchivist8 分钟前
攻坚肝胆疑难病例,AI成为诊疗决策的“智慧大脑”
人工智能
jake don14 分钟前
GPU服务器搭建大模型指南
服务器·人工智能
乔江seven24 分钟前
【Flask 进阶】3 从同步到异步:基于 Redis 任务队列解决 API 高并发与长耗时任务阻塞
redis·python·flask
JicasdC123asd28 分钟前
【深度学习实战】基于Mask-RCNN和HRNetV2P的腰果智能分级系统_1
人工智能·深度学习
pchaoda37 分钟前
基本面因子计算入门
python·matplotlib·量化
Wpa.wk42 分钟前
接口自动化测试 - 请求构造和响应断言 -Rest-assure
开发语言·python·测试工具·接口自动化
星爷AG I43 分钟前
9-28 视觉工作记忆(AGI基础理论)
人工智能·计算机视觉·agi
陈天伟教授1 小时前
人工智能应用- 语言理解:07.大语言模型
人工智能·深度学习·语言模型
岱宗夫up1 小时前
机器学习:标准化流模型(NF)
人工智能·python·机器学习·生成对抗网络