torch.searchsorted

torch.searchsorted

官方文档链接:torch.searchsorted --- PyTorch 2.3 documentation

该函数用于在已排序的序列中查找要插入的值的位置,以保持序列的顺序,

复制代码
torch.searchsorted(sorted_sequence, values, *, out_int32=False, right=False, side=None, out=None, sorter=None) → Tensor

参数如下,

  • sorted_sequence:这是一个N-D或1-D的张量,其中包含按最内部维度单调递增的序列。如果提供了sorter参数,则序列不需要按顺序排列

  • values:这是一个N-D张量或标量,包含要搜索的值

  • out_int32:这是一个可选参数,用于指示输出数据类型。如果为True,则输出数据类型为torch.int32,否则为torch.int64

  • right:这是一个可选参数,如果为False,则返回找到的第一个合适位置。如果为 True,则返回最后一个索引。如果找不到合适的索引,则对于非数值值(例如nan、inf),返回0,或者返回sorted_sequence内最内部维度的大小(超过最内部维度的最后一个索引)。如果为False,则获取每个值在sorted_sequence相应内部维度上的下限索引,如果为True,则获取上限索引。默认值为False

  • side:这是一个可选参数,"left" 对应于right为 False,"right" 对应于right为 True。如果将其设置为 "left",而right为 True,则会报错。默认值为None。

  • out:这是一个可选参数,输出张量,如果提供,则必须与 values 的大小相同

  • sorter:这是一个可选参数,如果提供,则是一个与未排序的sorted_sequence形状相匹配的张量,其中包含一个按最内部维度升序排列的索引序列

使用示例如下,

复制代码
sorted_sequence = torch.tensor([[1, 3, 5, 7, 9], [2, 4, 6, 8, 10]])
"""
tensor([[ 1,  3,  5,  7,  9],
        [ 2,  4,  6,  8, 10]])
"""

values = torch.tensor([[3, 6, 9], [3, 6, 9]])
"""
tensor([[3, 6, 9],
        [3, 6, 9]])
"""

torch.searchsorted(sorted_sequence, values)
"""
tensor([[1, 3, 4],
        [1, 2, 4]])
对于第一行 [3, 6, 9]:
数字3在第一行的sorted_sequence中的位置是索引1
数字6在第一行的sorted_sequence中的位置是索引3(6大于5而小于7,因此将6插入到索引3的位置时,能够使序列保持升序排序)
数字9在第一行的sorted_sequence中的位置是索引4
对于第二行 [3, 6, 9]:
数字3在第二行的sorted_sequence中的位置是索引1(3大于2而小于4,因此当索引为1时,不会改变序列的升序排序)
数字6在第二行的sorted_sequence中的位置是索引2
数字9在第二行的sorted_sequence中的位置是索引4(9大于8而小于10,因此当索引为4时,不会改变序列的升序排序)
"""

## 当side='right'时, 函数会返回每个值在对应行的sorted_sequence中的右侧插入位置索引
torch.searchsorted(sorted_sequence, values, side='right')
"""
tensor([[2, 3, 5],
        [1, 3, 4]])

对于第一行 [3, 6, 9]:
数字3在第一行的sorted_sequence中的右侧插入位置是索引2(数字3的右侧插入位置索引是2)
数字6在第一行的sorted_sequence中的右侧插入位置是索引3
数字9在第一行的sorted_sequence中的右侧插入位置是索引5(数字9的右侧插入位置索引是5)
对于第二行 [3, 6, 9]:
数字3在第二行的sorted_sequence中的右侧插入位置是索引1
数字6在第二行的sorted_sequence中的右侧插入位置是索引3(数字6的右侧插入位置索引是3)
数字9在第二行的sorted_sequence中的右侧插入位置是索引4
"""

sorted_sequence_1d = torch.tensor([1, 3, 5, 7, 9])
"""
tensor([1, 3, 5, 7, 9])
"""

torch.searchsorted(sorted_sequence_1d, values)
"""
tensor([[1, 3, 4],
        [1, 3, 4]])
"""
相关推荐
零号机2 分钟前
使用TRAE 30分钟极速开发一款划词中英互译浏览器插件
前端·人工智能
FunTester3 分钟前
基于 Cursor 的智能测试用例生成系统 - 项目介绍与实施指南
人工智能·ai·大模型·测试用例·实践指南·curor·智能测试用例
SEO_juper10 分钟前
LLMs.txt 创建指南:为大型语言模型优化您的网站
人工智能·ai·语言模型·自然语言处理·数字营销
B站_计算机毕业设计之家19 分钟前
大数据YOLOv8无人机目标检测跟踪识别系统 深度学习 PySide界面设计 大数据 ✅
大数据·python·深度学习·信息可视化·数据挖掘·数据分析·flask
淮雵的Blog25 分钟前
langGraph通俗易懂的解释、langGraph和使用API直接调用LLM的区别
人工智能
Mintopia28 分钟前
🚀 共绩算力:3分钟拥有自己的文生图AI服务-容器化部署 StableDiffusion1.5-WebUI 应用
前端·人工智能·aigc
HPC_C34 分钟前
SGLang: Efficient Execution of Structured Language Model Programs
人工智能·语言模型·自然语言处理
王哈哈^_^43 分钟前
【完整源码+数据集】草莓数据集,yolov8草莓成熟度检测数据集 3207 张,草莓成熟度数据集,目标检测草莓识别算法系统实战教程
人工智能·算法·yolo·目标检测·计算机视觉·视觉检测·毕业设计
songyuc1 小时前
《A Bilateral CFAR Algorithm for Ship Detection in SAR Images》译读笔记
人工智能·笔记·计算机视觉
码界奇点1 小时前
解密AI语言模型从原理到应用的全景解析
人工智能·语言模型·自然语言处理·架构