pytorch比较操作

文章目录


常用的比较操作


1.torch.allclose()

torch.allclose() 是 PyTorch 中用于比较两个张量是否在给定的容差范围内近似相等的函数。它可以用于比较浮点数张量之间的相等性。

python 复制代码
torch.allclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False)
"""
input:第一个输入张量。
other:第二个输入张量。
rtol:相对容差(relative tolerance),默认为 1e-05。
atol:绝对容差(absolute tolerance),默认为 1e-08。
equal_nan:一个布尔值,指示是否将 NaN 视为相等,默认为 False。
"""
python 复制代码
import torch

# 比较两个张量是否近似相等
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([1.0001, 2.0002, 3.0003])
is_close = torch.allclose(x, y, rtol=1e-03, atol=1e-05)

print(is_close)# True

2.torch.argsort()

torch.argsort() 是 PyTorch 中用于对张量进行排序并返回排序后的索引的函数。它返回一个新的张量,其中每个元素表示原始张量中对应位置的元素在排序后的顺序中的索引值。

python 复制代码
torch.argsort(input, dim=-1, descending=False, *, out=None)
"""
input:输入张量。
dim:指定排序的维度,默认为 -1,表示最后一个维度。
descending:一个布尔值,指示是否按降序排序,默认为 False。
out:可选参数,用于指定输出张量的位置。
"""
python 复制代码
import torch

# 对张量进行排序并返回索引
x = torch.tensor([3, 1, 4, 2])
sorted_indices = torch.argsort(x)

print(sorted_indices)
# tensor([1, 3, 0, 2])

3.torch.eq()

torch.eq() 是 PyTorch 中用于执行元素级别相等性比较的函数。它比较两个张量的对应元素,并返回一个新的布尔张量,其中元素为 True 表示对应位置的元素相等,元素为 False 表示对应位置的元素不相等。

python 复制代码
torch.eq(input, other, out=None)
"""
input:第一个输入张量。
other:第二个输入张量。
out:可选参数,用于指定输出张量的位置。
"""
python 复制代码
import torch

# 执行元素级别的相等性比较
x = torch.tensor([1, 2, 3])
y = torch.tensor([1, 2, 4])
result = torch.eq(x, y)

print(result)# tensor([ True,  True, False])

4.torch.equal()

torch.equal() 是 PyTorch 中用于检查两个张量是否在元素级别上完全相等的函数。它返回一个布尔值,指示两个张量是否具有相同的形状和相同的元素值。

python 复制代码
torch.equal(input, other)
"""
input:第一个输入张量。
other:第二个输入张量。
"""
python 复制代码
import torch

# 检查两个张量是否完全相等
x = torch.tensor([1, 2, 3])
y = torch.tensor([1, 2, 3])
is_equal = torch.equal(x, y)

print(is_equal)# True

5.torch.greater_equal()

torch.greater_equal() 是 PyTorch 中用于执行元素级别的大于等于比较的函数。它比较两个张量的对应元素,并返回一个新的布尔张量,其中元素为 True 表示对应位置的元素大于或等于,元素为 False 表示对应位置的元素小于。

python 复制代码
torch.greater_equal(input, other, out=None)
"""
input:第一个输入张量。
other:第二个输入张量。
out:可选参数,用于指定输出张量的位置。
"""
python 复制代码
import torch

# 执行元素级别的大于等于比较
x = torch.tensor([1, 2, 3])
y = torch.tensor([2, 2, 2])
result = torch.greater_equal(x, y)

print(result)
python 复制代码
tensor([False,  True,  True])

6.torch.gt()

torch.gt() 是 PyTorch 中用于执行元素级别的大于比较的函数。它比较两个张量的对应元素,并返回一个新的布尔张量,其中元素为 True 表示对应位置的元素大于,元素为 False 表示对应位置的元素小于或等于。

python 复制代码
torch.gt(input, other, out=None)
"""
input:第一个输入张量。
other:第二个输入张量。
out:可选参数,用于指定输出张量的位置。
"""
python 复制代码
import torch

# 执行元素级别的大于比较
x = torch.tensor([1, 2, 3])
y = torch.tensor([2, 2, 2])
result = torch.gt(x, y)

print(result)#tensor([False, False,  True])

7.torch.isclose()

torch.isclose() 是 PyTorch 中用于比较两个张量是否在给定的容差范围内近似相等的函数。它可以用于比较浮点数张量之间的相等性。

python 复制代码
torch.isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False)
"""
input:第一个输入张量。
other:第二个输入张量。
rtol:相对容差(relative tolerance),默认为 1e-05。
atol:绝对容差(absolute tolerance),默认为 1e-08。
equal_nan:一个布尔值,指示是否将 NaN 视为相等,默认为 False。
"""
python 复制代码
import torch

# 比较两个张量是否近似相等
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.tensor([1.0001, 2.0002, 3.0003])
is_close = torch.isclose(x, y, rtol=1e-03, atol=1e-05)

print(is_close)
python 复制代码
tensor([True, True, True])

8.torch.isfinite()

torch.isfinite() 是 PyTorch 中用于检查张量中的元素是否为有限数(finite number)的函数。它返回一个新的布尔张量,其中每个元素表示对应位置的元素是否为有限数。

python 复制代码
torch.isfinite(input, out=None)
"""
input:输入张量。
out:可选参数,用于指定输出张量的位置。
"""
python 复制代码
import torch

# 检查张量中的元素是否为有限数
x = torch.tensor([1.0, float('inf'), float('-inf'), float('nan')])
is_finite = torch.isfinite(x)

print(is_finite)# tensor([ True, False, False, False])

9.torch.isif()

torch.isinf() 是 PyTorch 中用于检查张量中的元素是否为无穷大的函数。它返回一个新的布尔张量,其中每个元素表示对应位置的元素是否为无穷大。

python 复制代码
torch.isinf(input, out=None)
"""
input:输入张量。
out:可选参数,用于指定输出张量的位置。
"""
python 复制代码
import torch
# 检查张量中的元素是否为无穷大
x = torch.tensor([1.0, float('inf'), float('-inf'), float('nan')])
is_inf = torch.isinf(x)

print(is_inf)
python 复制代码
tensor([False,  True,  True, False])

10.torch.isposinf()

torch.isposinf() 是 PyTorch 中用于检查张量中的元素是否为正无穷大的函数。它返回一个新的布尔张量,其中每个元素表示对应位置的元素是否为正无穷大。

python 复制代码
torch.isposinf(input, out=None)
"""
input:输入张量。
out:可选参数,用于指定输出张量的位置。
"""
python 复制代码
import torch

# 检查张量中的元素是否为正无穷大
x = torch.tensor([1.0, float('inf'), float('-inf'), float('nan')])
is_posinf = torch.isposinf(x)

print(is_posinf)# tensor([False,  True, False, False])

11.torch.isneginf()

torch.isneginf() 是 PyTorch 中用于检查张量中的元素是否为负无穷大的函数。它返回一个新的布尔张量,其中每个元素表示对应位置的元素是否为负无穷大。

python 复制代码
torch.isneginf(input, out=None)
"""
input:输入张量。
out:可选参数,用于指定输出张量的位置。
"""
python 复制代码
import torch

# 检查张量中的元素是否为负无穷大
x = torch.tensor([1.0, float('inf'), float('-inf'), float('nan')])
is_neginf = torch.isneginf(x)

print(is_neginf)# tensor([False, False,  True, False])

12.torch.isnan()

torch.isnan() 是 PyTorch 中用于检查张量中的元素是否为 NaN(Not a Number)的函数。它返回一个新的布尔张量,其中每个元素表示对应位置的元素是否为 NaN。

python 复制代码
torch.isnan(input, out=None)
"""
input:输入张量。
out:可选参数,用于指定输出张量的位置。
"""
python 复制代码
import torch

# 检查张量中的元素是否为 NaN
x = torch.tensor([1.0, float('inf'), float('-inf'), float('nan')])
is_nan = torch.isnan(x)

print(is_nan)# tensor([False, False, False,  True])

13.torch.kthvalue()

torch.kthvalue() 函数用于找出张量中的第 k 小值,而 torch.topk() 函数用于找出张量中的前 k 个最大值(或最小值)及其对应的索引。

python 复制代码
torch.topk(input, k, dim=None, largest=True, sorted=True, out=None)
"""
input:输入张量。
k:要找到的最大值(或最小值)的数量。
dim:可选参数,指定在哪个维度上进行查找。如果未指定,则默认在最后一个维度上查找。
largest:可选参数,指定是找到最大值还是最小值。默认为 True,表示找到最大值。
sorted:可选参数,指定结果张量是否按降序排列。默认为 True。
out:可选参数,用于指定输出张量的位置。
"""
python 复制代码
import torch

# 找出张量中的前 3 个最大值及其索引
x = torch.tensor([1, 3, 2, 4, 6, 5])
values, indices = torch.topk(x, k=3)

print(values)#tensor([6, 5, 4])
print(indices)#tensor([4, 5, 3])

14.torch.less_equal()

torch.less_equal() 是 PyTorch 中用于执行逐元素的小于等于(<=)比较的函数。它比较两个张量的对应元素,并返回一个新的布尔张量,其中每个元素表示对应位置的元素是否满足小于等于的条件。

python 复制代码
torch.less_equal(input, other, out=None)
"""
input:输入张量。
other:用于比较的另一个张量或标量值。
out:可选参数,用于指定输出张量的位置。
"""
python 复制代码
import torch

# 执行逐元素的小于等于比较
x = torch.tensor([1, 2, 3])
y = torch.tensor([2, 2, 2])
result = torch.less_equal(x, y)

print(result)# tensor([ True,  True, False])

15.torch.maximum()

torch.maximum() 是 PyTorch 中用于执行逐元素的最大值比较的函数。它比较两个张量的对应元素,并返回一个新的张量,其中每个元素是对应位置的最大值。

python 复制代码
torch.maximum(input, other, out=None)
"""
input:输入张量。
other:用于比较的另一个张量或标量值。
out:可选参数,用于指定输出张量的位置。
"""
python 复制代码
import torch

# 执行逐元素的最大值比较
x = torch.tensor([1, 2, 3])
y = torch.tensor([2, 1, 4])
result = torch.maximum(x, y)

print(result)# tensor([2, 2, 4])

16.torch.fmax()

torch.fmax() 是 PyTorch 中用于执行逐元素的最大值比较的函数,专门用于处理浮点数类型。它比较两个张量的对应元素,并返回一个新的张量,其中每个元素是对应位置的最大值。与 torch.maximum() 不同,torch.fmax() 函数在处理浮点数时会保留 NaN 值。如果其中一个张量的元素为 NaN,那么在对应位置上将返回另一个张量的值。

python 复制代码
torch.fmax(input, other, out=None)
"""
input:输入张量。
other:用于比较的另一个张量或标量值。
out:可选参数,用于指定输出张量的位置。
"""
python 复制代码
import torch

# 执行逐元素的最大值比较
x = torch.tensor([1.0, 2.0, float('nan')])
y = torch.tensor([2.0, 1.0, 3.0])
result = torch.fmax(x, y)

print(result)# tensor([2., 2., 3.])

17.torch.ne()

torch.ne() 是 PyTorch 中用于执行逐元素的不等于(!=)比较的函数。它比较两个张量的对应元素,并返回一个新的布尔张量,其中每个元素表示对应位置的元素是否不相等。

python 复制代码
torch.ne(input, other, out=None)
"""
input:输入张量。
other:用于比较的另一个张量或标量值。
out:可选参数,用于指定输出张量的位置。
"""
python 复制代码
import torch

# 执行逐元素的不等于比较
x = torch.tensor([1, 2, 3])
y = torch.tensor([2, 2, 2])
result = torch.ne(x, y)

print(result)# tensor([ True, False,  True])

18.torch.sort()

torch.sort() 是 PyTorch 中用于对张量进行排序的函数。它返回一个元组,包含排序后的值张量和对应的索引张量。

python 复制代码
torch.sort(input, dim=None, descending=False, out=None)
"""
input:输入张量。
dim:可选参数,指定在哪个维度上进行排序。如果未指定,则默认在最后一个维度上进行排序。
descending:可选参数,指定是否按降序排列。默认为 False,表示按升序排列。
out:可选参数,用于指定输出张量的位置。

torch.sort() 函数返回一个元组 (sorted_values, sorted_indices),其中:
sorted_values 是排序后的值张量。
sorted_indices 是排序后的值在原始张量中对应的索引张量。
"""
python 复制代码
import torch

# 对张量进行排序
x = torch.tensor([3, 1, 2])
sorted_values, sorted_indices = torch.sort(x)

print(sorted_values)# tensor([1, 2, 3])
print(sorted_indices)# tensor([1, 2, 0])

19.torch.topk()

torch.topk() 是 PyTorch 中用于获取张量中最大值或最小值的 k 个元素的函数。它返回一个元组,包含排序后的值张量和对应的索引张量。

python 复制代码
torch.topk(input, k, dim=None, largest=True, sorted=True, out=None)
"""
input:输入张量。
k:要获取的最大或最小值的个数。
dim:可选参数,指定在哪个维度上进行操作。如果未指定,则默认在最后一个维度上进行操作。
largest:可选参数,指定是否获取最大值。默认为 True,表示获取最大值。如果设置为 False,则获取最小值。
sorted:可选参数,指定是否返回排序结果。默认为 True,表示返回排序结果。如果设置为 False,则返回未排序的结果。
out:可选参数,用于指定输出张量的位置。
"""
python 复制代码
import torch

# 获取张量中的最大值和对应的索引
x = torch.tensor([3, 1, 2, 5, 4])
top_values, top_indices = torch.topk(x, k=3)

print(top_values)#tensor([5, 4, 3])
print(top_indices)# tensor([3, 4, 0])
相关推荐
Adolf_19934 分钟前
Flask-JWT-Extended登录验证, 不用自定义
后端·python·flask
冯宝宝^4 分钟前
基于mongodb+flask(Python)+vue的实验室器材管理系统
vue.js·python·flask
大耳朵爱学习7 分钟前
掌握Transformer之注意力为什么有效
人工智能·深度学习·自然语言处理·大模型·llm·transformer·大语言模型
TAICHIFEI9 分钟前
目标检测-数据集
人工智能·目标检测·目标跟踪
qq_153214526415 分钟前
【2023工业异常检测文献】SimpleNet
图像处理·人工智能·深度学习·神经网络·机器学习·计算机视觉·视觉检测
叫我:松哥15 分钟前
基于Python flask的医院管理学院,医生能够增加/删除/修改/删除病人的数据信息,有可视化分析
javascript·后端·python·mysql·信息可视化·flask·bootstrap
洛阳泰山18 分钟前
如何使用Chainlit让所有网站快速嵌入一个AI聊天助手Copilot
人工智能·ai·llm·copilot·网站·chainlit·copliot
儿创社ErChaungClub27 分钟前
解锁编程新境界:GitHub Copilot 让效率翻倍
人工智能·算法
乙真仙人32 分钟前
AIGC时代!AI的“iPhone时刻”与投资机遇
人工智能·aigc·iphone
Eiceblue1 小时前
Python 复制Excel 中的行、列、单元格
开发语言·python·excel