【深度学习】Pytorch训练过程中损失值出现NaN

项目场景

利用Pytorch框架,结合FEDformer开源代码(https://github.com/MAZiqing/FEDformer),将自己的数据集作为输入训练模型。

问题描述

训练过程中,发现打印出来的Train loss, Test loss, Test loss中,Test loss从第一个epoch开始就为nan。

输出中间结果后,发现第一个epoch训练到了第二个batch时,模型输出开始出现了nan。

原因分析

查阅了相关资料,有这样一些说法:

  • 梯度爆炸:batch size较大、学习率较大、数据特征之间值的差异较大
  • 数据本身有缺失值

之后针对数据的缺失值进行了统计,发现并没有缺失值。所以初步认为是发生了梯度爆炸。

随后我做了多组实验,观察每次epoch的每个batch的预测结果是否存在nan:

  • 对比实验a: 不断减小batch size
  • 对比实验b: 不断减小学习率
  • 对比实验c: 减少数据集特征的个数

最终发现,是数据集特征的问题。数据集的某个特征和其他特征数值差异较大,导致模型在反向传播计算梯度的时候计算出的梯度值过大,从而导致了梯度爆炸。

解决方案

经过理论分析,这一列特征对于实验结果的影响不会很大,故直接将这一列特征从数据中删除。之后的实验结果也表明确实是这一列的引入导致了模型训练出现了NaN。

总结

深度学习训练过程中损失值出现NaN的情况:

  • 梯度爆炸:batch size较大、学习率较大、数据特征之间值的差异较大
  • 数据本身有缺失值
相关推荐
hqyjzsb26 分钟前
2025文职转行AI管理岗:衔接型认证成为关键路径
大数据·c语言·人工智能·信息可视化·媒体·caie
恒点虚拟仿真28 分钟前
赋能成长,聚力前行——恒点启动核心产品知识与高品质方案系列专题培训
人工智能·虚拟仿真实验·恒点虚拟仿真·人工智能+x
weixin_4569042737 分钟前
Transformer架构发展历史
深度学习·架构·transformer
番茄寿司1 小时前
具身智能六大前沿创新思路深度解析
论文阅读·人工智能·深度学习·计算机网络·机器学习
rengang661 小时前
25-TensorFlow:概述Google开发的流行机器学习框架
人工智能·机器学习·tensorflow
递归尽头是星辰1 小时前
Spring AI 1.0 核心功能脉络
人工智能·spring ai·‌大模型应用开发‌·ai‌应用开发‌·java + ai
mit6.8242 小时前
[sam2图像分割] 视频追踪API | VideoPredictor | `inference_state`记忆
人工智能·计算机视觉·音视频
YangYang9YangYan2 小时前
大专计算机技术专业就业方向:解读、规划与提升指南
大数据·人工智能·数据分析
mwq301232 小时前
GPT监督微调SFT:在损失计算中屏蔽指令和填充 Token
人工智能
富唯智能2 小时前
智慧物流新篇章:复合机器人重塑装配车间物料配送
人工智能·工业机器人·复合机器人