使用True False矩阵对torch.tensor切片

torch.tensor可以使用true、false矩阵进行切片,这里对切片结果记录一下

第一种情况:true、false矩阵和torch.tensor维度相同

一维情况

bash 复制代码
import torch
a = torch.randn((8,))
b = torch.randint(0,2,(8,))
slice_ = (b==0)
print(f"a = {a}")
print(f"slice_ = {slice_}")
print(f"a[slice_] = {a[slice_]}")

输出结果为:

bash 复制代码
a = tensor([-0.5343, -1.2582, -0.4511, -0.4338,  2.2691,  0.4879,  0.6847,  0.6235])
slice_ = tensor([ True,  True, False, False, False, False, False, False])
a[slice_] = tensor([-0.5343, -1.2582])

多维情况

bash 复制代码
import torch
a = torch.randn((2,3,4))
b = torch.randint(0,2,(2,3,4))
slice_ = (b==0)
print(f"a = {a}")
print(f"slice_ = {slice_}")
print(f"a[slice_] = {a[slice_]}")

输出结果为:

bash 复制代码
a = tensor([[[ 1.9325,  1.2950,  1.2434,  0.0564],
         [ 0.3010,  0.0343,  0.7497,  0.4019],
         [ 0.3159, -2.3188,  0.3495, -0.3471]],

        [[ 1.0270,  0.9790,  0.9406,  0.3484],
         [ 0.7881,  0.7568,  1.8638,  0.4024],
         [-0.5964, -2.3572, -0.6636,  0.8282]]])
slice_ = tensor([[[False,  True,  True, False],
         [ True,  True,  True, False],
         [ True,  True, False, False]],

        [[ True,  True, False,  True],
         [ True,  True,  True, False],
         [False,  True, False, False]]])
a[slice_] = torch.Size([14])

当两者维度相同时,在torch.tensor里面和true、false矩阵中true位置相同的元素会被保留,其余的值丢掉,最终切片结果为一维张量

第二种情况:true、false矩阵和torch.tensor维度不相同

首先要说明的是,两者维度可以不相同,但是每一个维度的值必须相同,比如说一个张量shape为(2,3,4,5),那true、false矩阵的shape可以是(2,)、(2,3)、(2,3,4)、(2,3,4,5)四种情况

bash 复制代码
import torch
a = torch.randn((2,3,4))
b = torch.randint(0,2,(2,3))
slice_ = (b==0)
print(f"a = {a},  a.shape = {a.shape}")
print(f"slice_ = {slice_},  slice_.shape = {slice_.shape}")
print(f"a[slice_] = {a[slice_]},  a[slice_].shape = {a[slice_].shape}")

输出结果为

bash 复制代码
a = tensor([[[-0.2952, -0.2619, -0.8608,  1.2657],
         [ 0.1895, -0.4806, -1.5506,  0.2752],
         [-0.2219,  0.2185, -0.7038,  0.1399]],

        [[-1.9745,  0.7333, -1.0359,  1.4674],
         [ 1.6730,  0.1612,  0.3537, -0.1737],
         [-0.9188,  3.0544,  1.4211,  0.9257]]]),  a.shape = torch.Size([2, 3, 4])

slice_ = tensor([[ True,  True,  True],
        [ True, False, False]]),  slice_.shape = torch.Size([2, 3])

a[slice_] = tensor([[-0.2952, -0.2619, -0.8608,  1.2657],
        [ 0.1895, -0.4806, -1.5506,  0.2752],
        [-0.2219,  0.2185, -0.7038,  0.1399],
        [-1.9745,  0.7333, -1.0359,  1.4674]]),  a[slice_].shape = torch.Size([4, 4])

当两者维度不相同时,在上面例子中,torch.tensor的shape为(2,3,4),true、false矩阵shape为(2,3),两者的前两个维度相同,那torch.tensor就保留true位置的元素,只不过这个被保留元素的shape为(4,),又因为存在四个Ture,所以最后切片结果shape为(4,4)

相关推荐
databook7 小时前
Manim实现脉冲闪烁特效
后端·python·动效
程序设计实验室7 小时前
2025年了,在 Django 之外,Python Web 框架还能怎么选?
python
倔强青铜三9 小时前
苦练Python第46天:文件写入与上下文管理器
人工智能·python·面试
用户25191624271112 小时前
Python之语言特点
python
刘立军13 小时前
使用pyHugeGraph查询HugeGraph图数据
python·graphql
数据智能老司机16 小时前
精通 Python 设计模式——创建型设计模式
python·设计模式·架构
数据智能老司机17 小时前
精通 Python 设计模式——SOLID 原则
python·设计模式·架构
c8i19 小时前
django中的FBV 和 CBV
python·django
c8i19 小时前
python中的闭包和装饰器
python
这里有鱼汤1 天前
小白必看:QMT里的miniQMT入门教程
后端·python