以如下 tensor a
为例,展示常用的 indxing
, slicing
及其他高阶操作
python3
>>> a = torch.rand(4,3,28,28)
>>> a.shape
torch.Size([4, 3, 28, 28])
-
Indexing: 使用索引获取目标对象,
[x,x,x,....]
python3>>> a[0].shape torch.Size([3, 28, 28]) >>> a[0,0].shape torch.Size([28, 28]) >>> a[0,0,0].shape torch.Size([28]) >>> a[0,0,0,0].shape torch.Size([])
-
Slicing: 使用切片获取
一截
目标对象,::step
python3>>> a[:2].shape torch.Size([2, 3, 28, 28]) >>> a[0, :2].shape torch.Size([2, 28, 28]) >>> a[0, 0, :2].shape torch.Size([2, 28]) >>> a[0, 0, 0, :2].shape torch.Size([2])
-
其他汇总:
python3>>> a.index_select(dim, torch.tensor([idx_1,idx_2, ...])) ## by specific idx >>> torch.take(a, torch.tensor([idx_1, idx2, ...])) ## 不指定 dim 先打平 a 再按序提取 >>> a[a.ge(0.5)] ## by mask=a.ge(0.5),该方法没有保持 shape