以如下 tensor a 为例,展示常用的 indxing, slicing 及其他高阶操作
python3
>>> a = torch.rand(4,3,28,28)
>>> a.shape
torch.Size([4, 3, 28, 28])
-
Indexing: 使用索引获取目标对象,
[x,x,x,....]python3>>> a[0].shape torch.Size([3, 28, 28]) >>> a[0,0].shape torch.Size([28, 28]) >>> a[0,0,0].shape torch.Size([28]) >>> a[0,0,0,0].shape torch.Size([]) -
Slicing: 使用切片获取
一截目标对象,::steppython3>>> a[:2].shape torch.Size([2, 3, 28, 28]) >>> a[0, :2].shape torch.Size([2, 28, 28]) >>> a[0, 0, :2].shape torch.Size([2, 28]) >>> a[0, 0, 0, :2].shape torch.Size([2]) -
其他汇总:
python3>>> a.index_select(dim, torch.tensor([idx_1,idx_2, ...])) ## by specific idx >>> torch.take(a, torch.tensor([idx_1, idx2, ...])) ## 不指定 dim 先打平 a 再按序提取 >>> a[a.ge(0.5)] ## by mask=a.ge(0.5),该方法没有保持 shape