深度学习中的对数似然损失函数

对数似然损失函数 (Log-Likelihood Loss)

1. 适用场景

对数似然损失(通常表现为负对数似然损失,Negative Log-Likelihood Loss, NLL)是统计学与深度学习中的核心损失函数。其主要适用场景包括:

  • 分类问题(Classification)
    • 原理 :最小化负对数似然损失等价于最小化交叉熵损失(Cross-Entropy Loss)
    • 二分类:假设服从伯努利分布(Bernoulli),配合 Sigmoid 使用。
    • 多分类:假设服从类别分布(Categorical),配合 Softmax 使用。
  • 带不确定性的回归问题(Regression with Uncertainty)
    • 原理 :传统回归使用均方误差(MSE),假设误差方差恒定。若需模型同时预测数值 及其不确定性(方差),则使用高斯负对数似然。
    • 优势:能够处理**异方差(Heteroscedasticity)**噪声,即不同输入样本的噪声水平不同。
  • 计数数据建模(Count Data)
    • 原理:假设数据服从泊松分布(Poisson Distribution)。
    • 场景:预测事件发生次数(如客流量、事故数)。
  • 生存分析(Survival Analysis)
    • 场景:预测事件发生的时间间隔,常用 Cox 比例风险模型或参数化生存模型。

2. 核心概念

似然(Likelihood)

在已知观测数据 yyy 的情况下,模型参数 θ\thetaθ(或预测分布)"产生"这组数据的概率密度。

  • 视角 :它是关于参数 θ\thetaθ 的函数,衡量参数对数据的解释能力(可信程度)。
  • 记号 :L(θ;y)=p(y∣θ)L(\theta; y) = p(y \mid \theta)L(θ;y)=p(y∣θ)。

对数似然(Log-Likelihood)

将似然函数取自然对数(ln⁡\lnln 或 log⁡\loglog)。
ℓ(θ;y)=log⁡L(θ;y) \ell(\theta; y) = \log L(\theta; y) ℓ(θ;y)=logL(θ;y)

  • 作用:将连乘运算转化为加法运算,便于计算梯度,且能防止数值下溢。

同方差 vs 异方差(Homoscedasticity vs Heteroscedasticity)

  • 同方差 :误差方差在所有样本上相同,即 Var⁡(ϵi)=σ2\operatorname{Var}(\epsilon_i) = \sigma^2Var(ϵi)=σ2(常数)。
  • 异方差 :误差方差随输入 xi\mathbf{x}_ixi 变化,即 Var⁡(ϵi∣xi)=σi2\operatorname{Var}(\epsilon_i \mid \mathbf{x}_i) = \sigma_i^2Var(ϵi∣xi)=σi2(函数)。

最大似然估计(MLE)

寻找参数 θ\thetaθ,使得观测数据出现的似然最大:
θ^MLE=arg⁡max⁡θℓ(θ;y) \hat{\theta}{\text{MLE}} = \arg\max{\theta} \ell(\theta; y) θ^MLE=argθmaxℓ(θ;y)

  • 深度学习转化 :优化器通常用于最小化 损失,因此目标转化为最小化负对数似然(NLL)
    LNLL=−1N∑i=1Nlog⁡p(yi∣xi) \mathcal{L}{\text{NLL}} = -\frac{1}{N} \sum{i=1}^{N} \log p(y_i \mid \mathbf{x}_i) LNLL=−N1i=1∑Nlogp(yi∣xi)

3. 高斯负对数似然推导

假设预测值 yiy_iyi 来自条件高斯分布:
yi∼N(μi,σi2) y_i \sim \mathcal{N}(\mu_i, \sigma_i^2) yi∼N(μi,σi2)

其中 μi=fθ(xi)\mu_i = f_\theta(\mathbf{x}_i)μi=fθ(xi) 为预测均值,σi2\sigma_i^2σi2 为预测方差。

单样本概率密度

p(yi∣xi)=12πσi2exp⁡(−(yi−μi)22σi2) p(y_i \mid \mathbf{x}_i) = \frac{1}{\sqrt{2\pi\sigma_i^2}} \exp\left( -\frac{(y_i - \mu_i)^2}{2\sigma_i^2} \right) p(yi∣xi)=2πσi2 1exp(−2σi2(yi−μi)2)

单样本对数似然

取对数后:
log⁡p(yi∣xi)=−12log⁡(2π)−12log⁡(σi2)−(yi−μi)22σi2 \log p(y_i \mid \mathbf{x}_i) = -\frac{1}{2}\log(2\pi) - \frac{1}{2}\log(\sigma_i^2) - \frac{(y_i - \mu_i)^2}{2\sigma_i^2} logp(yi∣xi)=−21log(2π)−21log(σi2)−2σi2(yi−μi)2

损失函数(Negative Log-Likelihood)

取负号并忽略常数项 −12log⁡(2π)-\frac{1}{2}\log(2\pi)−21log(2π)(不影响梯度),得到单个样本的损失:
Li=12log⁡(σi2)+(yi−μi)22σi2 \mathcal{L}_i = \frac{1}{2}\log(\sigma_i^2) + \frac{(y_i - \mu_i)^2}{2\sigma_i^2} Li=21log(σi2)+2σi2(yi−μi)2


4. 两种形式的损失函数

根据方差 σ2\sigma^2σ2 是否随输入变化,分为两种形式。

4.1 同方差损失 (Homoscedastic NLL)

方差 σ2\sigma^2σ2 是一个全局可学习标量参数 ,不随输入 xi\mathbf{x}_ixi 变化。

  • 模型输出 :仅预测均值 μ\muμ。
  • 参数 :log⁡(σ2)\log(\sigma^2)log(σ2) 作为独立参数在训练中被优化。

Lhomo=12(log⁡(σ2)+1N∑i=1N(yi−μi)2σ2) \mathcal{L}{\text{homo}} = \frac{1}{2} \left( \log(\sigma^2) + \frac{1}{N}\sum{i=1}^{N} \frac{(y_i - \mu_i)^2}{\sigma^2} \right) Lhomo=21(log(σ2)+N1i=1∑Nσ2(yi−μi)2)

4.2 异方差损失 (Heteroscedastic NLL)

方差 σi2\sigma_i^2σi2 随输入变化,由模型动态预测。

  • 模型输出 :同时预测均值 μi\mu_iμi 和方差 σi2\sigma_i^2σi2(通常输出形状为 (B,L,2)(B, L, 2)(B,L,2))。
  • 含义:模型对某些样本 confident(方差小),对某些样本 uncertain(方差大)。

Lhetero=12N∑i=1N(log⁡(σi2)+(yi−μi)2σi2) \mathcal{L}{\text{hetero}} = \frac{1}{2N} \sum{i=1}^{N} \left( \log(\sigma_i^2) + \frac{(y_i - \mu_i)^2}{\sigma_i^2} \right) Lhetero=2N1i=1∑N(log(σi2)+σi2(yi−μi)2)

变量定义说明

若输入批次为 (B,L)(B, L)(B,L),则:

  • μ∈RB×L\boldsymbol{\mu} \in \mathbb{R}^{B \times L}μ∈RB×L:预测均值。
  • σ2∈RB×L\boldsymbol{\sigma}^2 \in \mathbb{R}^{B \times L}σ2∈RB×L:预测方差(不确定性)。
  • y∈RB×L\mathbf{y} \in \mathbb{R}^{B \times L}y∈RB×L:真实标签。
  • N=B×LN = B \times LN=B×L:样本总数。

5. 实际应用中需要注意的

5.1 数值稳定性

直接预测方差 σ2\sigma^2σ2 可能导致 σ2≤0\sigma^2 \le 0σ2≤0,从而使 log⁡(σ2)\log(\sigma^2)log(σ2) 报错。

  • 将模型输出进行变换 :模型输出层预测 log⁡(σ2)\log(\sigma^2)log(σ2) 记为 sis_isi。
  • exp变换 :σi2=exp⁡(si)\sigma_i^2 = \exp(s_i)σi2=exp(si)。
  • 稳定后的损失函数
    L=12N∑i=1N(si+(yi−μi)2exp⁡(si)) \mathcal{L} = \frac{1}{2N} \sum_{i=1}^{N} \left( s_i + \frac{(y_i - \mu_i)^2}{\exp(s_i)} \right) L=2N1i=1∑N(si+exp(si)(yi−μi)2)
    这样既保证了方差恒正,又避免了计算对数时的数值不稳定。

5.2 梯度平衡

  • 损失函数中的两项 log⁡(σ2)\log(\sigma^2)log(σ2) 和 (y−μ)2σ2\frac{(y-\mu)^2}{\sigma^2}σ2(y−μ)2 具有不同的量级。
  • 在训练初期,若方差预测不准确,可能导致梯度爆炸或消失。建议对方差输出进行适当的激活函数约束(如 Softplus)或直接预测对数方差。

5.3 与 MSE 的关系

  • 当 σ2\sigma^2σ2 固定为常数时,最小化 NLL 等价于最小化 MSE
  • 引入可学习的 σ2\sigma^2σ2 后,模型会自动根据数据的噪声水平调整损失权重:噪声大的样本权重低,噪声小的样本权重高。

6. 总结

特性 均方误差 (MSE) 负对数似然 (NLL)
假设分布 高斯分布,方差恒定 高斯分布,方差可变的
模型输出 仅均值 μ\muμ 均值 μ\muμ + 方差 σ2\sigma^2σ2 (或同方差参数)
不确定性 无法预测 可以预测 (Uncertainty Quantification)
适用场景 标准回归 噪声不均匀的回归、概率预测

核心结论:如果你只需要预测数值,MSE 足够;如果你需要知道模型"对这个预测有多大的把握"**,请使用对数似然损失。

相关推荐
慕诗客5 小时前
repo管理多仓库
大数据·elasticsearch·搜索引擎
China_Yanhy5 小时前
动手学大模型第一篇学习总结
人工智能
空间机器人5 小时前
自动驾驶 ADAS 器件选型:算力只是门票,系统才是生死线
人工智能·机器学习·自动驾驶
C+++Python5 小时前
提示词、Agent、MCP、Skill 到底是什么?
人工智能
小松要进步5 小时前
机器学习1
人工智能·机器学习
泰恒5 小时前
openclaw近期怎么样了?
人工智能·深度学习·机器学习
KaneLogger6 小时前
从传统笔记到 LLM 驱动的结构化 Wiki
人工智能·程序员·架构
tinygone6 小时前
OpenClaw之Memory配置成本地模式,Ubuntu+CUDA+cuDNN+llama.cpp
人工智能·ubuntu·llama
正在走向自律6 小时前
第二章-AIGC入门-AIGC工具全解析:技术控的效率神器,DeepSeek国产大模型的骄傲(8/36)
人工智能·chatgpt·aigc·可灵·deepseek·即梦·阿里通义千问
轩轩分享AI6 小时前
DeepSeek、Kimi、笔灵谁最好用?5款网文作者亲测的AI写作神器横评
人工智能·ai·ai写作·小说写作·小说·小说干货