引言:Loss 为 Nan ?别慌,这是深度学习的 "常见病"
在深度学习训练过程中,最让人崩溃的场景之一莫过于:前几轮迭代 Loss 还在正常下降,突然某个时刻直接变成 Nan(Not a Number),训练直接 "罢工"。这种情况看似突发,实则背后藏着数据、模型、超参数等多方面的问题。本文整理了 10 种最常见的 Loss 为 Nan 原因,搭配具体场景和解决方案,帮你快速定位问题、高效解决。
一、数据层面:"源头污染" 是罪魁祸首
数据是模型训练的基础,若数据存在异常,后续训练必然 "跑偏"。
1. 输入数据含 Nan/Inf 值
场景:数据集未经过严格清洗,存在缺失值(未处理直接变成 Nan)、极端异常值(如数值过大溢出为 Inf),或数据预处理时除法分母为 0(如归一化时除以标准差,而某特征标准差为 0)。
解决方案:
- 数据清洗:用 pandas 的 df.isnull().sum() 检查 Nan 值,用 np.isinf(df).sum() 检查 Inf 值,根据场景填充(均值 / 中位数 / 众数)或删除异常样本;
- 预处理校验:对归一化、标准化步骤添加 "防零校验",例如 std = max(std, 1e-8),避免分母为 0;
- 数据范围限制:对输入数据做截断处理,如 np.clip(x, a_min=-1e6, a_max=1e6),防止极端值溢出。
2. 标签(Label)异常
场景:分类任务中标签超出类别范围(如二分类标签出现 2)、回归任务中标签为 Nan/Inf,或标签数值过大(如回归目标值为 1e10,导致 Loss 计算时溢出)。
解决方案:
- 标签校验:训练前检查标签的取值范围(如 np.unique(label))、是否存在 Nan/Inf(np.isnan(label).sum());
- 标签缩放:回归任务中若标签数值过大,可采用标准化(Z-Score)或归一化(Min-Max)将标签缩放到合理范围(如 [-1, 1] 或 [0, 1])。
二、模型层面:"结构缺陷" 导致计算异常
模型的网络结构、激活函数、初始化方式等不当,会导致梯度爆炸或计算溢出。
1.激活函数选择不当
场景:深层网络中使用 Sigmoid/Tanh 激活函数,输入值过大时函数梯度趋近于 0(梯度消失),但输入值极端时可能导致计算溢出;或在输出层使用不适合的激活函数(如回归任务用 Sigmoid 导致输出范围受限,Loss 计算异常)。
解决方案:
- 隐藏层优先使用 ReLU 及其变体(Leaky ReLU、ELU、GELU),避免梯度消失 / 爆炸;
- 输出层按需选择:分类任务用 Softmax(多分类)/Sigmoid(二分类),回归任务用 Linear(无激活);
- 限制激活函数输入范围:对 ReLU 输入做截断(如 torch.clamp(x, max=1e2)),避免数值过大。
2.权重初始化不合理
场景:权重初始化过大(如全零初始化、随机初始化时方差过大),导致网络前向传播时输出值溢出(变成 Inf),进而 Loss 计算为 Nan;或初始化过小,导致梯度消失,后续训练中数值异常。
解决方案:
- 采用标准化初始化方法:如 Xavier 初始化(适用于 Tanh/Sigmoid)、He 初始化(适用于 ReLU 及其变体);
- 避免全零初始化:全零初始化会导致所有神经元输出相同,梯度更新无效,最终可能引发数值异常;
- 初始化后校验:打印网络第一层权重的均值和方差(如 model.fc1.weight.mean(), model.fc1.weight.std()),确保在合理范围(如均值接近 0,方差接近 1)。
3.网络结构过深 / 宽,导致梯度爆炸
场景:深层网络(如深度超过 100 层的 CNN、Transformer)未使用残差连接(Residual Connection)、LayerNorm 等技术,梯度在反向传播时累积放大(梯度爆炸),导致权重更新时数值溢出,Loss 变为 Nan。
解决方案:
- 加入残差连接:缓解深层网络的梯度消失 / 爆炸(如 ResNet 结构);
- 使用归一化层:在网络层间加入 BatchNorm 或 LayerNorm,稳定中间层输出的均值和方差,避免数值漂移;
- 简化网络:若无需过深结构,适当减少网络层数或神经元数量,降低训练难度。
三、训练过程:"参数不当" 引发数值崩溃
训练时的超参数设置、优化器选择、Loss 函数设计等,是导致 Loss 为 Nan 的高频原因。
1.学习率(Learning Rate)过大
场景:学习率是最影响训练稳定性的超参数 ------ 学习率过大时,权重更新步长太大,导致权重数值急剧震荡、溢出(变成 Inf),进而 Loss 计算为 Nan;即使初期 Loss 正常,几轮后也可能突然崩溃。
解决方案:
- 减小学习率:从较小的学习率开始尝试(如 1e-4、1e-5),而非默认的 1e-3;
- 使用学习率调度器:如 StepLR(步长衰减)、ReduceLROnPlateau(根据验证集 Loss 衰减),避免后期学习率过大;
- 梯度裁剪(Gradient Clipping):对梯度进行截断,限制梯度的最大范数(如 torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)),防止梯度爆炸导致权重更新异常。
2.Loss 函数设计不当
场景:自定义 Loss 函数时存在数值不稳定的计算(如对数运算中输入为 0 或负数,log(0) 会得到 -Inf,进而 Loss 为 Nan);或选择的 Loss 函数与任务不匹配(如多分类任务用 MSE Loss,导致梯度更新异常)。
解决方案:
- 避免数值不稳定运算:自定义 Loss 时,对对数输入添加微小偏移(如 torch.log(x + 1e-8)),对除法添加防零处理(如 x / (y + 1e-8));
- 选择匹配任务的 Loss:多分类用 CrossEntropyLoss,二分类用 BCEWithLogitsLoss(自带 Sigmoid,更稳定),回归用 MSELoss 或 L1Loss;
- 监控 Loss 组成:若 Loss 由多个部分加权求和(如分类 Loss + 正则化 Loss),可分别打印各部分数值,定位哪一部分导致 Nan。
3.优化器参数不当
场景:使用动量(Momentum)、Adam 等优化器时,参数设置不当导致训练不稳定。例如 Adam 的 eps 参数(数值稳定性项)过小(默认 1e-8,若改得更小可能导致数值溢出);或动量参数过大(如 Momentum=0.999),导致梯度累积过度。
解决方案:
- 保持优化器默认参数:新手建议先使用优化器默认参数(如 Adam 的 lr=1e-3, betas=(0.9, 0.999), eps=1e-8),再根据训练情况调整;
- 调整 eps 参数:若出现 Nan,可适当增大 eps(如 1e-6),增强数值稳定性;
- 避免使用过大的动量:动量参数建议在 0.9~0.99 之间,过大可能导致训练震荡。
四、其他细节:"隐性问题" 容易被忽略
除了上述核心原因,一些训练细节的疏忽也会导致 Loss 为 Nan。
1.batch size 过小或样本分布不均
场景:batch size 过小时(如 1、2),BatchNorm 计算的均值 / 方差波动极大(甚至为 Nan,当 batch 中所有样本某特征值相同时),导致中间层输出异常;或样本分布不均(如某类样本占比 99%),导致 Loss 计算时出现极端值。
解决方案:
- 增大 batch size:建议 batch size 至少为 8 或 16,若 GPU 显存不足,可使用梯度累积(Gradient Accumulation)模拟大 batch;
- 禁用 BatchNorm 或改用 LayerNorm:小 batch 场景下,BatchNorm 稳定性差,可替换为 LayerNorm(不受 batch 大小影响);
- 平衡样本分布:通过过采样少数类、欠采样多数类或使用加权 Loss(如 class_weight 参数),缓解样本不均导致的 Loss 异常。
2.数值精度问题
场景:使用 FP16(半精度)训练时,数值范围较小,容易出现溢出(尤其是梯度或 Loss 数值较大时);或 CPU/GPU 计算时的浮点精度误差累积,导致最终 Loss 变为 Nan。
解决方案:
- 优先使用 FP32(单精度)训练:新手避免直接使用 FP16,若需使用,需配合混合精度训练(如 PyTorch 的 torch.cuda.amp),设置梯度缩放(Gradient Scaling)避免溢出;
- 监控数值范围:训练时打印中间层输出、梯度的最大值 / 最小值,若发现数值超出 FP32 范围(如大于 1e38 或小于 -1e38),及时做截断处理;
- 避免累积微小误差:在循环计算中(如 RNN 中的隐藏状态更新),定期对数值做归一化或截断,防止误差累积。
总结:快速定位 Loss 为 Nan 的排查流程
遇到 Loss 为 Nan 时,无需盲目调参,可按以下步骤快速定位:
- 检查数据:先校验输入数据和标签是否存在 Nan/Inf、极端值,这是最常见原因;
- 降低学习率:将学习率缩小 10~100 倍,重新训练,排除学习率过大的问题;
- 监控中间值:打印网络中间层输出、梯度的数值范围,判断是前向传播(输出溢出)还是反向传播(梯度爆炸)的问题;
- 简化模型:临时移除 BatchNorm、减少网络层数,用简单模型(如 2 层全连接)测试,排除模型结构问题;
- 禁用自定义组件:若使用自定义 Loss、激活函数,先替换为 PyTorch/TensorFlow 内置组件,定位是否为自定义代码的数值不稳定问题。
深度学习训练的核心是 "稳定",Loss 为 Nan 本质是数值计算的崩溃。只要从数据、模型、训练过程三个维度逐一排查,就能快速找到问题根源。