深度学习笔记10-多分类

多分类和softmax回归

在多分类问题中,一个样本会被划分到三个或更多的类别中,可以使用多个二分类模型或一个多分类模型,这两种方式解决多分类问题。

1.基于二分类模型的多分类

直接基于二分类模型解决多分类任务,对于多分类中的每个目标类别都要训练一个二分类模型。在训练时,将需要识别出的类别的数据作为正例,其余数据为反例,这种训练方式被称为1-VS-Rest,也就是1对其余的策略。

如果有N个目标类别,训练N个二分类模型,N个模型互相独立,互不干扰,同一个数据每个模型都需要计算一遍。另外对于不同的类别,也可以使用不同的二分类模型进行训练。例如可以使用逻辑回归、SVM、决策树这三种模型应用在同一个分类系统中,来识别不同的类别。

基于1-VS-Rest策略的多分类,优势是可维护性高、随时可以增加新的类别模型,或者修改升级其中某个模型,都不会对其他已有模型产生影响。另外,分类结果是相互独立的,可以自由选择这些模型的组合方式,进而更有针对性的调试和优化。

2.softmax解决多分类问题

构建softmax回归模型同时对所有类别进行识别,在softmax回归中包括两步。步骤一:输入一个样本的特征向量,输出多个线性预测结果。步骤二:将这个结果输入到softmax函数,softmax函数会将多个线性输出转换为每个类别的概率。softmax回归会基于输入x,计算三个线性输出。可以将softmax回归看作是一个具有多个输出的单层神经网络。

三个目标类型: 四个输入特征:

基于矩阵,计算线性输出o:

o=Wx+b

通过softmax函数计算类别的概率

计算出线性输出o后,将o输入到softmax函数,从而将线性输出o转换为每个类别的预测概率y. 设有n个输出-,第k个输出是,它对应的类别概率是

由此通过softmax函数将所有线性输出都转换为0-1之间的实数:y1,y2,...yn[0,1],输出的总和y1+y2+...+yn=1.

softmax函数不会改变线性输出o之间的大小顺序,只会为每个类别分配相应的概率。它的优势在于模型简洁高效,只需要一次训练就可以同时识别所有类别的多分类模型。此外softmax回归也可以很好地处理类别之间的互斥问题,softmax函数可以确保预测结果总和为1。它也存在一些问题,在需要优化模型中的某个类别或者增加新的类别时,会影响到其他所有的类别,产生较高的评估与维护成本。

3.多分类中的交叉熵损失函数

交叉熵误差:评估模型输出的概率分布和真实概率分布的差异,它有两种形式分别对应二分类与多分类问题,

二分类问题:

多分类问题:

多分类问题中,如果每个类别之间的定义是互斥的,那么任何样本都只能被标记为一种类别。使用向量y表示样本的标记值,如果有n个类别,那么y就是一个n*1的列向量。向量中只有1个元素是1,其余元素都是0。多分类问题的交叉熵损失,只与真实类别对应的模型预测参数概率有关,因此第i个样本的误差为

m个样本、n个类别的交叉熵误差:

表示第i个样本第k个类别的真实标记、表示第k个类别的模型预测概率

如果某样本被标记为第2个类别,那么第二个元素标记为1,其余为0.

,

在交叉熵损失函数中,只有真实类别对应的那一项会被计算在内,其他类别的项在计算求和中均为0,因此,即便模型对其他类别的预测概率不准确,但只要对真实类别的预测概率较高,损失函数的值仍然较低。

4.softmax回归的数学原理

softmax回归也被称为多项的逻辑回归,它可以看作是逻辑回归在多分类问题上的推广。

类别个数:n

类别标签:y{0,1,...n}

x:样本特征向量

:第k个类别的权重

某样本属于类别k的概率:

softmax回归和逻辑回归的关系

逻辑回归中使用sigmoid函数,,将线性输出z转化为一个概率,这个概率表示样本属于正例的可能性。在softmax中使用softmax函数将输出值同样转化为概率,这个概率表示样本属于第k个类别的可能性。当类别数为2时,逻辑回归和softmax回归的输出时等价的。

softmax回归中:

类别为0的概率:

类别为1的概率:

将z0-z1看作一个整体-z,就得到逻辑回归的形式,所以在处理二分类问题时,softmax回归和逻辑回归实际上是完全等价的模型

梯度下降法求解softmax回归

softmax回归模型的代价函数即交叉熵损失函数:

求E关于第k个类别中第j个特征权重偏导数:

梯度下降算法:

最终迭代:

相关推荐
Light609 分钟前
数据模型全解:从架构之心到AI时代的智慧表达
人工智能·架构·数据模型·三层架构·数仓建模·ai辅助·业务翻译
链上日记3 小时前
WEEX出席迪拜区块链生活2025,担任白金赞助商
人工智能·区块链·生活
灵途科技6 小时前
灵途科技亮相NEPCON ASIA 2025 以光电感知点亮具身智能未来
人工智能·科技·机器人
文火冰糖的硅基工坊7 小时前
[人工智能-大模型-125]:模型层 - RNN的隐藏层是什么网络,全连接?还是卷积?RNN如何实现状态记忆?
人工智能·rnn·lstm
IT90907 小时前
c#+ visionpro汽车行业,机器视觉通用检测程序源码 产品尺寸检测,机械手引导定位等
人工智能·计算机视觉·视觉检测
Small___ming7 小时前
【人工智能数学基础】多元高斯分布
人工智能·机器学习·概率论
Ro Jace7 小时前
机器学习、深度学习、信号处理领域常用符号速查表
深度学习·机器学习·信号处理
渔舟渡简7 小时前
机器学习-回归分析概述
人工智能·机器学习
王哈哈^_^7 小时前
【数据集】【YOLO】目标检测游泳数据集 4481 张,溺水数据集,YOLO河道、海滩游泳识别算法实战训练教程。
人工智能·算法·yolo·目标检测·计算机视觉·分类·视觉检测
桂花饼7 小时前
Sora 2:从视频生成到世界模拟,OpenAI的“终极游戏”
人工智能·aigc·openai·sora 2