深度学习常用损失函数介绍

均方差损失(Mean Square Error,MSE)

均方误差损失又称为二次损失、L2损失,常用于回归预测任务中。均方误差函数通过计算预测值和实际值之间距离(即误差)的平方来衡量模型优劣。即预测值和真实值越接近,两者的均方差就越小。

计算方式

假设有 n n n 个训练数据 x i x_i xi,每个训练数据 x i x_i xi 的真实输出为 y i y_i yi,模型对 x i x_i xi 的预测值为 y ^ i \hat{y}_i y^i。该模型在 n n n 个训练数据下所产生的均方误差损失可定义如下:

M S E = 1 n ∑ i = 1 n ( y i − y ^ i ) 2 MSE=\frac{1}{n}\sum_{i=1}^n{\left( y_i-\hat{y}_i \right) ^2} MSE=n1i=1∑n(yi−y^i)2

假设真实目标值为100,预测值在-10000到10000之间,我们绘制MSE函数曲线如 图1 所示。可以看到,当预测值越接近100时,MSE损失值越小。MSE损失的范围为0到 ∞ \infty ∞ 。

图1 MSE损失示意图

CTC算法

算法背景

CTC 算法主要用来解决神经网络中标签和预测值无法对齐的情况,通常用于文字识别以及语音等序列学习领域。举例来说,在语音识别任务中,我们希望语音片段可以与对应的文本内容一一对应,这样才能方便我们后续的模型训练。但是对齐音频与文本是一件很困难的事,如 图2 所示,每个人的语速都不同,有人说话快,有人说话慢,我们很难按照时序信息将语音序列切分成一个个的字符片段。而手动对齐音频与字符又是一件非常耗时耗力的任务。

图2 语音识别任务中音频与文本无法对齐

在文本识别领域,由于字符间隔、图像变形等问题,相同的字符也会得到不同的预测结果,所以同样会会遇到标签和预测值无法对齐的情况。如 图3 所示。

图3 不同表现形式的相同字符示意图

总结来说,假设我们有个输入(如字幅图片或音频信号) X X X ,对应的输出是 Y Y Y ,在序列学习领域,通常会碰到如下难点:

  1. X X X 和 Y Y Y 都是变长的;
  2. X X X 和 Y Y Y 的长度比也是变化的;
  3. X X X 和 Y Y Y 相应的元素之间无法严格对齐。

算法概述

引入CTC主要就是要解决上述问题。这里以文本识别算法CRNN为例,分析CTC的计算方式及作用。CRNN中,整体流程如 图4 所示。

图4 CRNN整体流程

CRNN中,首先使用CNN提取图片特征,特征图的维度为 m × T m\times T m×T ,特征图 x x x 可以定义为:

x = ( x 1 , x 2 , . . . , x T ) x = (x^1,x^2,...,x^T) x=(x1,x2,...,xT)

然后,将特征图的每一列作为一个时间片送入LSTM中。令 t t t 为代表时间维度的值,且满足 1 < t < T 1<t<T 1<t<T ,则每个时间片可以表示为:

x t = ( x 1 t , x 2 t , . . . , x m t ) x^t = (x_1^t,x_2^t,...,x_m^t) xt=(x1t,x2t,...,xmt)

经过LSTM的计算后,使用softmax获取概率矩阵 y y y ,定义为:

y = ( y 1 , y 2 , . . . , y T ) y = (y^1,y^2,...,y^T) y=(y1,y2,...,yT)

其中,矩阵的每一列 y t y^t yt 定义为:

y t = ( y 1 t , y 2 t , . . . , y n t ) y^t = (y_1^t,y_2^t,...,y_n^t) yt=(y1t,y2t,...,ynt)
n n n 为字符字典的长度,由于 y i t y_i^t yit 是概率,所以 Σ i y i t = 1 \Sigma_i{y_i^t}=1 Σiyit=1 。对每一列 y t y^t yt 求 a r g m a x ( ) argmax() argmax() ,就可以获取每个类别的概率。

考虑到文本区域中字符之间存在间隔,也就是有的位置是没有字符的,所以这里定义分隔符 − - − 来表示当前列的对应位置在图像中没有出现字符。用 L L L 代表原始的字符字典,则此时新的字符字典 L ′ L' L′ 为:

L ′ = L ∪ { − } L' = L \cup \{-\} L′=L∪{−}

此时,就回到了我们上文提到的问题上了,由于字符间隔、图像变形等问题,相同的字符可能会得到不同的预测结果。在CTC算法中,定义了 B B B 变换来解决这个问题。 B B B 变换简单来说就是将模型的预测结果去掉分割符以及重复字符(如果同个字符连续出现,则表示只有1个字符,如果中间有分割符,则表示该字符出现多次),使得不同表现形式的相同字符得到统一的结果。如 图5 所示。

图5 CTC示意图

这里举几个简单的例子便于理解,这里令T为10:

B ( − s − t − a a t t e ) = s t a t e B(-s-t-aatte)=state B(−s−t−aatte)=state

B ( s s − t − a − t − e ) = s t a t e B(ss-t-a-t-e)=state B(ss−t−a−t−e)=state

B ( s s t t − a a t − e ) = s t a t e B(sstt-aat-e)=state B(sstt−aat−e)=state

对于字符中间有分隔符的重复字符则不进行合并:

B ( − s − t − t a t t e ) = s t t a t e B(-s-t-tatte)=sttate B(−s−t−tatte)=sttate

当获得LSTM输出后,进行 B B B 变换就可以得到最终结果。由于 B B B 变换并不是一对一的映射,例如上边的3个不同的字符都可以变换为state,所以在LSTM的输入为 x x x 的前提下,CTC的输出为 l l l 的概率应该为:

p ( l ∣ x ) = Σ π ∈ B − 1 ( l ) p ( π ∣ x ) p(l|x) = \Sigma_{\pi\in B^{-1}(l)}p(\pi|x) p(l∣x)=Σπ∈B−1(l)p(π∣x)

其中, π \pi π 为LSTM的输出向量, π ∈ B − 1 ( l ) \pi\in B^{-1}(l) π∈B−1(l) 代表所有能通过 B B B 变换得到 l l l 的 π \pi π 的集合。

而对于任意一个 π \pi π ,又有:

p ( π ∣ x ) = Π t = 1 T y π t t p(\pi|x) = \Pi_{t=1}^Ty^t_{\pi_t} p(π∣x)=Πt=1Tyπtt

其中, y π t t y^t_{\pi_t} yπtt 代表 t t t 时刻 π \pi π 为对应值的概率,这里举一个例子进行说明:

π = − s − t − a a t t e \pi = -s-t-aatte π=−s−t−aatte

y π t t = y − 1 ∗ y s 2 ∗ y − 3 ∗ y t 4 ∗ y − 5 ∗ y a 6 ∗ y a 7 ∗ y t 8 ∗ y t 9 ∗ y e 1 0 y^t_{\pi_t} = y_-^1*y_s^2*y_-^3*y_t^4*y_-^5*y_a^6*y_a^7*y_t^8*y_t^9*y_e^10 yπtt=y−1∗ys2∗y−3∗yt4∗y−5∗ya6∗ya7∗yt8∗yt9∗ye10

不难理解,使用CTC进行模型训练,本质上就是希望调整参数,使得 p ( π ∣ x ) p(\pi|x) p(π∣x) 取最大。

交叉熵损失函数

在物理学中,"熵"被用来表示热力学系统所呈现的无序程度。香农将这一概念引入信息论领域,提出了"信息熵"概念,通过对数函数来测量信息的不确定性。

交叉熵(cross entropy)是信息论中的重要概念,主要用来度量两个概率分布间的差异。假定 p p p 和 q q q 是数据 x x x 的两个概率分布,通过 q q q 来表示 p p p 的交叉熵可如下计算:

H ( p , q ) = − ∑ x p ( x ) log ⁡ q ( x ) H\left( p,q \right) =-\sum_x{p\left( x \right) \log q\left( x \right)} H(p,q)=−x∑p(x)logq(x)

交叉熵刻画了两个概率分布之间的距离,旨在描绘通过概率分布 q q q 来表达概率分布 p p p 的困难程度。根据公式不难理解,交叉熵越小,两个概率分布 p p p 和 q q q 越接近。

这里仍然以三类分类问题为例,假设数据 x x x 属于类别 1 1 1。记数据x的类别分布概率为 y y y,显然 y = ( 1 , 0 , 0 ) y=(1,0,0) y=(1,0,0)代表数据 x x x 的实际类别分布概率。记 y ^ \hat{y} y^ 代表模型预测所得类别分布概率。

那么对于数据 x x x 而言,其实际类别分布概率 y y y 和模型预测类别分布概率 y ^ \hat{y} y^ 的交叉熵损失函数定义为:

c r o s s e n t r o p y = − y × log ⁡ ( y ^ ) cross\ entropy=-y\times \log \left( \hat{y} \right) cross entropy=−y×log(y^)

很显然,一个良好的神经网络要尽量保证对于每一个输入数据,神经网络所预测类别分布概率与实际类别分布概率之间的差距越小越好,即交叉熵越小越好。于是,可将交叉熵作为损失函数来训练神经网络。

图6 三类分类问题中输入x的交叉熵损失示意图(x 属于第一类)

图6 给出了一个三个类别分类的例子。由于输入数据 x x x 属于类别 1 1 1,因此其实际类别概率分布值为 y = ( y 1 , y 2 , y 3 ) = ( 1 , 0 , 0 ) y=(y_1,y_2,y_3)=(1,0,0) y=(y1,y2,y3)=(1,0,0)。经过神经网络的变换,得到了输入数据 x x x 相对于三个类别的预测中间值 ( z 1 , z 2 , z 3 ) (z1,z2,z3) (z1,z2,z3)。然后,经过 S o f t m a x Softmax Softmax 函数映射,得到神经网络所预测的输入数据 x x x 的类别分布概率 y ^ = ( y ^ 1 , y ^ 2 , y ^ 3 ) \hat{y}=\left( \hat{y}_1,\hat{y}_2,\hat{y}_3 \right) y^=(y^1,y^2,y^3)。根据前面的介绍, y ^ 1 \hat{y}_1 y^1、 y ^ 2 \hat{y}_2 y^2 和 y ^ 3 \hat{y}_3 y^3 为 ( 0 , 1 ) (0,1) (0,1) 范围之间的一个概率值。由于样本 x x x 属于第一个类别,因此希望神经网络所预测得到的 y ^ 1 \hat{y}_1 y^1取值要远远大于 y ^ 2 \hat{y}_2 y^2 和 y ^ 3 \hat{y}_3 y^3 的取值。为了得到这样的神经网络,在训练中可利用如下交叉熵损失函数来对模型参数进行优化:
c r o s s e n t r o p y = − ( y 1 × log ⁡ ( y ^ 1 ) + y 2 × log ⁡ ( y ^ 2 ) + y 3 × log ⁡ ( y ^ 3 ) ) cross\ entropy=-\left( y_1\times \log \left( \hat{y}_1 \right) +y_2\times \log \left( \hat{y}_2 \right) +y_3\times \log \left( \hat{y}_3 \right) \right) cross entropy=−(y1×log(y^1)+y2×log(y^2)+y3×log(y^3))

在上式中, y 2 y_2 y2 和 y 3 y_3 y3 均为 0 0 0、 y 1 y_1 y1 为 1 1 1,因此交叉熵损失函数简化为:
− y 1 × log ⁡ ( y ^ 1 ) = − log ⁡ ( y ^ 1 ) -y_1\times \log \left( \hat{y}_1 \right) =-\log \left( \hat{y}_1 \right) −y1×log(y^1)=−log(y^1)

在神经网络训练中,要将输入数据实际的类别概率分布与模型预测的类别概率分布之间的误差(即损失)从输出端向输入端传递,以便来优化模型参数。下面简单介绍根据交叉熵计算得到的误差从 y ^ 1 \hat{y}_1 y^1 传递给 z 1 z_1 z1 和 z 2 z_2 z2( z 3 z_3 z3 的推导与 z 2 z_2 z2 相同)的情况。

∂ y ^ 1 ∂ z 1 = ∂ ( e z 1 ∑ k e z k ) ∂ z 1 = ( e z 1 ) ′ × ∑ k e z k − e z 1 × e z 1 ( ∑ k e z k ) 2 = e z 1 ∑ k e z k − e z 1 ∑ k e z k × e z 1 ∑ k e z k = y ^ 1 ( 1 − y ^ 1 ) \frac{\partial \hat{y}_1}{\partial z_1}=\frac{\partial \left( \frac{e^{z_1}}{\sum_k{e^{z_k}}} \right)}{\partial z_1}=\frac{\left( e^{z_1} \right) ^{'}\times \sum_k{e^{z_k}-e^{z_1}\times e^{z_1}}}{\left( \sum_k{e^{z_k}} \right) ^2}=\frac{e^{z_1}}{\sum_k{e^{z_k}}}-\frac{e^{z_1}}{\sum_k{e^{z_k}}}\times \frac{e^{z_1}}{\sum_k{e^{z_k}}}=\hat{y}_1\left( 1-\hat{y}_1 \right) ∂z1∂y^1=∂z1∂(∑kezkez1)=(∑kezk)2(ez1)′×∑kezk−ez1×ez1=∑kezkez1−∑kezkez1×∑kezkez1=y^1(1−y^1)

由于交叉熵损失函数 − log ⁡ ( y ^ 1 ) -\log \left( \hat{y}_1 \right) −log(y^1) 对 y ^ 1 \hat{y}_1 y^1 求导的结果为 − 1 y ^ 1 -\frac{1}{\hat{y}_1} −y^11, y ^ 1 ( 1 − y ^ 1 ) \hat{y}_1\left( 1-\hat{y}_1 \right) y^1(1−y^1) 与 − 1 y ^ 1 -\frac{1}{\hat{y}_1} −y^11 相乘为 y ^ 1 − 1 \hat{y}_1-1 y^1−1。这说明一旦得到模型预测输出 y ^ 1 \hat{y}_1 y^1,将该输出减去1就是交叉损失函数相对于 z 1 z_1 z1 的偏导结果。

∂ y ^ 1 ∂ z 2 = ∂ ( e z 1 ∑ k e z k ) ∂ z 2 = 0 × ∑ k e z k − e z 1 × e z 2 ( ∑ k e z k ) 2 = − e z 1 ∑ k e z k × e z 2 ∑ k e z k = − y ^ 1 y ^ 2 \frac{\partial \hat{y}_1}{\partial z_2}=\frac{\partial \left( \frac{e^{z_1}}{\sum_k{e^{z_k}}} \right)}{\partial z_2}=\frac{0\times \sum_k{e^{z_k}-e^{z_1}\times e^{z_2}}}{\left( \sum_k{e^{z_k}} \right) ^2}=-\frac{e^{z_1}}{\sum_k{e^{z_k}}}\times \frac{e^{z_2}}{\sum_k{e^{z_k}}}=-\hat{y}_1\hat{y}_2 ∂z2∂y^1=∂z2∂(∑kezkez1)=(∑kezk)20×∑kezk−ez1×ez2=−∑kezkez1×∑kezkez2=−y^1y^2

同理,交叉熵损失函数导数为 − 1 y ^ 1 -\frac{1}{\hat{y}_1} −y^11, − y ^ 1 y ^ 2 -\hat{y}_1\hat{y}_2 −y^1y^2 与 − 1 y ^ 1 -\frac{1}{\hat{y}_1} −y^11 相乘结果为 y ^ 2 \hat{y}_2 y^2。这意味对于除第一个输出节点以外的节点进行偏导,在得到模型预测输出后,只要将其保存,就是交叉损失函数相对于其他节点的偏导结果。在 z 1 z_1 z1、 z 2 z_2 z2 和 z 3 z_3 z3得到偏导结果后,再通过链式法则(后续介绍)将损失误差继续往输入端传递即可。

在上面的例子中,假设所预测中间值 ( z 1 , z 2 , z 3 ) (z_1,z_2,z_3) (z1,z2,z3) 经过 S o f t m a x Softmax Softmax 映射后所得结果为 ( 0.34 , 0.46 , 0.20 ) (0.34,0.46,0.20) (0.34,0.46,0.20)。由于已知输入数据 x x x 属于第一类,显然这个输出不理想而需要对模型参数进行优化。如果选择交叉熵损失函数来优化模型,则 ( z 1 , z 2 , z 3 ) (z_1,z_2,z_3) (z1,z2,z3) 这一层的偏导值为 ( 0.34 − 1 , 0.46 , 0.20 ) = ( − 0.66 , 0.46 , 0.20 ) (0.34-1,0.46,0.20)= (-0.66,0.46,0.20) (0.34−1,0.46,0.20)=(−0.66,0.46,0.20)。

可以看出, S o f t m a x Softmax Softmax 和交叉熵损失函数相互结合,为偏导计算带来了极大便利。偏导计算使得损失误差从输出端向输入端传递,来对模型参数进行优化。在这里,交叉熵与 S o f t m a x Softmax Softmax 函数结合在一起,因此也叫 S o f t m a x Softmax Softmax 损失(Softmax with cross-entropy loss)。

相关推荐
算家云1 分钟前
TangoFlux 本地部署实用教程:开启无限音频创意脑洞
人工智能·aigc·模型搭建·算家云、·应用社区·tangoflux
AI街潜水的八角1 小时前
工业缺陷检测实战——基于深度学习YOLOv10神经网络PCB缺陷检测系统
pytorch·深度学习·yolo
叫我:松哥2 小时前
基于Python django的音乐用户偏好分析及可视化系统设计与实现
人工智能·后端·python·mysql·数据分析·django
熊文豪2 小时前
深入解析人工智能中的协同过滤算法及其在推荐系统中的应用与优化
人工智能·算法
Vol火山2 小时前
AI引领工业制造智能化革命:机器视觉与时序数据预测的双重驱动
人工智能·制造
tuan_zhang3 小时前
第17章 安全培训筑牢梦想根基
人工智能·安全·工业软件·太空探索·战略欺骗·算法攻坚
Antonio9154 小时前
【opencv】第10章 角点检测
人工智能·opencv·计算机视觉
互联网资讯4 小时前
详解共享WiFi小程序怎么弄!
大数据·运维·网络·人工智能·小程序·生活
helianying554 小时前
AI赋能零售:ScriptEcho如何提升效率,优化用户体验
前端·人工智能·ux·零售
积鼎科技-多相流在线5 小时前
探索国产多相流仿真技术应用,积鼎科技助力石油化工工程数字化交付
人工智能·科技·cfd·流体仿真·多相流·virtualflow