torch.argsort 函数介绍

torch.argsort 是 PyTorch 中用于返回张量沿指定维度的排序索引的函数。它不会直接对张量的值进行排序,而是返回一个与张量相同形状的索引张量,指示出原张量中每个元素的排序顺序。

函数签名

torch.argsort(input, dim=-1, descending=False)

参数

  • input:输入的张量。
  • dim :指定进行排序的维度,默认为最后一个维度(-1)。
  • descending :如果为 True,则返回按降序排列的索引;如果为 False(默认),则按升序排列。

返回

返回与输入张量 input 形状相同的张量,其中每个元素为排序后的索引值。

使用场景

  • 需要知道张量中元素的排序顺序而不是排序后的实际值时,可以使用 argsort
  • 可以用于生成一种排序的掩码,比如对某些值进行有序操作。

示例代码

1. 基本用法
import torch

# 创建一个张量
x = torch.tensor([3.5, 1.2, 4.8, 2.9])

# 对张量进行升序排序,并返回排序后的索引
sorted_indices = torch.argsort(x)
print(sorted_indices)  # 输出:tensor([1, 3, 0, 2])

解释:torch.argsort 返回的是按升序排列元素的索引。索引 1 对应值 1.2(最小),依次类推。

2. 在多维张量上使用
import torch

# 创建一个 2D 张量
x = torch.tensor([[4, 1, 3],
                  [2, 8, 5]])

# 在每一行上对张量进行升序排序,dim=1 表示对行操作
sorted_indices = torch.argsort(x, dim=1)
print(sorted_indices)

# 在每一列上对张量进行升序排序,dim=0 表示对列操作
sorted_indices_col = torch.argsort(x, dim=0)
print(sorted_indices_col)

输出

tensor([[1, 2, 0],
        [0, 2, 1]])

tensor([[1, 0, 0],
        [0, 1, 1]])

解释:第一个例子返回的是按行升序排列的索引,第二个例子返回的是按列升序排列的索引。

3. 降序排序
import torch

# 创建一个张量
x = torch.tensor([3.5, 1.2, 4.8, 2.9])

# 对张量进行降序排序,并返回排序后的索引
sorted_indices = torch.argsort(x, descending=True)
print(sorted_indices)  # 输出:tensor([2, 0, 3, 1])

解释:这里使用了 descending=True,所以 torch.argsort 按降序返回索引,索引 2 对应值 4.8(最大),依次类推。

4. 与 gather 配合使用

你可以使用 torch.argsort 来获取排序后的索引,并结合 torch.gather 获取排序后的值:

import torch

# 创建一个张量
x = torch.tensor([3.5, 1.2, 4.8, 2.9])

# 获取排序后的索引
sorted_indices = torch.argsort(x)

# 根据索引获取排序后的值
sorted_x = torch.gather(x, 0, sorted_indices)
print(sorted_x)  # 输出:tensor([1.2000, 2.9000, 3.5000, 4.8000])

解释:通过 torch.gather 函数,可以根据排序后的索引来重新排列原张量的值。

总结

torch.argsort 函数非常适合需要知道张量元素排序顺序的场景。它返回的是元素在升序或降序排序中的索引,而不是排序后的实际值。结合其他 PyTorch 函数(如 gather),可以灵活实现很多数据处理操作。

相关推荐
changwan4 分钟前
基于Celery+Supervisord的异步任务管理方案
后端·python·性能优化
君秋水4 分钟前
Python异步编程指南:asyncio从入门到精通(Python 3.10+)
后端·python
訾博ZiBo12 分钟前
AI日报 - 2025年3月7日
人工智能
梓羽玩Python14 分钟前
一夜刷屏AI圈!Manus:这不是聊天机器人,是你的“AI打工仔”!
人工智能
Gene_INNOCENT15 分钟前
大型语言模型训练的三个阶段:Pre-Train、Instruction Fine-tuning、RLHF (PPO / DPO / GRPO)
人工智能·深度学习·语言模型
游戏智眼16 分钟前
中国团队发布通用型AI Agent产品Manus;GPT-4.5正式面向Plus用户推出;阿里发布并开源推理模型通义千问QwQ-32B...|游戏智眼日报
人工智能·游戏·游戏引擎·aigc
挣扎与觉醒中的技术人17 分钟前
如何优化FFmpeg拉流性能及避坑指南
人工智能·深度学习·性能优化·ffmpeg·aigc·ai编程
君秋水17 分钟前
FastAPI教程:20个核心概念从入门到 happy使用
后端·python·程序员
watersink21 分钟前
Dify框架下的基于RAG流程的政务检索平台
人工智能·深度学习·机器学习
脑极体24 分钟前
在MWC2025,读懂华为如何以行践言
大数据·人工智能·华为