文章目录
- [PyTorch 张量核心操作------比较、排序与数据校验](#PyTorch 张量核心操作——比较、排序与数据校验)
-
- 一、张量比较运算:判断元素间的关系
-
- [1. 基础比较运算符](#1. 基础比较运算符)
- [2. 比较函数:`torch.eq()` 与 `torch.ne()` 等](#2. 比较函数:
torch.eq()与torch.ne()等) - [3. 张量相等性判断 `torch.equal()`](#3. 张量相等性判断
torch.equal()) -
- [3.1 `torch.equal()` 函数原型](#3.1
torch.equal()函数原型) - [3.2 核心特点:判断"整体相等性"](#3.2 核心特点:判断“整体相等性”)
- [3.3 代码示例与结果分析](#3.3 代码示例与结果分析)
- [3.4 与 `==` 运算符的区别](#3.4 与
==运算符的区别) - [3.5 使用场景与注意事项](#3.5 使用场景与注意事项)
- [3.1 `torch.equal()` 函数原型](#3.1
- 二、排序操作:`torch.sort()`
-
- [1. 一维张量排序](#1. 一维张量排序)
- [2. 多维张量排序(指定维度)](#2. 多维张量排序(指定维度))
- [三、Top-K 选取:`torch.topk()`](#三、Top-K 选取:
torch.topk()) -
- [1. 一维张量的 Top-K 选取](#1. 一维张量的 Top-K 选取)
- [2. 多维张量的 Top-K 选取(指定维度)](#2. 多维张量的 Top-K 选取(指定维度))
- [四、K-th 值选取:`torch.kthvalue()`](#四、K-th 值选取:
torch.kthvalue()) -
- [1. 一维张量的 K-th 值](#1. 一维张量的 K-th 值)
- [2. 多维张量的 K-th 值(指定维度)](#2. 多维张量的 K-th 值(指定维度))
- 五、数据合法性校验:检测异常值
-
- [1. 检测 `NaN`:`torch.isnan()`](#1. 检测
NaN:torch.isnan()) - [2. 检测 `Inf`:`torch.isinf()`](#2. 检测
Inf:torch.isinf()) - [3. 检测有限值:`torch.isfinite()`](#3. 检测有限值:
torch.isfinite()) -
- [`torch.isfinite()` 与其他检测函数的关系](#
torch.isfinite()与其他检测函数的关系)
- [`torch.isfinite()` 与其他检测函数的关系](#
- [4. 综合校验:同时检测 `NaN` 和 `Inf`](#4. 综合校验:同时检测
NaN和Inf)
- [1. 检测 `NaN`:`torch.isnan()`](#1. 检测
- 六、总结与应用场景
- 六、总结与应用场景
PyTorch 张量核心操作------比较、排序与数据校验
在深度学习开发中,张量(Tensor)的比较、排序和数据校验是基础且高频的操作。无论是模型训练中的数据预处理,还是推理阶段的结果分析,这些操作都扮演着重要角色。本文将系统讲解 PyTorch 中与张量相关的比较运算、排序方法、Top-K 选取、K-th 值提取以及数据合法性校验,涵盖函数原型、参数详解、代码示例和结果分析,帮助初学者全面掌握这些核心技能。
一、张量比较运算:判断元素间的关系
比较运算用于判断张量元素之间的大小关系或相等性,返回与输入张量形状相同的布尔型张量(True/False)。PyTorch 提供了丰富的比较运算符和函数,支持元素级别的逐元素比较。
1. 基础比较运算符
PyTorch 支持与 Python 类似的比较运算符,包括 ==(等于)、!=(不等于)、>(大于)、<(小于)、>=(大于等于)、<=(小于等于)。这些运算符均为元素级运算,即对两个张量的对应元素逐一进行比较。
运算符特点:
- 要求两个张量形状相同或可广播(广播机制见后文补充)。
- 返回布尔型张量,
True表示满足条件,False表示不满足。
运算原理:
张量的比较运算遵循 "位置对应" 原则 :两个张量必须形状相同(或可通过广播机制扩展为相同形状),然后对相同位置的元素 逐一进行比较,最终生成一个形状相同的布尔张量(True/False)。
- "大小" 的含义 :张量中元素的大小就是其数值大小(如
3 > 2、-1 < 0等),与元素在张量中的位置无关。 - 布尔张量的意义 :结果中
True表示对应位置的元素满足比较条件,False表示不满足。
运算演示:
示例:比较两个一维张量 a = [1, 3, 5, 7] 和 b = [2, 3, 4, 8]
-
步骤 1:确认张量形状
-
a的形状:(4,)(1 维,4 个元素) -
b的形状:(4,)(1 维,4 个元素)形状相同,可直接比较(无需广播)。
-
-
步骤 2:逐元素比较(以
a > b为例)比较逻辑:对每个索引
i(0 ≤ i < 4),判断a[i] > b[i]是否成立。-
i=0 :
a[0] = 1,b[0] = 2→1 > 2?→ False -
i=1 :
a[1] = 3,b[1] = 3→3 > 3?→ False -
i=2 :
a[2] = 5,b[2] = 4→5 > 4?→ True -
i=3 :
a[3] = 7,b[3] = 8→7 > 8?→ False
-
-
步骤 3:生成结果张量
- 将上述判断结果按原位置组合,得到布尔张量:
a > b的结果为[False, False, True, False]
- 将上述判断结果按原位置组合,得到布尔张量:
代码示例:
python
import torch
# 定义两个形状相同的张量
a = torch.tensor([1, 3, 5, 7])
b = torch.tensor([2, 3, 4, 8])
# 比较运算
print("a == b:", a == b) # 等于
print("a != b:", a != b) # 不等于
print("a > b: ", a > b) # 大于
print("a < b: ", a < b) # 小于
print("a >= b:", a >= b) # 大于等于
print("a <= b:", a <= b) # 小于等于
运行结果:
a == b: tensor([False, True, False, False])
a != b: tensor([ True, False, True, True])
a > b: tensor([False, False, True, False])
a < b: tensor([ True, False, False, True])
a >= b: tensor([False, True, True, False])
a <= b: tensor([ True, True, False, True])
结果分析:
- 每个运算符都对
a和b的对应元素进行比较(如a[0]=1与b[0]=2比较,1 < 2故a < b的第 0 位为True)。 - 布尔张量的形状与输入张量一致(均为
(4,)),便于后续基于条件筛选元素(如a[a > b]可提取a中大于b对应元素的值)。
2. 比较函数:torch.eq() 与 torch.ne() 等
除运算符外,PyTorch 还提供了对应的函数形式,如 torch.eq()(等于)、torch.ne()(不等于)、torch.gt()(大于)、torch.lt()(小于)、torch.ge()(大于等于)、torch.le()(小于等于)。这些函数与运算符功能一致,但支持更灵活的参数设置(如广播)。
函数原型:
python
torch.eq(input, other, *, out=None) → Tensor
torch.ne(input, other, *, out=None) → Tensor
torch.gt(input, other, *, out=None) → Tensor
# 其余函数参数类似
参数说明:
input:输入张量(第一个比较对象)。other:第二个比较对象(可以是张量或标量)。out(可选):输出张量,用于存储结果(需与预期输出形状一致)。
代码示例(支持广播机制):
python
# 形状不同但可广播的张量比较
a = torch.tensor([[1, 2, 3], [4, 5, 6]]) # 形状 (2, 3)
b = torch.tensor([3]) # 标量张量(形状 ()),可广播为 (2, 3)
# 使用函数进行比较(等价于 a < 3)
lt_result = torch.lt(a, b)
print("a < 3 的结果:\n", lt_result)
运行结果:
a < 3 的结果:
tensor([[ True, True, False],
[False, False, False]])
结果分析:
b是标量张量,通过广播机制自动扩展为与a同形状的张量[[3, 3, 3], [3, 3, 3]]。torch.lt(a, b)逐元素比较a和扩展后的b,返回布尔张量(a中元素小于 3 的位置为True)。
比较运算的核心应用:
- 条件筛选:通过
tensor[布尔张量]提取满足条件的元素(如a[a > 5]提取a中大于 5 的元素)。 - 掩码操作:生成掩码张量用于数据过滤或加权计算。
- 结果验证:在模型推理中判断预测结果与标签的匹配情况(如计算准确率时统计
pred == label的数量)。
3. 张量相等性判断 torch.equal()
在 PyTorch 中,torch.equal() 是一个用于判断两个张量是否完全相等 的函数。它与我们之前讲过的元素级比较运算符(如 ==)不同,后者返回一个布尔张量,而 torch.equal() 会返回一个单一的布尔值(True 或 False),表示两个张量是否在所有元素和形状上都完全一致。
3.1 torch.equal() 函数原型
python
torch.equal(input1, input2) → bool
参数说明:
input1:第一个待比较的张量。input2:第二个待比较的张量。
返回值:
- 布尔值(
True或False):如果两个张量的形状相同 且所有对应元素都相等 ,则返回True;否则返回False。
3.2 核心特点:判断"整体相等性"
torch.equal() 的核心是整体判断,而非元素级判断。它有两个严格条件:
- 两个张量的形状必须完全相同;
- 两个张量所有对应位置的元素必须完全相等。
只有同时满足这两个条件,才会返回 True。
3.3 代码示例与结果分析
示例 1:形状相同且元素全相等
python
import torch
a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 2, 3])
print(torch.equal(a, b)) # 输出:True
分析 :a 和 b 形状均为 (3,),且所有元素对应相等,因此返回 True。
示例 2:形状相同但元素不全相等
python
a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 4, 3])
print(torch.equal(a, b)) # 输出:False
分析 :虽然形状相同,但 a[1]=2 与 b[1]=4 不相等,因此返回 False。
示例 3:形状不同(即使元素"看起来"对应)
python
a = torch.tensor([[1, 2], [3, 4]]) # 形状 (2, 2)
b = torch.tensor([1, 2, 3, 4]) # 形状 (4,)
print(torch.equal(a, b)) # 输出:False
分析 :a 是 2×2 的二维张量,b 是长度为 4 的一维张量,形状不同,直接返回 False。
示例 4:浮点数的相等判断(需注意精度)
python
a = torch.tensor([1.0, 2.0])
b = torch.tensor([1.0 + 1e-9, 2.0]) # 第一个元素有微小差异
print(torch.equal(a, b)) # 输出:False
分析 :浮点数由于精度问题,即使差异极小(如 1e-9),torch.equal() 也会判定为不相等。如果需要忽略微小误差,应使用 torch.allclose()(后续会介绍)。
示例 5:包含 NaN 的张量(特殊情况)
python
a = torch.tensor([1.0, torch.nan])
b = torch.tensor([1.0, torch.nan])
print(torch.equal(a, b)) # 输出:False
分析 :由于 NaN 与任何值(包括自身)都不相等(数学定义),因此即使两个张量都含 NaN,torch.equal() 也会返回 False。
3.4 与 == 运算符的区别
| 操作 | 返回值类型 | 核心逻辑 | 典型用途 |
|---|---|---|---|
torch.equal(a, b) |
单一布尔值(bool) |
判断整体是否完全相等 | 验证两个张量是否完全一致 |
a == b |
布尔张量(Tensor) |
元素级比较,返回每个位置的结果 | 筛选特定位置的元素 |
对比示例:
python
a = torch.tensor([1, 2, 3])
b = torch.tensor([1, 4, 3])
print("a == b 的结果:", a == b) # 元素级比较
print("torch.equal(a, b):", torch.equal(a, b)) # 整体比较
输出:
a == b 的结果: tensor([ True, False, True])
torch.equal(a, b): False
3.5 使用场景与注意事项
场景:
- 验证张量是否完全一致:如检查模型参数在训练前后是否发生预期变化,或验证两个计算结果是否完全相同。
- 单元测试:在测试代码中,判断函数输出是否与预期张量完全一致。
- 调试:排查代码中张量是否在传递过程中被意外修改(形状或元素)。
注意:
- 形状优先 :即使两个张量的元素"数量相同"但形状不同(如
(2,2)和(4,)),也会返回False。 - 浮点数精度 :对浮点数张量,微小的精度误差(如数值计算中的舍入误差)会导致返回
False,此时应使用torch.allclose()并设置合理的误差容限(如atol=1e-5)。 NaN特殊处理 :含NaN的张量几乎不可能被判定为相等,需单独处理NaN位置(如先通过torch.isnan()检测)。
总之,torch.equal() 是判断两个张量"完全一致性"的便捷工具,适合需要严格验证张量是否相同的场景。但使用时需注意形状、浮点数精度和 NaN 等特殊情况。
二、排序操作:torch.sort()
排序是将张量元素按升序或降序重新排列的操作。PyTorch 中 torch.sort() 函数不仅返回排序后的张量,还返回原元素在排序后的位置索引,这对后续分析元素来源至关重要。
函数原型:
python
torch.sort(input, dim=-1, descending=False, *, out=None) → (Tensor, Tensor)
参数说明:
input:需要排序的输入张量。dim(默认-1):指定排序的维度(如dim=0按行排序,dim=1按列排序)。descending(默认False):排序方式,False为升序,True为降序。out(可选):元组(sorted_tensor, indices),用于存储输出结果。
返回值:
- 元组
(sorted_tensor, indices):sorted_tensor:排序后的张量,形状与输入一致。indices:整数张量,记录原张量元素在排序后张量中的位置索引。
1. 一维张量排序
代码示例:
python
x = torch.tensor([3, 1, 4, 2, 5])
# 升序排序(默认)
sorted_x, indices = torch.sort(x)
print("升序排序结果:", sorted_x)
print("原索引位置:", indices)
# 降序排序
sorted_x_desc, indices_desc = torch.sort(x, descending=True)
print("降序排序结果:", sorted_x_desc)
print("原索引位置:", indices_desc)
运行结果:
升序排序结果: tensor([1, 2, 3, 4, 5])
原索引位置: tensor([1, 3, 0, 2, 4])
降序排序结果: tensor([5, 4, 3, 2, 1])
原索引位置: tensor([4, 2, 0, 3, 1])
结果分析:
- 升序排序后,
sorted_x为[1, 2, 3, 4, 5],indices表示原张量中元素的位置(如sorted_x[0] = 1来自原张量的索引1)。 - 降序排序通过
descending=True实现,结果为[5, 4, 3, 2, 1],索引对应原元素位置。
2. 多维张量排序(指定维度)
多维张量排序需通过 dim 参数指定排序维度,不同维度的排序结果差异显著。
代码示例:
python
x = torch.tensor([[3, 1, 2],
[6, 4, 5]]) # 形状 (2, 3)
# 按列维度(dim=1)升序排序(每行内部排序)
sorted_row, indices_row = torch.sort(x, dim=1)
print("按行内元素排序结果:\n", sorted_row)
print("行内排序索引:\n", indices_row)
# 按行维度(dim=0)升序排序(每列内部排序)
sorted_col, indices_col = torch.sort(x, dim=0)
print("按列内元素排序结果:\n", sorted_col)
print("列内排序索引:\n", indices_col)
运行结果:
按行内元素排序结果:
tensor([[1, 2, 3],
[4, 5, 6]])
行内排序索引:
tensor([[1, 2, 0],
[1, 2, 0]])
按列内元素排序结果:
tensor([[3, 1, 2],
[6, 4, 5]])
列内排序索引:
tensor([[0, 0, 0],
[1, 1, 1]])
结果分析:
dim=1表示按列维度排序(每行内部元素重新排列),第一行[3,1,2]排序后为[1,2,3],索引[1,2,0]对应原元素位置。dim=0表示按行维度排序(每列内部元素重新排列),由于原张量列元素已按升序排列(如第一列[3,6]),排序后结果不变,索引[0,1]表示原行位置。
三、Top-K 选取:torch.topk()
在很多场景中,我们不需要对整个张量排序,只需获取最大或最小的 k 个元素(如推荐系统中的 Top-N 物品)。torch.topk() 函数可高效实现这一功能,无需全量排序,计算效率更高。
函数原型:
python
torch.topk(input, k, dim=-1, largest=True, sorted=True, *, out=None) → (Tensor, Tensor)
参数说明:
input:输入张量。k:需要选取的元素数量(必须满足1 ≤ k ≤ 输入张量在指定维度的大小)。dim(默认-1):选取元素的维度。largest(默认True):选取方式,True表示取最大的k个元素,False表示取最小的k个元素。sorted(默认True):返回结果是否按大小排序(True表示排序,False表示不保证顺序)。out(可选):元组(values, indices),用于存储输出结果。
返回值:
- 元组
(values, indices):values:选取的k个元素值。indices:这些元素在原张量中的位置索引。
1. 一维张量的 Top-K 选取
代码示例:
python
x = torch.tensor([7, 2, 8, 1, 9, 3])
# 取最大的 3 个元素(默认)
top3_vals, top3_indices = torch.topk(x, k=3)
print("最大的 3 个元素:", top3_vals)
print("对应原索引:", top3_indices)
# 取最小的 2 个元素(largest=False)
bottom2_vals, bottom2_indices = torch.topk(x, k=2, largest=False)
print("最小的 2 个元素:", bottom2_vals)
print("对应原索引:", bottom2_indices)
运行结果:
最大的 3 个元素: tensor([9, 8, 7])
对应原索引: tensor([4, 2, 0])
最小的 2 个元素: tensor([1, 2])
对应原索引: tensor([3, 1])
结果分析:
k=3且largest=True时,返回最大的 3 个元素[9,8,7],其原索引分别为4(x[4]=9)、2(x[2]=8)、0(x[0]=7)。largest=False时,返回最小的 2 个元素[1,2],对应原索引3和1。
2. 多维张量的 Top-K 选取(指定维度)
代码示例:
python
x = torch.tensor([[5, 2, 8],
[3, 9, 1]]) # 形状 (2, 3)
# 对每行取最大的 2 个元素(dim=1)
row_top2_vals, row_top2_indices = torch.topk(x, k=2, dim=1)
print("每行最大的 2 个元素:\n", row_top2_vals)
print("对应列索引:\n", row_top2_indices)
# 对每列取最小的 1 个元素(dim=0, largest=False)
col_bottom1_vals, col_bottom1_indices = torch.topk(x, k=1, dim=0, largest=False)
print("每列最小的 1 个元素:\n", col_bottom1_vals)
print("对应行索引:\n", col_bottom1_indices)
运行结果:
每行最大的 2 个元素:
tensor([[8, 5],
[9, 3]])
对应列索引:
tensor([[2, 0],
[1, 0]])
每列最小的 1 个元素:
tensor([[3],
[2],
[1]])
对应行索引:
tensor([[1],
[0],
[1]])
结果分析:
dim=1表示按行取 Top-K,第一行[5,2,8]最大的 2 个元素是8(列索引 2)和5(列索引 0)。dim=0且largest=False表示按列取最小元素,第一列[5,3]最小元素是3(行索引 1),以此类推。
四、K-th 值选取:torch.kthvalue()
torch.kthvalue() 用于获取张量中第 k 小的元素 (按升序排列后的第 k 个元素,索引从 1 开始)。与 torch.topk() 不同,它聚焦于"特定排名"的元素,而非前 k 个元素。
函数原型:
python
torch.kthvalue(input, k, dim=-1, *, out=None) → (Tensor, Tensor)
参数说明:
input:输入张量。k:第 k 小的元素(1 ≤ k ≤ 输入张量在指定维度的大小,注意 k 从 1 开始计数)。dim(默认-1):选取元素的维度。out(可选):元组(value, index),用于存储输出结果。
返回值:
- 元组
(value, index):value:第 k 小的元素值。index:该元素在原张量中的位置索引。
1. 一维张量的 K-th 值
代码示例:
python
x = torch.tensor([3, 1, 4, 2, 5]) # 升序排列后为 [1, 2, 3, 4, 5]
# 取第 3 小的元素(k=3)
k3_val, k3_idx = torch.kthvalue(x, k=3)
print("第 3 小的元素值:", k3_val)
print("对应原索引:", k3_idx)
# 取第 1 小的元素(k=1,即最小值)
k1_val, k1_idx = torch.kthvalue(x, k=1)
print("第 1 小的元素值(最小值):", k1_val)
运行结果:
第 3 小的元素值: tensor(3)
对应原索引: tensor(0)
第 1 小的元素值(最小值): tensor(1)
结果分析:
- 原张量升序排列后为
[1,2,3,4,5],第 3 小的元素是3,对应原张量的索引0(x[0]=3)。 k=1时返回最小值1,验证了函数的正确性。
2. 多维张量的 K-th 值(指定维度)
代码示例:
python
x = torch.tensor([[5, 2, 8],
[3, 9, 1]]) # 形状 (2, 3)
# 对每行取第 2 小的元素(dim=1, k=2)
row_k2_val, row_k2_idx = torch.kthvalue(x, k=2, dim=1)
print("每行第 2 小的元素值:", row_k2_val)
print("对应列索引:", row_k2_idx)
运行结果:
每行第 2 小的元素值: tensor([5, 3])
对应列索引: tensor([0, 0])
结果分析:
- 第一行
[5,2,8]升序后为[2,5,8],第 2 小的元素是5,对应原列索引0。 - 第二行
[3,9,1]升序后为[1,3,9],第 2 小的元素是3,对应原列索引0。
五、数据合法性校验:检测异常值
在深度学习中,张量中若存在 NaN(非数)、Inf(无穷大)等异常值,会导致模型训练发散或推理结果错误。因此,数据预处理和训练过程中需对异常值进行检测和处理。PyTorch 提供了专门的函数用于异常值检测。
1. 检测 NaN:torch.isnan()
NaN(Not a Number)通常由无效运算产生(如 0/0、sqrt(-1) 等),torch.isnan() 可标记张量中所有 NaN 元素。
函数原型:
python
torch.isnan(input, *, out=None) → Tensor
参数说明:
input:输入张量(通常为浮点型)。out(可选):输出布尔张量,用于存储结果。
代码示例:
python
x = torch.tensor([1.0, float('nan'), 3.0, torch.nan]) # 包含 NaN 的张量
# 检测 NaN
is_nan = torch.isnan(x)
print("NaN 位置标记:", is_nan)
print("非 NaN 元素:", x[~is_nan]) # ~ 表示逻辑取反
运行结果:
NaN 位置标记: tensor([False, True, False, True])
非 NaN 元素: tensor([1., 3.])
结果分析:
float('nan')(Python 原生)和torch.nan(PyTorch 定义)均会被检测为NaN。- 通过
x[~is_nan]可筛选出所有非NaN元素,实现数据清洗。
2. 检测 Inf:torch.isinf()
Inf(无穷大)由溢出运算产生(如 1/0),torch.isinf() 可标记所有 Inf 元素(包括正无穷 +inf 和负无穷 -inf)。
函数原型:
python
torch.isinf(input, *, out=None) → Tensor
代码示例:
python
x = torch.tensor([1.0, float('inf'), -float('inf'), 5.0]) # 包含 Inf 的张量
# 检测 Inf
is_inf = torch.isinf(x)
print("Inf 位置标记:", is_inf)
print("非 Inf 元素:", x[~is_inf])
运行结果:
Inf 位置标记: tensor([False, True, True, False])
非 Inf 元素: tensor([1., 5.])
结果分析:
- 正无穷
float('inf')和负无穷-float('inf')均被标记为True。 - 筛选后仅保留正常元素
[1.0, 5.0]。
3. 检测有限值:torch.isfinite()
有限值 指的是既不是 NaN 也不是 Inf(包括正、负无穷)的正常数值(如 1.0、-3.5 等)。torch.isfinite() 函数返回一个布尔张量,其中 True 表示对应元素是有限值,False 表示元素是 NaN 或 Inf。
函数原型:
python
torch.isfinite(input, *, out=None) → Tensor
参数说明:
input:输入张量(通常为浮点型)。out(可选):输出布尔张量,用于存储结果。
代码示例:
python
x = torch.tensor([
1.0, # 有限值
torch.nan, # NaN(非有限值)
float('inf'), # 正无穷(非有限值)
-float('inf'),# 负无穷(非有限值)
3.14 # 有限值
])
# 检测有限值
is_finite = torch.isfinite(x)
print("有限值位置标记:", is_finite)
print("所有有限值元素:", x[is_finite]) # 直接筛选有限值
运行结果:
plaintext
有限值位置标记: tensor([ True, False, False, False, True])
所有有限值元素: tensor([1.0000, 3.1400])
结果分析:
-
torch.isfinite(x)直接标记出所有有限值元素(1.0和3.14),返回True;对NaN、+inf、-inf均返回False。 -
与 "先检测异常值再取反"(
~(is_nan | is_inf))相比,torch.isfinite()是更简洁的方式,直接筛选有限值元素。
torch.isfinite() 与其他检测函数的关系
torch.isfinite() 的结果等价于对 torch.isnan() 和 torch.isinf() 取反的逻辑与,即:
python
is_finite = ~torch.isnan(x) & ~torch.isinf(x)
但 torch.isfinite() 是专门优化的函数,计算效率更高,且代码更简洁。
4. 综合校验:同时检测 NaN 和 Inf
实际应用中,通常需要同时检测 NaN 和 Inf,可通过逻辑运算符 |(或)实现。
代码示例:
python
x = torch.tensor([2.0, torch.nan, float('inf'), -float('inf'), 3.0])
# 同时检测 NaN 和 Inf
is_abnormal = torch.isnan(x) | torch.isinf(x)
print("异常值位置标记:", is_abnormal)
print("正常元素:", x[~is_abnormal])
运行结果:
异常值位置标记: tensor([False, True, True, True, False])
正常元素: tensor([2., 3.])
结果分析:
torch.isnan(x) | torch.isinf(x)标记所有NaN或Inf元素,返回布尔张量。- 筛选后仅保留正常元素
[2.0, 3.0],确保数据合法性。
六、总结与应用场景
| 操作类型 | 核心函数/运算符 | 关键功能 | 典型应用场景 |
|---|---|---|---|
| 比较运算 | ==/!=/torch.eq() 等 |
元素级关系判断(等于、大于等) | 条件筛选、掩码生成、结果验证 |
| 排序 | torch.sort() |
按指定维度升序/降序排序,返回索引 | 全量排序、元素顺序分析 |
| Top-K 选取 | torch.topk() |
高效获取最大/最小的 k 个元素 | 推荐系统、模型推理加速 |
| K-th 值选取 | torch.kthvalue() |
获取第 k 小的元素 | 统计分析(如中位数计算) |
| 数据校验 | torch.isnan()/torch.isinf() |
检测 NaN/Inf 异常值 |
数据预处理、训练过程异常监控 |
掌握这些操作后,你可以:
- 快速筛选符合条件的张量元素;
- 高效获取排序后的结果或Top-K元素,优化模型推理;
- 检测并处理异常值,保障模型训练稳定性。
保数据合法性。
六、总结与应用场景
| 操作类型 | 核心函数/运算符 | 关键功能 | 典型应用场景 |
|---|---|---|---|
| 比较运算 | ==/!=/torch.eq() 等 |
元素级关系判断(等于、大于等) | 条件筛选、掩码生成、结果验证 |
| 排序 | torch.sort() |
按指定维度升序/降序排序,返回索引 | 全量排序、元素顺序分析 |
| Top-K 选取 | torch.topk() |
高效获取最大/最小的 k 个元素 | 推荐系统、模型推理加速 |
| K-th 值选取 | torch.kthvalue() |
获取第 k 小的元素 | 统计分析(如中位数计算) |
| 数据校验 | torch.isnan()/torch.isinf() |
检测 NaN/Inf 异常值 |
数据预处理、训练过程异常监控 |
掌握这些操作后,你可以:
- 快速筛选符合条件的张量元素;
- 高效获取排序后的结果或Top-K元素,优化模型推理;
- 检测并处理异常值,保障模型训练稳定性。
这些操作是深度学习开发的基础工具,无论是数据预处理、模型训练还是结果分析,都离不开它们的灵活应用。建议结合实际场景多做练习,加深理解。