tensor.topk 以及tensor.argmax
topk(self, k, dim=None, largest=True, sorted=True):,返回两个值,values与indices。
argmax(self, dim=None, keepdim=False): 返回Tensor
python
'''
具体使用方法定义里写的很清楚,topk中largest=False返回最小值,sorted打乱原有的元素顺序。
argmax中keepdim=True:保持维数不变,默认会减少一维。
'''
import torch
a = torch.randn(2, 2, 2)
a.argmax(1)
'''
tensor([[0, 1],
[1, 0]])
'''
a.topk(k=1)
'''
torch.return_types.topk(
values=tensor([[[ 0.8766],
[-0.1330]],
[[ 1.5773],
[ 0.8146]]]),
indices=tensor([[[0],
[1]],
[[1],
[0]]]))
'''
a.topk(k=1).values
'''
tensor([[[ 0.8766],
[-0.1330]],
[[ 1.5773],
[ 0.8146]]])
'''
a.topk(1).indices
'''
tensor([[[0],
[1]],
[[1],
[0]]])
'''