torch.argsort 函数介绍

torch.argsort 是 PyTorch 中用于返回张量沿指定维度的排序索引的函数。它不会直接对张量的值进行排序,而是返回一个与张量相同形状的索引张量,指示出原张量中每个元素的排序顺序。

函数签名

复制代码
torch.argsort(input, dim=-1, descending=False)

参数

  • input:输入的张量。
  • dim :指定进行排序的维度,默认为最后一个维度(-1)。
  • descending :如果为 True,则返回按降序排列的索引;如果为 False(默认),则按升序排列。

返回

返回与输入张量 input 形状相同的张量,其中每个元素为排序后的索引值。

使用场景

  • 需要知道张量中元素的排序顺序而不是排序后的实际值时,可以使用 argsort
  • 可以用于生成一种排序的掩码,比如对某些值进行有序操作。

示例代码

1. 基本用法
复制代码
import torch

# 创建一个张量
x = torch.tensor([3.5, 1.2, 4.8, 2.9])

# 对张量进行升序排序,并返回排序后的索引
sorted_indices = torch.argsort(x)
print(sorted_indices)  # 输出:tensor([1, 3, 0, 2])

解释:torch.argsort 返回的是按升序排列元素的索引。索引 1 对应值 1.2(最小),依次类推。

2. 在多维张量上使用
复制代码
import torch

# 创建一个 2D 张量
x = torch.tensor([[4, 1, 3],
                  [2, 8, 5]])

# 在每一行上对张量进行升序排序,dim=1 表示对行操作
sorted_indices = torch.argsort(x, dim=1)
print(sorted_indices)

# 在每一列上对张量进行升序排序,dim=0 表示对列操作
sorted_indices_col = torch.argsort(x, dim=0)
print(sorted_indices_col)

输出

复制代码
tensor([[1, 2, 0],
        [0, 2, 1]])

tensor([[1, 0, 0],
        [0, 1, 1]])

解释:第一个例子返回的是按行升序排列的索引,第二个例子返回的是按列升序排列的索引。

3. 降序排序
复制代码
import torch

# 创建一个张量
x = torch.tensor([3.5, 1.2, 4.8, 2.9])

# 对张量进行降序排序,并返回排序后的索引
sorted_indices = torch.argsort(x, descending=True)
print(sorted_indices)  # 输出:tensor([2, 0, 3, 1])

解释:这里使用了 descending=True,所以 torch.argsort 按降序返回索引,索引 2 对应值 4.8(最大),依次类推。

4. 与 gather 配合使用

你可以使用 torch.argsort 来获取排序后的索引,并结合 torch.gather 获取排序后的值:

复制代码
import torch

# 创建一个张量
x = torch.tensor([3.5, 1.2, 4.8, 2.9])

# 获取排序后的索引
sorted_indices = torch.argsort(x)

# 根据索引获取排序后的值
sorted_x = torch.gather(x, 0, sorted_indices)
print(sorted_x)  # 输出:tensor([1.2000, 2.9000, 3.5000, 4.8000])

解释:通过 torch.gather 函数,可以根据排序后的索引来重新排列原张量的值。

总结

torch.argsort 函数非常适合需要知道张量元素排序顺序的场景。它返回的是元素在升序或降序排序中的索引,而不是排序后的实际值。结合其他 PyTorch 函数(如 gather),可以灵活实现很多数据处理操作。

相关推荐
小白狮ww7 小时前
RStudio 教程:以抑郁量表测评数据分析为例
人工智能·算法·机器学习
沧海一粟青草喂马7 小时前
抖音批量上传视频怎么弄?抖音矩阵账号管理的专业指南
大数据·人工智能·矩阵
demaichuandong7 小时前
详细讲解锥齿轮丝杆升降机的加工制造工艺
人工智能·自动化·制造
ZZHow10247 小时前
02OpenCV基本操作
python·opencv·计算机视觉
理智的煎蛋8 小时前
CentOS/Ubuntu安装显卡驱动与GPU压力测试
大数据·人工智能·ubuntu·centos·gpu算力
计算机学长felix8 小时前
基于Django的“酒店推荐系统”设计与开发(源码+数据库+文档+PPT)
数据库·python·mysql·django·vue
站大爷IP8 小时前
Python随机数函数全解析:5个核心工具的实战指南
python
知来者逆8 小时前
视觉语言模型应用开发——Qwen 2.5 VL模型视频理解与定位能力深度解析及实践指南
人工智能·语言模型·自然语言处理·音视频·视觉语言模型·qwen 2.5 vl
IT_陈寒8 小时前
Java性能优化:10个让你的Spring Boot应用提速300%的隐藏技巧
前端·人工智能·后端
Android出海8 小时前
Android 15重磅升级:16KB内存页机制详解与适配指南
android·人工智能·新媒体运营·产品运营·内容运营