使用True False矩阵对torch.tensor切片

torch.tensor可以使用true、false矩阵进行切片,这里对切片结果记录一下

第一种情况:true、false矩阵和torch.tensor维度相同

一维情况

bash 复制代码
import torch
a = torch.randn((8,))
b = torch.randint(0,2,(8,))
slice_ = (b==0)
print(f"a = {a}")
print(f"slice_ = {slice_}")
print(f"a[slice_] = {a[slice_]}")

输出结果为:

bash 复制代码
a = tensor([-0.5343, -1.2582, -0.4511, -0.4338,  2.2691,  0.4879,  0.6847,  0.6235])
slice_ = tensor([ True,  True, False, False, False, False, False, False])
a[slice_] = tensor([-0.5343, -1.2582])

多维情况

bash 复制代码
import torch
a = torch.randn((2,3,4))
b = torch.randint(0,2,(2,3,4))
slice_ = (b==0)
print(f"a = {a}")
print(f"slice_ = {slice_}")
print(f"a[slice_] = {a[slice_]}")

输出结果为:

bash 复制代码
a = tensor([[[ 1.9325,  1.2950,  1.2434,  0.0564],
         [ 0.3010,  0.0343,  0.7497,  0.4019],
         [ 0.3159, -2.3188,  0.3495, -0.3471]],

        [[ 1.0270,  0.9790,  0.9406,  0.3484],
         [ 0.7881,  0.7568,  1.8638,  0.4024],
         [-0.5964, -2.3572, -0.6636,  0.8282]]])
slice_ = tensor([[[False,  True,  True, False],
         [ True,  True,  True, False],
         [ True,  True, False, False]],

        [[ True,  True, False,  True],
         [ True,  True,  True, False],
         [False,  True, False, False]]])
a[slice_] = torch.Size([14])

当两者维度相同时,在torch.tensor里面和true、false矩阵中true位置相同的元素会被保留,其余的值丢掉,最终切片结果为一维张量

第二种情况:true、false矩阵和torch.tensor维度不相同

首先要说明的是,两者维度可以不相同,但是每一个维度的值必须相同,比如说一个张量shape为(2,3,4,5),那true、false矩阵的shape可以是(2,)、(2,3)、(2,3,4)、(2,3,4,5)四种情况

bash 复制代码
import torch
a = torch.randn((2,3,4))
b = torch.randint(0,2,(2,3))
slice_ = (b==0)
print(f"a = {a},  a.shape = {a.shape}")
print(f"slice_ = {slice_},  slice_.shape = {slice_.shape}")
print(f"a[slice_] = {a[slice_]},  a[slice_].shape = {a[slice_].shape}")

输出结果为

bash 复制代码
a = tensor([[[-0.2952, -0.2619, -0.8608,  1.2657],
         [ 0.1895, -0.4806, -1.5506,  0.2752],
         [-0.2219,  0.2185, -0.7038,  0.1399]],

        [[-1.9745,  0.7333, -1.0359,  1.4674],
         [ 1.6730,  0.1612,  0.3537, -0.1737],
         [-0.9188,  3.0544,  1.4211,  0.9257]]]),  a.shape = torch.Size([2, 3, 4])

slice_ = tensor([[ True,  True,  True],
        [ True, False, False]]),  slice_.shape = torch.Size([2, 3])

a[slice_] = tensor([[-0.2952, -0.2619, -0.8608,  1.2657],
        [ 0.1895, -0.4806, -1.5506,  0.2752],
        [-0.2219,  0.2185, -0.7038,  0.1399],
        [-1.9745,  0.7333, -1.0359,  1.4674]]),  a[slice_].shape = torch.Size([4, 4])

当两者维度不相同时,在上面例子中,torch.tensor的shape为(2,3,4),true、false矩阵shape为(2,3),两者的前两个维度相同,那torch.tensor就保留true位置的元素,只不过这个被保留元素的shape为(4,),又因为存在四个Ture,所以最后切片结果shape为(4,4)

相关推荐
算法小白(真小白)2 小时前
低代码软件搭建自学第二天——构建拖拽功能
python·低代码·pyqt
唐小旭2 小时前
服务器建立-错误:pyenv环境建立后python版本不对
运维·服务器·python
007php0072 小时前
Go语言zero项目部署后启动失败问题分析与解决
java·服务器·网络·python·golang·php·ai编程
Chinese Red Guest3 小时前
python
开发语言·python·pygame
骑个小蜗牛3 小时前
Python 标准库:string——字符串操作
python
黄公子学安全5 小时前
Java的基础概念(一)
java·开发语言·python
程序员一诺6 小时前
【Python使用】嘿马python高级进阶全体系教程第10篇:静态Web服务器-返回固定页面数据,1. 开发自己的静态Web服务器【附代码文档】
后端·python
小木_.6 小时前
【Python 图片下载器】一款专门为爬虫制作的图片下载器,多线程下载,速度快,支持续传/图片缩放/图片压缩/图片转换
爬虫·python·学习·分享·批量下载·图片下载器
Jiude7 小时前
算法题题解记录——双变量问题的 “枚举右,维护左”
python·算法·面试