在 PyTorch 中,torch.topk
函数用于在输入张量中找到最大的k
个值及其索引。
一、函数语法
torch.topk(input, k, dim=None, largest=True, sorted=True, out=None)
input
:输入张量。k
:要返回的最大或最小元素的数量。dim
(可选):要进行操作的维度。如果为None
,则在扁平的输入张量上进行操作。largest
(可选):如果为True
,则返回最大的k
个值;如果为False
,则返回最小的k
个值。sorted
(可选):如果为True
,则返回的k
个值将按降序(如果largest=True
)或升序(如果largest=False
)排列;如果为False
,则返回的k
个值的顺序是未定义的。out
(可选):输出张量,可以是一个已存在的张量,用于存储结果。
二、返回值
该函数返回一个包含两个张量的元组:
- 第一个张量是包含最大或最小的
k
个值的张量。 - 第二个张量是包含这些值在输入张量中的索引的张量。
三、使用示例
import torch
# 创建一个二维张量
tensor = torch.tensor([[4, 2, 3], [1, 5, 6]])
# 找到每行中的最大的两个值及其索引
values, indices = torch.topk(tensor, k=2, dim=1, largest=True)
print("最大的两个值:", values)
print("对应的索引:", indices)
# 找到每列中的最小的两个值及其索引
values, indices = torch.topk(tensor, k=2, dim=0, largest=False)
print("最小的两个值:", values)
print("对应的索引:", indices)
在上述示例中,首先创建了一个二维张量。然后,分别在行维度和列维度上使用torch.topk
函数找到最大的两个值及其索引和最小的两个值及其索引,并打印出结果。