pytorch torch.nan_to_num函数介绍

torch.nan_to_num 函数简介

torch.nan_to_num 是 PyTorch 中的一个函数,用于将张量中的特殊浮点值(如 NaN+Inf-Inf)替换为指定的数值,或使用默认替代值。

函数签名

复制代码
torch.nan_to_num(input, nan=0.0, posinf=None, neginf=None)

参数

  1. input:

    • 输入张量。
    • 可以包含 NaN、正无穷(+Inf)、负无穷(-Inf)等特殊值。
  2. nan (可选):

    • 替换 NaN 的值。
    • 默认是 0.0
  3. posinf (可选):

    • 替换正无穷 (+Inf) 的值。
    • 默认是张量元素的最大有限值 (torch.finfo(input.dtype).max)。
  4. neginf (可选):

    • 替换负无穷 (-Inf) 的值。
    • 默认是张量元素的最小有限值 (torch.finfo(input.dtype).min)。

返回值

  • 返回一个张量,其中的 NaN+Inf-Inf 被替换为指定的值。
  • 输出张量与输入张量的形状和数据类型相同。

工作原理

  • NaN : 检测到 NaN 后,替换为参数 nan 指定的值。
  • +Inf-Inf : 检测到无穷值后,分别替换为参数 posinfneginf 指定的值。

简单示例

复制代码
import torch

# 创建包含 NaN、+Inf 和 -Inf 的张量
x = torch.tensor([float('nan'), float('inf'), -float('inf'), 1.0, -2.0])

# 替换 NaN 和 Inf
result = torch.nan_to_num(x, nan=0.0, posinf=10.0, neginf=-10.0)
print(result)

输出:

复制代码
tensor([  0.,  10., -10.,   1.,  -2.])

使用默认值

如果没有指定 posinfneginf,函数会使用数据类型的最大或最小值。

复制代码
x = torch.tensor([float('nan'), float('inf'), -float('inf')], dtype=torch.float32)

result = torch.nan_to_num(x)
print(result)

输出:

复制代码
tensor([ 0.0000e+00,  3.4028e+38, -3.4028e+38])

其中 3.4028e+38-3.4028e+38 分别是 float32 类型的最大和最小有限值。

广播支持

torch.nan_to_num 支持广播机制,当输入包含多维张量时同样可以逐元素替换:

复制代码
x = torch.tensor([[float('nan'), float('inf')], [-float('inf'), 1.0]])
result = torch.nan_to_num(x, nan=0.0, posinf=100.0, neginf=-100.0)
print(result)

输出:

复制代码
tensor([[   0.,  100.],
        [-100.,    1.]])

应用场景

1. 清洗数据 : 替换缺失值(NaN)或异常值(+Inf-Inf)以便进一步处理。

复制代码
x = torch.tensor([float('nan'), 5.0, float('inf'), -float('inf')])
clean_x = torch.nan_to_num(x, nan=0.0)
print(clean_x)  # tensor([ 0.,  5.,  max_value, min_value])

2. 防止计算异常 : 在模型训练或推理过程中,防止出现 NaN 或无穷值导致的计算失败。

3. 图像/信号处理: 在处理图像或信号数据时,用于替换缺失的像素值或异常值。

注意事项

  1. 数据类型兼容性:

    • 如果输入张量的类型为整数,使用 torch.nan_to_num 会报错,因为整数类型无法表示 NaN 或无穷值。
    • 函数只能用于浮点类型张量(如 torch.float32, torch.float64)。
  2. 默认替换值:

    • 对于正无穷和负无穷,默认替换值依赖于张量的数据类型。
  3. 性能开销:

    • 对大张量来说,函数调用会带来一定的计算开销,需在实际应用中注意性能。

总结

torch.nan_to_num 是处理数据异常(如缺失值和溢出值)的重要工具,特别适用于数据预处理和深度学习模型的训练过程。通过灵活的参数设置,可以有效替换各种特殊值,保证后续计算的稳定性和可靠性。

相关推荐
扫地的小何尚几秒前
Isaac Lab 2.3深度解析:全身控制与增强遥操作如何重塑机器人学习
arm开发·人工智能·学习·自然语言处理·机器人·gpu·nvidia
元基时代1 分钟前
视频图文矩阵发布系统企业
大数据·人工智能·矩阵
岁月宁静8 分钟前
AI聊天系统 实战:打造优雅的聊天记录复制与批量下载功能
前端·vue.js·人工智能
IT_陈寒15 分钟前
SpringBoot性能飞跃:5个关键优化让你的应用吞吐量提升300%
前端·人工智能·后端
kunge1v51 小时前
学习爬虫第三天:数据提取
前端·爬虫·python·学习
爱学习的小鱼gogo1 小时前
python 矩阵中寻找就接近的目标值 (矩阵-中等)含源码(八)
开发语言·经验分享·python·算法·职场和发展·矩阵
聚客AI1 小时前
系统提示的“消亡”?上下文工程正在重新定义人机交互规则
图像处理·人工智能·pytorch·语言模型·自然语言处理·chatgpt·gpt-3
Hello.Reader1 小时前
Flink 状态模式演进(State Schema Evolution)从原理到落地的一站式指南
python·flink·状态模式
红纸2811 小时前
Subword算法之WordPiece、Unigram与SentencePiece
人工智能·python·深度学习·神经网络·算法·机器学习·自然语言处理
golang学习记1 小时前
Crush:新一代基于Go语言构建的开源 AI 编程CLI工具
人工智能