使用True False矩阵对torch.tensor切片

torch.tensor可以使用true、false矩阵进行切片,这里对切片结果记录一下

第一种情况:true、false矩阵和torch.tensor维度相同

一维情况

bash 复制代码
import torch
a = torch.randn((8,))
b = torch.randint(0,2,(8,))
slice_ = (b==0)
print(f"a = {a}")
print(f"slice_ = {slice_}")
print(f"a[slice_] = {a[slice_]}")

输出结果为:

bash 复制代码
a = tensor([-0.5343, -1.2582, -0.4511, -0.4338,  2.2691,  0.4879,  0.6847,  0.6235])
slice_ = tensor([ True,  True, False, False, False, False, False, False])
a[slice_] = tensor([-0.5343, -1.2582])

多维情况

bash 复制代码
import torch
a = torch.randn((2,3,4))
b = torch.randint(0,2,(2,3,4))
slice_ = (b==0)
print(f"a = {a}")
print(f"slice_ = {slice_}")
print(f"a[slice_] = {a[slice_]}")

输出结果为:

bash 复制代码
a = tensor([[[ 1.9325,  1.2950,  1.2434,  0.0564],
         [ 0.3010,  0.0343,  0.7497,  0.4019],
         [ 0.3159, -2.3188,  0.3495, -0.3471]],

        [[ 1.0270,  0.9790,  0.9406,  0.3484],
         [ 0.7881,  0.7568,  1.8638,  0.4024],
         [-0.5964, -2.3572, -0.6636,  0.8282]]])
slice_ = tensor([[[False,  True,  True, False],
         [ True,  True,  True, False],
         [ True,  True, False, False]],

        [[ True,  True, False,  True],
         [ True,  True,  True, False],
         [False,  True, False, False]]])
a[slice_] = torch.Size([14])

当两者维度相同时,在torch.tensor里面和true、false矩阵中true位置相同的元素会被保留,其余的值丢掉,最终切片结果为一维张量

第二种情况:true、false矩阵和torch.tensor维度不相同

首先要说明的是,两者维度可以不相同,但是每一个维度的值必须相同,比如说一个张量shape为(2,3,4,5),那true、false矩阵的shape可以是(2,)、(2,3)、(2,3,4)、(2,3,4,5)四种情况

bash 复制代码
import torch
a = torch.randn((2,3,4))
b = torch.randint(0,2,(2,3))
slice_ = (b==0)
print(f"a = {a},  a.shape = {a.shape}")
print(f"slice_ = {slice_},  slice_.shape = {slice_.shape}")
print(f"a[slice_] = {a[slice_]},  a[slice_].shape = {a[slice_].shape}")

输出结果为

bash 复制代码
a = tensor([[[-0.2952, -0.2619, -0.8608,  1.2657],
         [ 0.1895, -0.4806, -1.5506,  0.2752],
         [-0.2219,  0.2185, -0.7038,  0.1399]],

        [[-1.9745,  0.7333, -1.0359,  1.4674],
         [ 1.6730,  0.1612,  0.3537, -0.1737],
         [-0.9188,  3.0544,  1.4211,  0.9257]]]),  a.shape = torch.Size([2, 3, 4])

slice_ = tensor([[ True,  True,  True],
        [ True, False, False]]),  slice_.shape = torch.Size([2, 3])

a[slice_] = tensor([[-0.2952, -0.2619, -0.8608,  1.2657],
        [ 0.1895, -0.4806, -1.5506,  0.2752],
        [-0.2219,  0.2185, -0.7038,  0.1399],
        [-1.9745,  0.7333, -1.0359,  1.4674]]),  a[slice_].shape = torch.Size([4, 4])

当两者维度不相同时,在上面例子中,torch.tensor的shape为(2,3,4),true、false矩阵shape为(2,3),两者的前两个维度相同,那torch.tensor就保留true位置的元素,只不过这个被保留元素的shape为(4,),又因为存在四个Ture,所以最后切片结果shape为(4,4)

相关推荐
B站_计算机毕业设计之家3 分钟前
计算机毕业设计:Python农业数据可视化分析系统 气象数据 农业生产 粮食数据 播种数据 爬虫 Django框架 天气数据 降水量(源码+文档)✅
大数据·爬虫·python·机器学习·信息可视化·课程设计·农业
Q_Q51100828516 分钟前
python+uniapp基于微信小程序的旅游信息系统
spring boot·python·微信小程序·django·flask·uni-app·node.js
鄃鳕18 分钟前
python迭代器解包【python】
开发语言·python
懷淰メ1 小时前
python3GUI--模仿百度网盘的本地文件管理器 By:PyQt5(详细分享)
开发语言·python·pyqt·文件管理·百度云·百度网盘·ui设计
Q_Q5110082851 小时前
python基于web的汽车班车车票管理系统/火车票预订系统/高铁预定系统 可在线选座
spring boot·python·django·flask·node.js·汽车·php
新子y1 小时前
【小白笔记】普通二叉树(General Binary Tree)和二叉搜索树的最近公共祖先(LCA)
开发语言·笔记·python
囚生CY1 小时前
【速写】优化的深度与广度(Adam & Moun)
人工智能·python·算法
Query*1 小时前
Java 设计模式——工厂模式:从原理到实战的系统指南
java·python·设计模式
爱学习的uu1 小时前
CURSOR最新使用指南及使用思路
人工智能·笔记·python·软件工程
叶凡要飞2 小时前
RTX5060Ti安装双系统ubuntu22.04各种踩坑点(黑屏,引导区修复、装驱动、server版本安装)
人工智能·python·yolo·ubuntu·机器学习·操作系统