支持向量机

一、SVM基本原理SVM(Support Vector Machine)

SVM是机器学习中常用的分类算法。SVM算法可以将𝑛维空间分割到不同的类中,以便将新数据点放在正确的类中,得到的最佳决策边界称为超平面。在超平面一侧的数据标签为1,而在另一侧的数据标签为-1。考虑数据集 ,其中 。超平面的公式表示为 ,其中

SVM分为硬间隔SVM,软间隔SVM以及核SVM。下面以硬间隔SVM为例,解释SVM的目标函数是如何推导。硬间隔SVM基于数据线性可分(线性可分指的是存在一个超平面能够严格划分正负样本)的假设。SVM的原理是要找到一个超平面,这个超平面能够分离正负样本的同时,使得正样本和负样本到超平面的间隔是最大的。因此SVM又称为最大间隔分类器。

如何将"最大间隔分类器"转化为数学语言呢?首先我们要定义"间隔",所谓"间隔",指的是样本到达超平面的距离,那么就是整个样本集中离超平面最近的数据点到超平面的距离,因此可以写为 ,那么这个distance我们可以利用点到平面的距离公式来计算,因此 。那么我们要最大化这个"间隔",因此目标函数为 。另外除了最大间隔,还要是一个分类器,即当 时, ,当 时, ,因此约束条件可以写为 。因此我们将最大间隔分类器转化为数学语言为

等价于

(感知机算法提供了理论保证,这里是可以取到的)。因此上式转化为

取𝛾=1,且将目标函数变为最小化,便可以得到

为什么可以取𝛾=1,这里面要理解函数间隔和几何间隔的概念。实际上上式是在优化几何间隔,几何间隔是函数间隔的一个正则化表示,因此它的取值一般不影响优化问题的解,因此通常取1。

对于软间隔SVM分类器,面对的是线性不可分的问题,允许分类器一点错误。对于错误的样本要给予一定的惩罚,因此目标函数变为 ,这个loss便是对错分样本的惩罚函数,在SVM下,这里的loss是hinge损失,即 。因此软间隔SVM的优化目标为

其中C是惩罚系数。引入松弛变量 。因此上式转化为

这便是软间隔SVM的优化目标。对于核SVM而言,如果数据在低维空间中不可分,但是我们可以通过一个核函数𝜙将x映射到高维空间中,数据变为可分的。

优化目标与软间隔SVM基本相同,唯一的不同在于x要通过一个非线性的核函数的映射。因此核SVM 分类器的目标函数为

上述公式中,为了防止 SVM 分类器与噪声数据过度拟合,引入松弛变量 以允许某些数据点位于边距内,常数𝐶>0保证最大化边距与该边距内训练数据点数(从而确定训练误差)之间的权衡。𝜙是一个定义好的核函数,其作用是将数据投影到高维空间中,创建非线性决策边界,使得数据点在高维空间下可分。通常使用拉格朗日乘子法以及对偶化后,可以将上式转化为

最终的决策函数如下所示

其中 为拉格朗日乘数, 为核函数。常用的核函数包括如下几种

  • 线性核:

  • 多项式核:

  • 高斯核:

  • 拉普拉斯核:

二、KKT条件

这一部分我们介绍在SVM的原理中最重要的KKT条件,首先通过最优化理论可知,凸优化+Slater条件可以推出强对偶关系。

假设一个凸优化问题

Slater条件: (相对内部),

  • 对于大多数凸优化问题,slater条件成立。
  • 放松的slater:若𝑀中有𝐾个放射函数,则只需关心其余𝑀−𝐾个是否满足

下面我们写出上述凸优化问题的largrange函数

,于是对偶问题的优化目标是

上式凸优化问题的对偶问题为

KKT条件:

  1. 可行性条件
    对于最优解 ,满足 .
  2. 互补松弛条件(最重要的条件)
  3. 梯度为0

让我们考虑在硬间隔SVM下的KKT条件(硬间隔SVM指的是样本完全线性可分,不存在松弛变量 )。目标函数为

通过largrange乘子法,令 可以转化为

因此KKT条件为

注意第2条互补松弛条件,满足 的样本点称为支持向量 ,此时 。当 时, 。也就是说只有支持向量的样本点的largrange系数 。注意到 可知, ,说明 是数据的线性组合,只与支持向量有关系。

三、 SVM实现多分类

SVM 本身是一个二值分类器,但是我们可以很方便的把它扩展到多分类的情况,这样就可以很好的应用到文本分类,图像识别等场景中。扩展 SVM 支持多分类,一般有两种方法,OVR(one versus rest),一对多法;OVO(one versus one),一对一法。

OVR:对于 k 个类别的情况,训练 k 个SVM,第 j 个 SVM 用于判断任意一条数据是属于类别 j 还是非 j。这种方法需要针对 k 个分类训练 k 个分类器,分类的速度比较快,但是训练的速度较慢。当新增一个分类时,需要重新对分类进行构造。

OVO:对于 k 个类别的情况,训练 k * (k-1)/2 个 SVM,每一个 SVM 用来判断任意一条数据是属于 k 中的特定两个类别中的哪一个。对于一个未知的样本,每一个分类器都会有一个分类结果,记票为1,最终得票最多的类别就是未知样本的类别。这样当新增类别时,不需要重新构造 SVM 模型,训练速度快。但是当 k 较大时,训练和测试时间都会比较慢。

四、SVM的优缺点

优点

  • 可用于线性/非线性分类,也可以用于回归,泛化错误率低,也就是说具有良好的学习能力,且学到的结果具有很好的推广性。
  • 可以解决小样本情况下的机器学习问题,可以解决高维问题,可以避免神经网络结构选择和局部极小点问题。

缺点

  • 支持向量机算法对大规模训练样本难以实施,这是因为支持向量算法借助二次规划求解支持向量,这其中会设计m阶矩阵的计算,所以矩阵阶数很大时将耗费大量的机器内存和运算时间。
  • 现在常用的SVM理论都是使用固定惩罚系数C,但是正负样本的两种错误造成的损失是不一样的。
  • 对于参数调节和核函数的选择很敏感。

五、 SVM的应用

这里使用sklearn自带的svm包,调用SVC函数用来做SVM的分类。具体的参数设置可以查看官方文档

python 复制代码
class sklearn.svm.SVC(self, C=1.0, kernel='rbf', degree=3, gamma='auto_deprecated',    
             coef0=0.0, shrinking=True, probability=False,    
             tol=1e-3, cache_size=200, class_weight=None,    
             verbose=False, max_iter=-1, decision_function_shape='ovr',    
             random_state=None)
  • C:C-SVC的惩罚参数C,默认值是1.0。C越大,相当于惩罚松弛变量,希望松弛变量接近0,即对误分类的惩罚增大,趋向于对训练集全分对的情况,这样对训练集测试时准确率很高,但泛化能力弱。C值小,对误分类的惩罚减小,允许容错,将他们当成噪声点,泛化能力较强。
  • kernel :核函数,默认是rbf,可以是'linear', 'poly', 'rbf', 'sigmoid', 'precomputed'
  • degree :多项式poly函数的维度,默认是3,选择其他核函数时会被忽略。
  • gamma : 'rbf','poly' 和'sigmoid'的核函数参数。默认是'auto',则会选择1/n_features
  • coef0 :核函数的常数项。对于'poly'和 'sigmoid'有用。
  • probability :是否输出概率,当评价指标为AUC时需要打开。
  • shrinking :是否采用shrinking heuristic方法,默认为true
  • tol :停止训练的误差值大小,默认为1e-3
  • cache_size :核函数cache缓存大小,默认为200
  • class_weight :类别的权重,字典形式传递。设置第几类的参数C为weight*C(C-SVC中的C)
  • verbose :输出设置。
  • max_iter :最大迭代次数。-1为无限制。
  • decision_function_shape :控制多分类时的决策函数模式,包含'ovo', 'ovr' or None,
  • random_state :随机数种子
python 复制代码
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split

## 导入手写数字数据集
X=datasets.load_digits()['data']
Y=datasets.load_digits()['target']

Y[Y>=5] = 10
Y[Y<5] = 0
Y[Y==10] = 1


X_train,X_test,y_train,y_test=train_test_split(X, Y, test_size=0.4,stratify=Y)


from sklearn.svm import SVC

svm_model = SVC(kernel = 'linear', probability=True)
svm_model.fit(X_train,y_train)

y_pred = svm_model.predict(X_test)
y_proba = svm_model.predict_proba(X_test)


from sklearn.metrics import classification_report,accuracy_score,roc_auc_score,f1_score,roc_curve

print("分类报告:\n", classification_report(y_test, y_pred))
print("准确率:\n",accuracy_score(y_test, y_pred))
print("AUC:\n",roc_auc_score(y_test, y_proba[:, 1]))
print("micro F1:\n",f1_score(y_test, y_pred, average = 'micro'))
print("macro F1:\n",f1_score(y_test, y_pred, average = 'macro'))

'''
分类报告:
               precision    recall  f1-score   support

           0       0.88      0.90      0.89       361
           1       0.90      0.88      0.89       358

    accuracy                           0.89       719
   macro avg       0.89      0.89      0.89       719
weighted avg       0.89      0.89      0.89       719

准确率:
 0.8901251738525731
AUC:
 0.9466101301474799
micro F1:
 0.8901251738525731
macro F1:
 0.8901039157529781
'''


import matplotlib.pyplot as plt

auc = roc_auc_score(y_test,y_proba[:, 1])

fpr,tpr, thresholds = roc_curve(y_test,svm_model.decision_function(X_test))
plt.plot(fpr,tpr,color='darkorange',label='ROC curve (area = %0.2f)' % auc)
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic')
plt.legend(loc="lower right")
plt.show()


import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split

## 导入手写数字数据集
X=datasets.load_digits()['data']
Y=datasets.load_digits()['target']


X_train,X_test,y_train,y_test=train_test_split(X, Y, test_size=0.4,stratify=Y)


from sklearn.svm import SVC

svm_model = SVC(kernel = 'rbf', decision_function_shape = 'ovr')
svm_model.fit(X_train,y_train)

y_pred = svm_model.predict(X_test)


from sklearn.metrics import classification_report,accuracy_score,roc_auc_score,f1_score

print("分类报告:\n", classification_report(y_test, y_pred))
print("准确率:\n",accuracy_score(y_test, y_pred))
print("micro F1:\n",f1_score(y_test, y_pred, average = 'micro'))
print("macro F1:\n",f1_score(y_test, y_pred, average = 'macro'))

'''
分类报告:
               precision    recall  f1-score   support

           0       1.00      0.49      0.66        71
           1       1.00      0.23      0.38        73
           2       1.00      0.38      0.55        71
           3       0.17      1.00      0.30        73
           4       1.00      0.44      0.62        72
           5       1.00      0.89      0.94        73
           6       1.00      0.64      0.78        72
           7       1.00      0.35      0.52        72
           8       1.00      0.06      0.11        70
           9       1.00      0.68      0.81        72

    accuracy                           0.52       719
   macro avg       0.92      0.52      0.57       719
weighted avg       0.92      0.52      0.57       719

准确率:
 0.5187760778859527
micro F1:
 0.5187760778859527
macro F1:
 0.5656487510758079
'''
相关推荐
2301_79716471几秒前
GC1277和灿瑞的OCH477优势分析 可以用于电脑散热风扇,视频监控和图像处理的图像信号处理器中
图像处理·人工智能·单片机·嵌入式硬件·电脑·制造·充电枪
不打灰的小刘5 分钟前
Cursor AI编辑器:开发效率提升利器
人工智能·python·chatgpt·编辑器·个人开发
AnalogElectronic14 分钟前
tensorflow入门案例手写数字识别人工智能界的helloworld项目落地1
人工智能·python·tensorflow
OpenGVLab17 分钟前
MVBench多模态大模型视频理解能力基准 | CVPR Highlight
人工智能·深度学习·音视频
知来者逆26 分钟前
探索Ultralytics YOLO11在视觉任务上的应用
人工智能·深度学习·yolo·机器学习·计算机视觉·yolo11
weixin_457340211 小时前
图像分割恢复方法
人工智能·opencv·计算机视觉
萤火架构2 小时前
‌ComfyUI 高级实战:实现华为手机的AI消除功能
人工智能·comfyui·ai消除·涂抹消除·完美重绘
sp_fyf_20246 小时前
【大语言模型-论文精读】用于医疗领域摘要任务的大型语言模型评估综述
人工智能·神经网络·算法·计算机视觉·语言模型·自然语言处理·健康医疗
李元中7 小时前
2024下半年软考中级软件设计师,这100题,必做!
java·开发语言·javascript·人工智能·算法·ecmascript