torch.argsort 函数介绍

torch.argsort 是 PyTorch 中用于返回张量沿指定维度的排序索引的函数。它不会直接对张量的值进行排序,而是返回一个与张量相同形状的索引张量,指示出原张量中每个元素的排序顺序。

函数签名

torch.argsort(input, dim=-1, descending=False)

参数

  • input:输入的张量。
  • dim :指定进行排序的维度,默认为最后一个维度(-1)。
  • descending :如果为 True,则返回按降序排列的索引;如果为 False(默认),则按升序排列。

返回

返回与输入张量 input 形状相同的张量,其中每个元素为排序后的索引值。

使用场景

  • 需要知道张量中元素的排序顺序而不是排序后的实际值时,可以使用 argsort
  • 可以用于生成一种排序的掩码,比如对某些值进行有序操作。

示例代码

1. 基本用法
import torch

# 创建一个张量
x = torch.tensor([3.5, 1.2, 4.8, 2.9])

# 对张量进行升序排序,并返回排序后的索引
sorted_indices = torch.argsort(x)
print(sorted_indices)  # 输出:tensor([1, 3, 0, 2])

解释:torch.argsort 返回的是按升序排列元素的索引。索引 1 对应值 1.2(最小),依次类推。

2. 在多维张量上使用
import torch

# 创建一个 2D 张量
x = torch.tensor([[4, 1, 3],
                  [2, 8, 5]])

# 在每一行上对张量进行升序排序,dim=1 表示对行操作
sorted_indices = torch.argsort(x, dim=1)
print(sorted_indices)

# 在每一列上对张量进行升序排序,dim=0 表示对列操作
sorted_indices_col = torch.argsort(x, dim=0)
print(sorted_indices_col)

输出

tensor([[1, 2, 0],
        [0, 2, 1]])

tensor([[1, 0, 0],
        [0, 1, 1]])

解释:第一个例子返回的是按行升序排列的索引,第二个例子返回的是按列升序排列的索引。

3. 降序排序
import torch

# 创建一个张量
x = torch.tensor([3.5, 1.2, 4.8, 2.9])

# 对张量进行降序排序,并返回排序后的索引
sorted_indices = torch.argsort(x, descending=True)
print(sorted_indices)  # 输出:tensor([2, 0, 3, 1])

解释:这里使用了 descending=True,所以 torch.argsort 按降序返回索引,索引 2 对应值 4.8(最大),依次类推。

4. 与 gather 配合使用

你可以使用 torch.argsort 来获取排序后的索引,并结合 torch.gather 获取排序后的值:

import torch

# 创建一个张量
x = torch.tensor([3.5, 1.2, 4.8, 2.9])

# 获取排序后的索引
sorted_indices = torch.argsort(x)

# 根据索引获取排序后的值
sorted_x = torch.gather(x, 0, sorted_indices)
print(sorted_x)  # 输出:tensor([1.2000, 2.9000, 3.5000, 4.8000])

解释:通过 torch.gather 函数,可以根据排序后的索引来重新排列原张量的值。

总结

torch.argsort 函数非常适合需要知道张量元素排序顺序的场景。它返回的是元素在升序或降序排序中的索引,而不是排序后的实际值。结合其他 PyTorch 函数(如 gather),可以灵活实现很多数据处理操作。

相关推荐
无脑敲代码,bug漫天飞几秒前
COR 损失函数
人工智能·机器学习
幽兰的天空5 分钟前
Python 中的模式匹配:深入了解 match 语句
开发语言·python
HPC_fac130520678161 小时前
以科学计算为切入点:剖析英伟达服务器过热难题
服务器·人工智能·深度学习·机器学习·计算机视觉·数据挖掘·gpu算力
网易独家音乐人Mike Zhou3 小时前
【卡尔曼滤波】数据预测Prediction观测器的理论推导及应用 C语言、Python实现(Kalman Filter)
c语言·python·单片机·物联网·算法·嵌入式·iot
安静读书3 小时前
Python解析视频FPS(帧率)、分辨率信息
python·opencv·音视频
小陈phd4 小时前
OpenCV从入门到精通实战(九)——基于dlib的疲劳监测 ear计算
人工智能·opencv·计算机视觉
Guofu_Liao5 小时前
大语言模型---LoRA简介;LoRA的优势;LoRA训练步骤;总结
人工智能·语言模型·自然语言处理·矩阵·llama
小二·5 小时前
java基础面试题笔记(基础篇)
java·笔记·python
小喵要摸鱼6 小时前
Python 神经网络项目常用语法
python
一念之坤8 小时前
零基础学Python之数据结构 -- 01篇
数据结构·python