使用True False矩阵对torch.tensor切片

torch.tensor可以使用true、false矩阵进行切片,这里对切片结果记录一下

第一种情况:true、false矩阵和torch.tensor维度相同

一维情况

bash 复制代码
import torch
a = torch.randn((8,))
b = torch.randint(0,2,(8,))
slice_ = (b==0)
print(f"a = {a}")
print(f"slice_ = {slice_}")
print(f"a[slice_] = {a[slice_]}")

输出结果为:

bash 复制代码
a = tensor([-0.5343, -1.2582, -0.4511, -0.4338,  2.2691,  0.4879,  0.6847,  0.6235])
slice_ = tensor([ True,  True, False, False, False, False, False, False])
a[slice_] = tensor([-0.5343, -1.2582])

多维情况

bash 复制代码
import torch
a = torch.randn((2,3,4))
b = torch.randint(0,2,(2,3,4))
slice_ = (b==0)
print(f"a = {a}")
print(f"slice_ = {slice_}")
print(f"a[slice_] = {a[slice_]}")

输出结果为:

bash 复制代码
a = tensor([[[ 1.9325,  1.2950,  1.2434,  0.0564],
         [ 0.3010,  0.0343,  0.7497,  0.4019],
         [ 0.3159, -2.3188,  0.3495, -0.3471]],

        [[ 1.0270,  0.9790,  0.9406,  0.3484],
         [ 0.7881,  0.7568,  1.8638,  0.4024],
         [-0.5964, -2.3572, -0.6636,  0.8282]]])
slice_ = tensor([[[False,  True,  True, False],
         [ True,  True,  True, False],
         [ True,  True, False, False]],

        [[ True,  True, False,  True],
         [ True,  True,  True, False],
         [False,  True, False, False]]])
a[slice_] = torch.Size([14])

当两者维度相同时,在torch.tensor里面和true、false矩阵中true位置相同的元素会被保留,其余的值丢掉,最终切片结果为一维张量

第二种情况:true、false矩阵和torch.tensor维度不相同

首先要说明的是,两者维度可以不相同,但是每一个维度的值必须相同,比如说一个张量shape为(2,3,4,5),那true、false矩阵的shape可以是(2,)、(2,3)、(2,3,4)、(2,3,4,5)四种情况

bash 复制代码
import torch
a = torch.randn((2,3,4))
b = torch.randint(0,2,(2,3))
slice_ = (b==0)
print(f"a = {a},  a.shape = {a.shape}")
print(f"slice_ = {slice_},  slice_.shape = {slice_.shape}")
print(f"a[slice_] = {a[slice_]},  a[slice_].shape = {a[slice_].shape}")

输出结果为

bash 复制代码
a = tensor([[[-0.2952, -0.2619, -0.8608,  1.2657],
         [ 0.1895, -0.4806, -1.5506,  0.2752],
         [-0.2219,  0.2185, -0.7038,  0.1399]],

        [[-1.9745,  0.7333, -1.0359,  1.4674],
         [ 1.6730,  0.1612,  0.3537, -0.1737],
         [-0.9188,  3.0544,  1.4211,  0.9257]]]),  a.shape = torch.Size([2, 3, 4])

slice_ = tensor([[ True,  True,  True],
        [ True, False, False]]),  slice_.shape = torch.Size([2, 3])

a[slice_] = tensor([[-0.2952, -0.2619, -0.8608,  1.2657],
        [ 0.1895, -0.4806, -1.5506,  0.2752],
        [-0.2219,  0.2185, -0.7038,  0.1399],
        [-1.9745,  0.7333, -1.0359,  1.4674]]),  a[slice_].shape = torch.Size([4, 4])

当两者维度不相同时,在上面例子中,torch.tensor的shape为(2,3,4),true、false矩阵shape为(2,3),两者的前两个维度相同,那torch.tensor就保留true位置的元素,只不过这个被保留元素的shape为(4,),又因为存在四个Ture,所以最后切片结果shape为(4,4)

相关推荐
程序员小续1 小时前
React 多个 HOC 嵌套太深,会带来哪些隐患?
java·前端·javascript·vue.js·python·react.js·webpack
九转成圣3 小时前
windows10安装配置并使用Miniconda3
python·conda
Aerkui3 小时前
Python高阶函数-eval深入解析
开发语言·python
胖哥真不错3 小时前
数据分享:汽车测评数据
python·机器学习·数据分享·汽车测评数据·car evaluation
OreoCC4 小时前
第R3周:RNN-心脏病预测(pytorch版)
人工智能·pytorch·rnn
u0103731064 小时前
Django异步执行任务django-background-tasks
后端·python·django
杰瑞学AI5 小时前
LeetCode详解之如何一步步优化到最佳解法:21. 合并两个有序链表
数据结构·python·算法·leetcode·链表·面试·职场和发展
攻城狮7号5 小时前
Python爬虫第5节-urllib的异常处理、链接解析及 Robots 协议分析
爬虫·python·python爬虫
java1234_小锋5 小时前
一周学会Pandas2 Python数据处理与分析-Jupyter Notebook安装
开发语言·python·jupyter·pandas
skywalk81635 小时前
unittest测试模块:Python 标准库中的单元测试利器
开发语言·python·unittest