深度学习中的对数似然损失函数

对数似然损失函数 (Log-Likelihood Loss)

1. 适用场景

对数似然损失(通常表现为负对数似然损失,Negative Log-Likelihood Loss, NLL)是统计学与深度学习中的核心损失函数。其主要适用场景包括:

  • 分类问题(Classification)
    • 原理 :最小化负对数似然损失等价于最小化交叉熵损失(Cross-Entropy Loss)
    • 二分类:假设服从伯努利分布(Bernoulli),配合 Sigmoid 使用。
    • 多分类:假设服从类别分布(Categorical),配合 Softmax 使用。
  • 带不确定性的回归问题(Regression with Uncertainty)
    • 原理 :传统回归使用均方误差(MSE),假设误差方差恒定。若需模型同时预测数值 及其不确定性(方差),则使用高斯负对数似然。
    • 优势:能够处理**异方差(Heteroscedasticity)**噪声,即不同输入样本的噪声水平不同。
  • 计数数据建模(Count Data)
    • 原理:假设数据服从泊松分布(Poisson Distribution)。
    • 场景:预测事件发生次数(如客流量、事故数)。
  • 生存分析(Survival Analysis)
    • 场景:预测事件发生的时间间隔,常用 Cox 比例风险模型或参数化生存模型。

2. 核心概念

似然(Likelihood)

在已知观测数据 yyy 的情况下,模型参数 θ\thetaθ(或预测分布)"产生"这组数据的概率密度。

  • 视角 :它是关于参数 θ\thetaθ 的函数,衡量参数对数据的解释能力(可信程度)。
  • 记号 :L(θ;y)=p(y∣θ)L(\theta; y) = p(y \mid \theta)L(θ;y)=p(y∣θ)。

对数似然(Log-Likelihood)

将似然函数取自然对数(ln⁡\lnln 或 log⁡\loglog)。
ℓ(θ;y)=log⁡L(θ;y) \ell(\theta; y) = \log L(\theta; y) ℓ(θ;y)=logL(θ;y)

  • 作用:将连乘运算转化为加法运算,便于计算梯度,且能防止数值下溢。

同方差 vs 异方差(Homoscedasticity vs Heteroscedasticity)

  • 同方差 :误差方差在所有样本上相同,即 Var⁡(ϵi)=σ2\operatorname{Var}(\epsilon_i) = \sigma^2Var(ϵi)=σ2(常数)。
  • 异方差 :误差方差随输入 xi\mathbf{x}_ixi 变化,即 Var⁡(ϵi∣xi)=σi2\operatorname{Var}(\epsilon_i \mid \mathbf{x}_i) = \sigma_i^2Var(ϵi∣xi)=σi2(函数)。

最大似然估计(MLE)

寻找参数 θ\thetaθ,使得观测数据出现的似然最大:
θ^MLE=arg⁡max⁡θℓ(θ;y) \hat{\theta}{\text{MLE}} = \arg\max{\theta} \ell(\theta; y) θ^MLE=argθmaxℓ(θ;y)

  • 深度学习转化 :优化器通常用于最小化 损失,因此目标转化为最小化负对数似然(NLL)
    LNLL=−1N∑i=1Nlog⁡p(yi∣xi) \mathcal{L}{\text{NLL}} = -\frac{1}{N} \sum{i=1}^{N} \log p(y_i \mid \mathbf{x}_i) LNLL=−N1i=1∑Nlogp(yi∣xi)

3. 高斯负对数似然推导

假设预测值 yiy_iyi 来自条件高斯分布:
yi∼N(μi,σi2) y_i \sim \mathcal{N}(\mu_i, \sigma_i^2) yi∼N(μi,σi2)

其中 μi=fθ(xi)\mu_i = f_\theta(\mathbf{x}_i)μi=fθ(xi) 为预测均值,σi2\sigma_i^2σi2 为预测方差。

单样本概率密度

p(yi∣xi)=12πσi2exp⁡(−(yi−μi)22σi2) p(y_i \mid \mathbf{x}_i) = \frac{1}{\sqrt{2\pi\sigma_i^2}} \exp\left( -\frac{(y_i - \mu_i)^2}{2\sigma_i^2} \right) p(yi∣xi)=2πσi2 1exp(−2σi2(yi−μi)2)

单样本对数似然

取对数后:
log⁡p(yi∣xi)=−12log⁡(2π)−12log⁡(σi2)−(yi−μi)22σi2 \log p(y_i \mid \mathbf{x}_i) = -\frac{1}{2}\log(2\pi) - \frac{1}{2}\log(\sigma_i^2) - \frac{(y_i - \mu_i)^2}{2\sigma_i^2} logp(yi∣xi)=−21log(2π)−21log(σi2)−2σi2(yi−μi)2

损失函数(Negative Log-Likelihood)

取负号并忽略常数项 −12log⁡(2π)-\frac{1}{2}\log(2\pi)−21log(2π)(不影响梯度),得到单个样本的损失:
Li=12log⁡(σi2)+(yi−μi)22σi2 \mathcal{L}_i = \frac{1}{2}\log(\sigma_i^2) + \frac{(y_i - \mu_i)^2}{2\sigma_i^2} Li=21log(σi2)+2σi2(yi−μi)2


4. 两种形式的损失函数

根据方差 σ2\sigma^2σ2 是否随输入变化,分为两种形式。

4.1 同方差损失 (Homoscedastic NLL)

方差 σ2\sigma^2σ2 是一个全局可学习标量参数 ,不随输入 xi\mathbf{x}_ixi 变化。

  • 模型输出 :仅预测均值 μ\muμ。
  • 参数 :log⁡(σ2)\log(\sigma^2)log(σ2) 作为独立参数在训练中被优化。

Lhomo=12(log⁡(σ2)+1N∑i=1N(yi−μi)2σ2) \mathcal{L}{\text{homo}} = \frac{1}{2} \left( \log(\sigma^2) + \frac{1}{N}\sum{i=1}^{N} \frac{(y_i - \mu_i)^2}{\sigma^2} \right) Lhomo=21(log(σ2)+N1i=1∑Nσ2(yi−μi)2)

4.2 异方差损失 (Heteroscedastic NLL)

方差 σi2\sigma_i^2σi2 随输入变化,由模型动态预测。

  • 模型输出 :同时预测均值 μi\mu_iμi 和方差 σi2\sigma_i^2σi2(通常输出形状为 (B,L,2)(B, L, 2)(B,L,2))。
  • 含义:模型对某些样本 confident(方差小),对某些样本 uncertain(方差大)。

Lhetero=12N∑i=1N(log⁡(σi2)+(yi−μi)2σi2) \mathcal{L}{\text{hetero}} = \frac{1}{2N} \sum{i=1}^{N} \left( \log(\sigma_i^2) + \frac{(y_i - \mu_i)^2}{\sigma_i^2} \right) Lhetero=2N1i=1∑N(log(σi2)+σi2(yi−μi)2)

变量定义说明

若输入批次为 (B,L)(B, L)(B,L),则:

  • μ∈RB×L\boldsymbol{\mu} \in \mathbb{R}^{B \times L}μ∈RB×L:预测均值。
  • σ2∈RB×L\boldsymbol{\sigma}^2 \in \mathbb{R}^{B \times L}σ2∈RB×L:预测方差(不确定性)。
  • y∈RB×L\mathbf{y} \in \mathbb{R}^{B \times L}y∈RB×L:真实标签。
  • N=B×LN = B \times LN=B×L:样本总数。

5. 实际应用中需要注意的

5.1 数值稳定性

直接预测方差 σ2\sigma^2σ2 可能导致 σ2≤0\sigma^2 \le 0σ2≤0,从而使 log⁡(σ2)\log(\sigma^2)log(σ2) 报错。

  • 将模型输出进行变换 :模型输出层预测 log⁡(σ2)\log(\sigma^2)log(σ2) 记为 sis_isi。
  • exp变换 :σi2=exp⁡(si)\sigma_i^2 = \exp(s_i)σi2=exp(si)。
  • 稳定后的损失函数
    L=12N∑i=1N(si+(yi−μi)2exp⁡(si)) \mathcal{L} = \frac{1}{2N} \sum_{i=1}^{N} \left( s_i + \frac{(y_i - \mu_i)^2}{\exp(s_i)} \right) L=2N1i=1∑N(si+exp(si)(yi−μi)2)
    这样既保证了方差恒正,又避免了计算对数时的数值不稳定。

5.2 梯度平衡

  • 损失函数中的两项 log⁡(σ2)\log(\sigma^2)log(σ2) 和 (y−μ)2σ2\frac{(y-\mu)^2}{\sigma^2}σ2(y−μ)2 具有不同的量级。
  • 在训练初期,若方差预测不准确,可能导致梯度爆炸或消失。建议对方差输出进行适当的激活函数约束(如 Softplus)或直接预测对数方差。

5.3 与 MSE 的关系

  • 当 σ2\sigma^2σ2 固定为常数时,最小化 NLL 等价于最小化 MSE
  • 引入可学习的 σ2\sigma^2σ2 后,模型会自动根据数据的噪声水平调整损失权重:噪声大的样本权重低,噪声小的样本权重高。

6. 总结

特性 均方误差 (MSE) 负对数似然 (NLL)
假设分布 高斯分布,方差恒定 高斯分布,方差可变的
模型输出 仅均值 μ\muμ 均值 μ\muμ + 方差 σ2\sigma^2σ2 (或同方差参数)
不确定性 无法预测 可以预测 (Uncertainty Quantification)
适用场景 标准回归 噪声不均匀的回归、概率预测

核心结论:如果你只需要预测数值,MSE 足够;如果你需要知道模型"对这个预测有多大的把握"**,请使用对数似然损失。

相关推荐
Percent_bigdata2 小时前
百分点科技亮相MWC 2026:以数据智能深耕全球治理
人工智能·科技
arvin_xiaoting2 小时前
AI Agent 实战:用飞书任务卡片让后台任务「可见」
人工智能·自动化·llm·飞书·ai agent·openclaw·任务卡片
sheyuDemo2 小时前
torch中的rand()和randn()函数的区别
人工智能·pytorch·深度学习
rainbow7242442 小时前
如何科学选型:AI人才技术水平评估的多元方法对比与深度分析
大数据·人工智能
70asunflower2 小时前
CUDA基础知识巩固检验练习题【附有参考答案】(7)
c++·人工智能·cuda
大傻^2 小时前
LangChain4j 记忆架构:ChatMemory、持久化与跨会话状态
java·人工智能·windows·架构·langchain4j
生活予甜2 小时前
广柔扁平电缆在机器人AI技术创新应用中的前景探索
人工智能·机器人
找藉口是失败者的习惯2 小时前
从LLM到Agent:大语言模型核心概念指南
人工智能·语言模型·自然语言处理
接着奏乐接着舞。2 小时前
5分钟本地跑起大模型
人工智能·llama