【机器学习】【概率论】【损失熵】【KL散度】信息量、香农熵和KL散度的计算

1、信息量(Amount of Information)

对于一个事件:

  • 小概率 --> 大信息量

  • 大概率 --> 小信息量

  • 独立事件的信息量可以相加
    I ( x ) = l o g 2 ( 1 p ( x ) ) = − l o g 2 ( p ( x ) ) I(x)=log_2(\frac{1}{p(x)})=-log_2(p(x)) I(x)=log2(p(x)1)=−log2(p(x))

    E.g.:

  • 一枚均匀的硬币:
    p ( h ) = 0.5 p(h)=0.5 p(h)=0.5 I p ( h ) = l o g 2 ( 1 0.5 ) = 1 I_p(h)=log_2(\frac{1}{0.5})=1 Ip(h)=log2(0.51)=1
    p ( t ) = 0.5 p(t)=0.5 p(t)=0.5 I p ( t ) = l o g 2 ( 1 0.5 ) = 1 I_p(t)=log_2(\frac{1}{0.5})=1 Ip(t)=log2(0.51)=1

  • 一枚不均匀的硬币:
    q ( h ) = 0.2 q(h)=0.2 q(h)=0.2 I q ( h ) = l o g 2 ( 1 0.2 ) = 2.32 I_q(h)=log_2(\frac{1}{0.2})=2.32 Iq(h)=log2(0.21)=2.32
    q ( t ) = 0.8 q(t)=0.8 q(t)=0.8 I q ( t ) = l o g 2 ( 1 0.8 ) = 0.32 I_q(t)=log_2(\frac{1}{0.8})=0.32 Iq(t)=log2(0.81)=0.32

2、香农熵(Shannon Entropy)

熵(entropy): 概率分布的预期信息量。它也是不确定性的度量。

假设离散分布,比如伯努利(Bernoulli)分布

连续分布时使用整体
H ( p ) = ∑ p i I i p = ∑ p i l o g 2 ( 1 p i ) = − ∑ p i l o g 2 ( p i ) H(p)=\sum p_iI^p_i=\sum p_ilog_2(\frac{1}{p_i})=-\sum p_ilog_2(p_i) H(p)=∑piIip=∑pilog2(pi1)=−∑pilog2(pi)

example: 硬币概率: p ( h ) = 0.5 p(h)=0.5 p(h)=0.5, p ( t ) = 0.5 p(t)=0.5 p(t)=0.5
H ( p ) = p ( h ) × l o g 2 ( 1 p ( h ) ) + p ( t ) × l o g 2 ( 1 p ( t ) ) = 0.5 × 1 + 0.5 × 1 = 1 H(p)=p(h)\times log_2(\frac{1}{p(h)})+p(t)\times log_2(\frac{1}{p(t)})=0.5\times 1+0.5\times 1=1 H(p)=p(h)×log2(p(h)1)+p(t)×log2(p(t)1)=0.5×1+0.5×1=1

example: 硬币概率: p ( h ) = 0.2 p(h)=0.2 p(h)=0.2, p ( t ) = 0.8 p(t)=0.8 p(t)=0.8
H ( p ) = p ( h ) × l o g 2 ( 1 p ( h ) ) + p ( t ) × l o g 2 ( 1 p ( t ) ) = 0.2 × 2.32 + 0.8 × 0.32 = 0.72 H(p)=p(h)\times log_2(\frac{1}{p(h)})+p(t)\times log_2(\frac{1}{p(t)})=0.2\times 2.32+0.8\times 0.32=0.72 H(p)=p(h)×log2(p(h)1)+p(t)×log2(p(t)1)=0.2×2.32+0.8×0.32=0.72

3、交叉熵(Cross Entropy)

一枚硬币的ground truth概率: p ( h ) = 0.5 p (h) = 0.5 p(h)=0.5, p ( t ) = 0.5 p(t) = 0.5 p(t)=0.5

估计(观察到的)概率概率: q ( h ) = 0.2 q (h) = 0.2 q(h)=0.2, q ( t ) = 0.8 q(t) = 0.8 q(t)=0.8

给定估计概率分布,估计真值概率分布的预期信息量:

H ( p , q ) = ∑ p i I i q = ∑ p i l o g 2 ( 1 q i ) = − ∑ p i l o g 2 ( q i ) H(p,q)=\sum p_iI^q_i=\sum p_ilog_2(\frac{1}{q_i})=-\sum p_i log_2(q_i) H(p,q)=∑piIiq=∑pilog2(qi1)=−∑pilog2(qi)

  • 期望值来源于真值概率分布,因为数据始终根据真值概率分布显示
  • 信息量使用估计概率分布,因为信息量是我们估计出来的

q ( h ) = 0.2 q(h) = 0.2 q(h)=0.2 q ( t ) = 0.8 q (t) = 0.8 q(t)=0.8
H ( p , q ) = p ( h ) × l o g 2 ( 1 q ( h ) ) + p ( t ) × l o g 2 ( 1 q ( t ) ) = 0.5 × 2.32 + 0.5 × 0.32 = 1.32 H(p,q) = p(h)\times log_2(\frac{1}{q(h)})+p(t)\times log_2(\frac{1}{q(t)})=0.5\times 2.32+0.5\times 0.32=1.32 H(p,q)=p(h)×log2(q(h)1)+p(t)×log2(q(t)1)=0.5×2.32+0.5×0.32=1.32

q ( h ) = 0.4 q(h) = 0.4 q(h)=0.4 q ( t ) = 0.6 q (t) = 0.6 q(t)=0.6
H ( p , q ) = p ( h ) × l o g 2 ( 1 q ( h ) ) + p ( t ) × l o g 2 ( 1 q ( t ) ) = 0.5 × 1.32 + 0.5 × 0.74 = 1.03 H(p,q) = p(h)\times log_2(\frac{1}{q(h)})+p(t)\times log_2(\frac{1}{q(t)})=0.5\times 1.32+0.5\times 0.74=1.03 H(p,q)=p(h)×log2(q(h)1)+p(t)×log2(q(t)1)=0.5×1.32+0.5×0.74=1.03

4、KL散度(Kullback-Leibler Divergence, Relative Entropy)

KL散度是用来衡量两种分布之间的差异的方法

4.1 量化视角看待熵或交叉熵之间差异性

D ( p ∥ q ) = H ( p , q ) − H ( p ) = ∑ p i I i q − ∑ p i I i p = ∑ p i l o g 2 ( 1 q i ) − ∑ p i l o g 2 ( 1 p i ) = ∑ p i l o g 2 ( p i q i ) \begin{aligned}D(p\Vert q)=H(p,q)-H(p)&=\sum p_iI^q_i-\sum p_iI^p_i\\ &=\sum p_i log_2(\frac{1}{q_i})-\sum p_i log_2(\frac{1}{p_i})\\ &=\sum p_ilog_2(\frac{p_i}{q_i}) \end{aligned} D(p∥q)=H(p,q)−H(p)=∑piIiq−∑piIip=∑pilog2(qi1)−∑pilog2(pi1)=∑pilog2(qipi)

D ( p ∥ q ) ≥ 0 D(p\Vert q)\ge 0 D(p∥q)≥0 Gibbs inequality 当且仅当两个分部一样时为0
D ( p ∥ q ) ≠ D ( q ∥ p ) D(p\Vert q)\ne D(q\Vert p) D(p∥q)=D(q∥p) 不是距离指标

最小化 KL 散度有时等同于最小化交叉熵

q θ q_\theta qθ是预测的概率分布,p是我们想要的分布。对 θ \theta θ求梯度, ∇ θ H ( p ) \nabla_\theta H(p) ∇θH(p)是常数,求地梯度为0。
∇ θ D ( p ∥ q θ ) = ∇ θ H ( p , q θ ) − ∇ θ H ( p ) = ∇ θ H ( p , q θ ) \nabla_\theta D(p\Vert q_\theta)=\nabla_\theta H(p,q_\theta)-\nabla_\theta H(p)=\nabla_\theta H(p,q_\theta) ∇θD(p∥qθ)=∇θH(p,qθ)−∇θH(p)=∇θH(p,qθ)

4.2 另一种视角看待KL散度:

两种序列的分布需要很相近:

硬币的Ground Truth:

  • p ( h ) = 0.5 p(h)=0.5 p(h)=0.5
  • p ( t ) = 0.5 p(t)=0.5 p(t)=0.5$
    硬币的观察(估计)结果:
  • q ( h ) = 0.2 q(h)=0.2 q(h)=0.2
  • q ( t ) = 0.8 q(t)=0.8 q(t)=0.8
    现在抛N次, N h N_h Nh 次head朝上, N t N_t Nt次tail朝上,形成的序列称为seq。
    当N足够大时, N h N \frac{N_h}{N} NNh趋近于 p ( h ) p(h) p(h), N t N \frac{N_t}{N} NNt趋近于 p ( t ) p(t) p(t)

l o g ( ( P ( s e q ∣ p ) P ( s e q ∣ q ) ) 1 N ) = 1 N l o g ( p ( h ) N h p ( t ) N t q ( h ) N h q ( t ) N t ) = N h N l o g ( p ( h ) ) + N t N l o g ( p ( t ) ) − N h N l o g ( q ( h ) ) − N t N l o g ( q ( t ) ) = p ( h ) l o g ( p ( h ) ) + p ( t ) l o g ( p ( t ) ) − p ( h ) l o g ( q ( h ) ) − p ( t ) l o g ( q ( t ) ) = p ( h ) l o g ( p ( h ) q ( h ) ) + p ( t ) l o g ( p ( t ) q ( t ) ) \begin{aligned}&log((\frac{P(seq\vert p)}{P(seq\vert q)})^{\frac{1}{N}})=\frac{1}{N}log(\frac{p(h)^{N_h}p(t)^{N_t}}{q(h)^{N_h}q(t)^{N_t}})\\ &=\frac{N_h}{N}log(p(h))+\frac{N_t}{N}log(p(t))-\frac{N_h}{N}log(q(h))-\frac{N_t}{N}log(q(t))\\ &=p(h)log(p(h))+p(t)log(p(t))-p(h)log(q(h))-p(t)log(q(t))\\ &=p(h)log(\frac{p(h)}{q(h)})+p(t)log(\frac{p(t)}{q(t)}) \end{aligned} log((P(seq∣q)P(seq∣p))N1)=N1log(q(h)Nhq(t)Ntp(h)Nhp(t)Nt)=NNhlog(p(h))+NNtlog(p(t))−NNhlog(q(h))−NNtlog(q(t))=p(h)log(p(h))+p(t)log(p(t))−p(h)log(q(h))−p(t)log(q(t))=p(h)log(q(h)p(h))+p(t)log(q(t)p(t))

D ( p ∥ q ) = ∑ p i l o g ( p i q i ) = l o g ( P ( s e q u e n c e o f d i s t r i b u t i o n p ∣ d i s t r i b u t i o n p ) P ( s e q u e n c e o f d i s t r i b u t i o n p ∣ d i s t r i b u t i o n q ) ) D(p\Vert q)=\sum p_i log(\frac{p_i}{q_i})=log(\frac{P(sequence\space of \space distribution \space p \vert distribution \space p)}{P(sequence\space of \space distribution \space p \vert distribution \space q)}) D(p∥q)=∑pilog(qipi)=log(P(sequence of distribution p∣distribution q)P(sequence of distribution p∣distribution p))

相关推荐
边缘计算社区29 分钟前
首个!艾灵参编的工业边缘计算国家标准正式发布
大数据·人工智能·边缘计算
游客52040 分钟前
opencv中的各种滤波器简介
图像处理·人工智能·python·opencv·计算机视觉
一位小说男主40 分钟前
编码器与解码器:从‘乱码’到‘通话’
人工智能·深度学习
深圳南柯电子1 小时前
深圳南柯电子|电子设备EMC测试整改:常见问题与解决方案
人工智能
Kai HVZ1 小时前
《OpenCV计算机视觉》--介绍及基础操作
人工智能·opencv·计算机视觉
biter00881 小时前
opencv(15) OpenCV背景减除器(Background Subtractors)学习
人工智能·opencv·学习
吃个糖糖1 小时前
35 Opencv 亚像素角点检测
人工智能·opencv·计算机视觉
qq_529025291 小时前
Torch.gather
python·深度学习·机器学习