torch.argsort
是 PyTorch 中用于返回张量沿指定维度的排序索引的函数。它不会直接对张量的值进行排序,而是返回一个与张量相同形状的索引张量,指示出原张量中每个元素的排序顺序。
函数签名
torch.argsort(input, dim=-1, descending=False)
参数
input
:输入的张量。dim
:指定进行排序的维度,默认为最后一个维度(-1
)。descending
:如果为True
,则返回按降序排列的索引;如果为False
(默认),则按升序排列。
返回
返回与输入张量 input
形状相同的张量,其中每个元素为排序后的索引值。
使用场景
- 需要知道张量中元素的排序顺序而不是排序后的实际值时,可以使用
argsort
。 - 可以用于生成一种排序的掩码,比如对某些值进行有序操作。
示例代码
1. 基本用法
import torch
# 创建一个张量
x = torch.tensor([3.5, 1.2, 4.8, 2.9])
# 对张量进行升序排序,并返回排序后的索引
sorted_indices = torch.argsort(x)
print(sorted_indices) # 输出:tensor([1, 3, 0, 2])
解释:torch.argsort
返回的是按升序排列元素的索引。索引 1
对应值 1.2
(最小),依次类推。
2. 在多维张量上使用
import torch
# 创建一个 2D 张量
x = torch.tensor([[4, 1, 3],
[2, 8, 5]])
# 在每一行上对张量进行升序排序,dim=1 表示对行操作
sorted_indices = torch.argsort(x, dim=1)
print(sorted_indices)
# 在每一列上对张量进行升序排序,dim=0 表示对列操作
sorted_indices_col = torch.argsort(x, dim=0)
print(sorted_indices_col)
输出:
tensor([[1, 2, 0],
[0, 2, 1]])
tensor([[1, 0, 0],
[0, 1, 1]])
解释:第一个例子返回的是按行升序排列的索引,第二个例子返回的是按列升序排列的索引。
3. 降序排序
import torch
# 创建一个张量
x = torch.tensor([3.5, 1.2, 4.8, 2.9])
# 对张量进行降序排序,并返回排序后的索引
sorted_indices = torch.argsort(x, descending=True)
print(sorted_indices) # 输出:tensor([2, 0, 3, 1])
解释:这里使用了 descending=True
,所以 torch.argsort
按降序返回索引,索引 2
对应值 4.8
(最大),依次类推。
4. 与 gather
配合使用
你可以使用 torch.argsort
来获取排序后的索引,并结合 torch.gather
获取排序后的值:
import torch
# 创建一个张量
x = torch.tensor([3.5, 1.2, 4.8, 2.9])
# 获取排序后的索引
sorted_indices = torch.argsort(x)
# 根据索引获取排序后的值
sorted_x = torch.gather(x, 0, sorted_indices)
print(sorted_x) # 输出:tensor([1.2000, 2.9000, 3.5000, 4.8000])
解释:通过 torch.gather
函数,可以根据排序后的索引来重新排列原张量的值。
总结
torch.argsort
函数非常适合需要知道张量元素排序顺序的场景。它返回的是元素在升序或降序排序中的索引,而不是排序后的实际值。结合其他 PyTorch 函数(如 gather
),可以灵活实现很多数据处理操作。