使用True False矩阵对torch.tensor切片

torch.tensor可以使用true、false矩阵进行切片,这里对切片结果记录一下

第一种情况:true、false矩阵和torch.tensor维度相同

一维情况

bash 复制代码
import torch
a = torch.randn((8,))
b = torch.randint(0,2,(8,))
slice_ = (b==0)
print(f"a = {a}")
print(f"slice_ = {slice_}")
print(f"a[slice_] = {a[slice_]}")

输出结果为:

bash 复制代码
a = tensor([-0.5343, -1.2582, -0.4511, -0.4338,  2.2691,  0.4879,  0.6847,  0.6235])
slice_ = tensor([ True,  True, False, False, False, False, False, False])
a[slice_] = tensor([-0.5343, -1.2582])

多维情况

bash 复制代码
import torch
a = torch.randn((2,3,4))
b = torch.randint(0,2,(2,3,4))
slice_ = (b==0)
print(f"a = {a}")
print(f"slice_ = {slice_}")
print(f"a[slice_] = {a[slice_]}")

输出结果为:

bash 复制代码
a = tensor([[[ 1.9325,  1.2950,  1.2434,  0.0564],
         [ 0.3010,  0.0343,  0.7497,  0.4019],
         [ 0.3159, -2.3188,  0.3495, -0.3471]],

        [[ 1.0270,  0.9790,  0.9406,  0.3484],
         [ 0.7881,  0.7568,  1.8638,  0.4024],
         [-0.5964, -2.3572, -0.6636,  0.8282]]])
slice_ = tensor([[[False,  True,  True, False],
         [ True,  True,  True, False],
         [ True,  True, False, False]],

        [[ True,  True, False,  True],
         [ True,  True,  True, False],
         [False,  True, False, False]]])
a[slice_] = torch.Size([14])

当两者维度相同时,在torch.tensor里面和true、false矩阵中true位置相同的元素会被保留,其余的值丢掉,最终切片结果为一维张量

第二种情况:true、false矩阵和torch.tensor维度不相同

首先要说明的是,两者维度可以不相同,但是每一个维度的值必须相同,比如说一个张量shape为(2,3,4,5),那true、false矩阵的shape可以是(2,)、(2,3)、(2,3,4)、(2,3,4,5)四种情况

bash 复制代码
import torch
a = torch.randn((2,3,4))
b = torch.randint(0,2,(2,3))
slice_ = (b==0)
print(f"a = {a},  a.shape = {a.shape}")
print(f"slice_ = {slice_},  slice_.shape = {slice_.shape}")
print(f"a[slice_] = {a[slice_]},  a[slice_].shape = {a[slice_].shape}")

输出结果为

bash 复制代码
a = tensor([[[-0.2952, -0.2619, -0.8608,  1.2657],
         [ 0.1895, -0.4806, -1.5506,  0.2752],
         [-0.2219,  0.2185, -0.7038,  0.1399]],

        [[-1.9745,  0.7333, -1.0359,  1.4674],
         [ 1.6730,  0.1612,  0.3537, -0.1737],
         [-0.9188,  3.0544,  1.4211,  0.9257]]]),  a.shape = torch.Size([2, 3, 4])

slice_ = tensor([[ True,  True,  True],
        [ True, False, False]]),  slice_.shape = torch.Size([2, 3])

a[slice_] = tensor([[-0.2952, -0.2619, -0.8608,  1.2657],
        [ 0.1895, -0.4806, -1.5506,  0.2752],
        [-0.2219,  0.2185, -0.7038,  0.1399],
        [-1.9745,  0.7333, -1.0359,  1.4674]]),  a[slice_].shape = torch.Size([4, 4])

当两者维度不相同时,在上面例子中,torch.tensor的shape为(2,3,4),true、false矩阵shape为(2,3),两者的前两个维度相同,那torch.tensor就保留true位置的元素,只不过这个被保留元素的shape为(4,),又因为存在四个Ture,所以最后切片结果shape为(4,4)

相关推荐
Mantanmu8 分钟前
Python训练day40
人工智能·python·机器学习
天天爱吃肉821811 分钟前
新能源汽车热管理核心技术解析:冬季续航提升40%的行业方案
android·python·嵌入式硬件·汽车
ss.li14 分钟前
TripGenie:畅游济南旅行规划助手:个人工作纪实(二十二)
javascript·人工智能·python
l木本I27 分钟前
大模型低秩微调技术 LoRA 深度解析与实践
python·深度学习·自然语言处理·lstm·transformer
哆啦A梦的口袋呀31 分钟前
基于Python学习《Head First设计模式》第七章 适配器和外观模式
python·学习·设计模式
十月狐狸34 分钟前
Python字符串进化史:从青涩到成熟的蜕变
python
狐凄1 小时前
Python实例题:Python计算线性代数
开发语言·python·线性代数
西猫雷婶1 小时前
pytorch基本运算-导数和f-string
人工智能·pytorch·python
顽强卖力1 小时前
第二十八课:深度学习及pytorch简介
人工智能·pytorch·深度学习
述雾学java1 小时前
深入理解 transforms.Normalize():PyTorch 图像预处理中的关键一步
人工智能·pytorch·python