torch.argsort 函数介绍

torch.argsort 是 PyTorch 中用于返回张量沿指定维度的排序索引的函数。它不会直接对张量的值进行排序,而是返回一个与张量相同形状的索引张量,指示出原张量中每个元素的排序顺序。

函数签名

复制代码
torch.argsort(input, dim=-1, descending=False)

参数

  • input:输入的张量。
  • dim :指定进行排序的维度,默认为最后一个维度(-1)。
  • descending :如果为 True,则返回按降序排列的索引;如果为 False(默认),则按升序排列。

返回

返回与输入张量 input 形状相同的张量,其中每个元素为排序后的索引值。

使用场景

  • 需要知道张量中元素的排序顺序而不是排序后的实际值时,可以使用 argsort
  • 可以用于生成一种排序的掩码,比如对某些值进行有序操作。

示例代码

1. 基本用法
复制代码
import torch

# 创建一个张量
x = torch.tensor([3.5, 1.2, 4.8, 2.9])

# 对张量进行升序排序,并返回排序后的索引
sorted_indices = torch.argsort(x)
print(sorted_indices)  # 输出:tensor([1, 3, 0, 2])

解释:torch.argsort 返回的是按升序排列元素的索引。索引 1 对应值 1.2(最小),依次类推。

2. 在多维张量上使用
复制代码
import torch

# 创建一个 2D 张量
x = torch.tensor([[4, 1, 3],
                  [2, 8, 5]])

# 在每一行上对张量进行升序排序,dim=1 表示对行操作
sorted_indices = torch.argsort(x, dim=1)
print(sorted_indices)

# 在每一列上对张量进行升序排序,dim=0 表示对列操作
sorted_indices_col = torch.argsort(x, dim=0)
print(sorted_indices_col)

输出

复制代码
tensor([[1, 2, 0],
        [0, 2, 1]])

tensor([[1, 0, 0],
        [0, 1, 1]])

解释:第一个例子返回的是按行升序排列的索引,第二个例子返回的是按列升序排列的索引。

3. 降序排序
复制代码
import torch

# 创建一个张量
x = torch.tensor([3.5, 1.2, 4.8, 2.9])

# 对张量进行降序排序,并返回排序后的索引
sorted_indices = torch.argsort(x, descending=True)
print(sorted_indices)  # 输出:tensor([2, 0, 3, 1])

解释:这里使用了 descending=True,所以 torch.argsort 按降序返回索引,索引 2 对应值 4.8(最大),依次类推。

4. 与 gather 配合使用

你可以使用 torch.argsort 来获取排序后的索引,并结合 torch.gather 获取排序后的值:

复制代码
import torch

# 创建一个张量
x = torch.tensor([3.5, 1.2, 4.8, 2.9])

# 获取排序后的索引
sorted_indices = torch.argsort(x)

# 根据索引获取排序后的值
sorted_x = torch.gather(x, 0, sorted_indices)
print(sorted_x)  # 输出:tensor([1.2000, 2.9000, 3.5000, 4.8000])

解释:通过 torch.gather 函数,可以根据排序后的索引来重新排列原张量的值。

总结

torch.argsort 函数非常适合需要知道张量元素排序顺序的场景。它返回的是元素在升序或降序排序中的索引,而不是排序后的实际值。结合其他 PyTorch 函数(如 gather),可以灵活实现很多数据处理操作。

相关推荐
每天都要写算法(努力版)几秒前
【神经网络与深度学习】训练集与验证集的功能解析与差异探究
人工智能·深度学习·神经网络
cloudy4912 分钟前
强化学习:历史基金净产值,学习最大化长期收益
python·强化学习
Bruce_Liuxiaowei13 分钟前
使用Python脚本在Mac上彻底清除Chrome浏览历史:开发实战与隐私保护指南
chrome·python·macos
vocal19 分钟前
谷歌第七版Prompt Engineering—第一部分
人工智能
MonkeyKing_sunyuhua20 分钟前
5.6 Microsoft Semantic Kernel:专注于将LLM集成到现有应用中的框架
人工智能·microsoft·agent
ruyingcai66666625 分钟前
用python进行OCR识别
开发语言·python·ocr
arbboter28 分钟前
【AI插件开发】Notepad++ AI插件开发1.0发布和使用说明
人工智能·大模型·notepad++·ai助手·ai插件·aicoder·notepad++插件开发
Niuguangshuo28 分钟前
Python设计模式:MVC模式
python·设计模式·mvc
TOMGRIL32 分钟前
文件的读取操作
python
liuweidong080236 分钟前
【Pandas】pandas DataFrame radd
开发语言·python·pandas