一、SVM基本原理SVM(Support Vector Machine)
SVM是机器学习中常用的分类算法。SVM算法可以将𝑛维空间分割到不同的类中,以便将新数据点放在正确的类中,得到的最佳决策边界称为超平面。在超平面一侧的数据标签为1,而在另一侧的数据标签为-1。考虑数据集 ,其中 。超平面的公式表示为 ,其中 和 。
SVM分为硬间隔SVM,软间隔SVM以及核SVM。下面以硬间隔SVM为例,解释SVM的目标函数是如何推导。硬间隔SVM基于数据线性可分(线性可分指的是存在一个超平面能够严格划分正负样本)的假设。SVM的原理是要找到一个超平面,这个超平面能够分离正负样本的同时,使得正样本和负样本到超平面的间隔是最大的。因此SVM又称为最大间隔分类器。
如何将"最大间隔分类器"转化为数学语言呢?首先我们要定义"间隔",所谓"间隔",指的是样本到达超平面的距离,那么就是整个样本集中离超平面最近的数据点到超平面的距离,因此可以写为 ,那么这个distance我们可以利用点到平面的距离公式来计算,因此 。那么我们要最大化这个"间隔",因此目标函数为 。另外除了最大间隔,还要是一个分类器,即当 时, ,当 时, ,因此约束条件可以写为 。因此我们将最大间隔分类器转化为数学语言为
等价于
取 (感知机算法提供了理论保证,这里是可以取到的)。因此上式转化为
取𝛾=1,且将目标函数变为最小化,便可以得到
为什么可以取𝛾=1,这里面要理解函数间隔和几何间隔的概念。实际上上式是在优化几何间隔,几何间隔是函数间隔的一个正则化表示,因此它的取值一般不影响优化问题的解,因此通常取1。
对于软间隔SVM分类器,面对的是线性不可分的问题,允许分类器一点错误。对于错误的样本要给予一定的惩罚,因此目标函数变为 ,这个loss便是对错分样本的惩罚函数,在SVM下,这里的loss是hinge损失,即 。因此软间隔SVM的优化目标为
其中C是惩罚系数。引入松弛变量 。因此上式转化为
这便是软间隔SVM的优化目标。对于核SVM而言,如果数据在低维空间中不可分,但是我们可以通过一个核函数𝜙将x映射到高维空间中,数据变为可分的。
优化目标与软间隔SVM基本相同,唯一的不同在于x要通过一个非线性的核函数的映射。因此核SVM 分类器的目标函数为
上述公式中,为了防止 SVM 分类器与噪声数据过度拟合,引入松弛变量 以允许某些数据点位于边距内,常数𝐶>0保证最大化边距与该边距内训练数据点数(从而确定训练误差)之间的权衡。𝜙是一个定义好的核函数,其作用是将数据投影到高维空间中,创建非线性决策边界,使得数据点在高维空间下可分。通常使用拉格朗日乘子法以及对偶化后,可以将上式转化为
最终的决策函数如下所示
其中 为拉格朗日乘数, 为核函数。常用的核函数包括如下几种
-
线性核: ;
-
多项式核: ;
-
高斯核: ;
-
拉普拉斯核: 。
二、KKT条件
这一部分我们介绍在SVM的原理中最重要的KKT条件,首先通过最优化理论可知,凸优化+Slater条件可以推出强对偶关系。
假设一个凸优化问题
Slater条件: (相对内部), 。
- 对于大多数凸优化问题,slater条件成立。
- 放松的slater:若𝑀中有𝐾个放射函数,则只需关心其余𝑀−𝐾个是否满足 。
下面我们写出上述凸优化问题的largrange函数
,于是对偶问题的优化目标是
。
上式凸优化问题的对偶问题为
KKT条件:
- 可行性条件
对于最优解 ,满足 . - 互补松弛条件(最重要的条件)
- 梯度为0
让我们考虑在硬间隔SVM下的KKT条件(硬间隔SVM指的是样本完全线性可分,不存在松弛变量 )。目标函数为
通过largrange乘子法,令 可以转化为
因此KKT条件为
注意第2条互补松弛条件,满足 的样本点称为支持向量 ,此时 。当 时, 。也就是说只有支持向量的样本点的largrange系数 。注意到 可知, ,说明 是数据的线性组合,只与支持向量有关系。
三、 SVM实现多分类
SVM 本身是一个二值分类器,但是我们可以很方便的把它扩展到多分类的情况,这样就可以很好的应用到文本分类,图像识别等场景中。扩展 SVM 支持多分类,一般有两种方法,OVR(one versus rest),一对多法;OVO(one versus one),一对一法。
OVR:对于 k 个类别的情况,训练 k 个SVM,第 j 个 SVM 用于判断任意一条数据是属于类别 j 还是非 j。这种方法需要针对 k 个分类训练 k 个分类器,分类的速度比较快,但是训练的速度较慢。当新增一个分类时,需要重新对分类进行构造。
OVO:对于 k 个类别的情况,训练 k * (k-1)/2 个 SVM,每一个 SVM 用来判断任意一条数据是属于 k 中的特定两个类别中的哪一个。对于一个未知的样本,每一个分类器都会有一个分类结果,记票为1,最终得票最多的类别就是未知样本的类别。这样当新增类别时,不需要重新构造 SVM 模型,训练速度快。但是当 k 较大时,训练和测试时间都会比较慢。
四、SVM的优缺点
优点
- 可用于线性/非线性分类,也可以用于回归,泛化错误率低,也就是说具有良好的学习能力,且学到的结果具有很好的推广性。
- 可以解决小样本情况下的机器学习问题,可以解决高维问题,可以避免神经网络结构选择和局部极小点问题。
缺点
- 支持向量机算法对大规模训练样本难以实施,这是因为支持向量算法借助二次规划求解支持向量,这其中会设计m阶矩阵的计算,所以矩阵阶数很大时将耗费大量的机器内存和运算时间。
- 现在常用的SVM理论都是使用固定惩罚系数C,但是正负样本的两种错误造成的损失是不一样的。
- 对于参数调节和核函数的选择很敏感。
五、 SVM的应用
这里使用sklearn自带的svm包,调用SVC函数用来做SVM的分类。具体的参数设置可以查看官方文档。
python
class sklearn.svm.SVC(self, C=1.0, kernel='rbf', degree=3, gamma='auto_deprecated',
coef0=0.0, shrinking=True, probability=False,
tol=1e-3, cache_size=200, class_weight=None,
verbose=False, max_iter=-1, decision_function_shape='ovr',
random_state=None)
- C:C-SVC的惩罚参数C,默认值是1.0。C越大,相当于惩罚松弛变量,希望松弛变量接近0,即对误分类的惩罚增大,趋向于对训练集全分对的情况,这样对训练集测试时准确率很高,但泛化能力弱。C值小,对误分类的惩罚减小,允许容错,将他们当成噪声点,泛化能力较强。
- kernel :核函数,默认是rbf,可以是'linear', 'poly', 'rbf', 'sigmoid', 'precomputed'
- degree :多项式poly函数的维度,默认是3,选择其他核函数时会被忽略。
- gamma : 'rbf','poly' 和'sigmoid'的核函数参数。默认是'auto',则会选择1/n_features
- coef0 :核函数的常数项。对于'poly'和 'sigmoid'有用。
- probability :是否输出概率,当评价指标为AUC时需要打开。
- shrinking :是否采用shrinking heuristic方法,默认为true
- tol :停止训练的误差值大小,默认为1e-3
- cache_size :核函数cache缓存大小,默认为200
- class_weight :类别的权重,字典形式传递。设置第几类的参数C为weight*C(C-SVC中的C)
- verbose :输出设置。
- max_iter :最大迭代次数。-1为无限制。
- decision_function_shape :控制多分类时的决策函数模式,包含'ovo', 'ovr' or None,
- random_state :随机数种子
python
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
## 导入手写数字数据集
X=datasets.load_digits()['data']
Y=datasets.load_digits()['target']
Y[Y>=5] = 10
Y[Y<5] = 0
Y[Y==10] = 1
X_train,X_test,y_train,y_test=train_test_split(X, Y, test_size=0.4,stratify=Y)
from sklearn.svm import SVC
svm_model = SVC(kernel = 'linear', probability=True)
svm_model.fit(X_train,y_train)
y_pred = svm_model.predict(X_test)
y_proba = svm_model.predict_proba(X_test)
from sklearn.metrics import classification_report,accuracy_score,roc_auc_score,f1_score,roc_curve
print("分类报告:\n", classification_report(y_test, y_pred))
print("准确率:\n",accuracy_score(y_test, y_pred))
print("AUC:\n",roc_auc_score(y_test, y_proba[:, 1]))
print("micro F1:\n",f1_score(y_test, y_pred, average = 'micro'))
print("macro F1:\n",f1_score(y_test, y_pred, average = 'macro'))
'''
分类报告:
precision recall f1-score support
0 0.88 0.90 0.89 361
1 0.90 0.88 0.89 358
accuracy 0.89 719
macro avg 0.89 0.89 0.89 719
weighted avg 0.89 0.89 0.89 719
准确率:
0.8901251738525731
AUC:
0.9466101301474799
micro F1:
0.8901251738525731
macro F1:
0.8901039157529781
'''
import matplotlib.pyplot as plt
auc = roc_auc_score(y_test,y_proba[:, 1])
fpr,tpr, thresholds = roc_curve(y_test,svm_model.decision_function(X_test))
plt.plot(fpr,tpr,color='darkorange',label='ROC curve (area = %0.2f)' % auc)
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic')
plt.legend(loc="lower right")
plt.show()
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
## 导入手写数字数据集
X=datasets.load_digits()['data']
Y=datasets.load_digits()['target']
X_train,X_test,y_train,y_test=train_test_split(X, Y, test_size=0.4,stratify=Y)
from sklearn.svm import SVC
svm_model = SVC(kernel = 'rbf', decision_function_shape = 'ovr')
svm_model.fit(X_train,y_train)
y_pred = svm_model.predict(X_test)
from sklearn.metrics import classification_report,accuracy_score,roc_auc_score,f1_score
print("分类报告:\n", classification_report(y_test, y_pred))
print("准确率:\n",accuracy_score(y_test, y_pred))
print("micro F1:\n",f1_score(y_test, y_pred, average = 'micro'))
print("macro F1:\n",f1_score(y_test, y_pred, average = 'macro'))
'''
分类报告:
precision recall f1-score support
0 1.00 0.49 0.66 71
1 1.00 0.23 0.38 73
2 1.00 0.38 0.55 71
3 0.17 1.00 0.30 73
4 1.00 0.44 0.62 72
5 1.00 0.89 0.94 73
6 1.00 0.64 0.78 72
7 1.00 0.35 0.52 72
8 1.00 0.06 0.11 70
9 1.00 0.68 0.81 72
accuracy 0.52 719
macro avg 0.92 0.52 0.57 719
weighted avg 0.92 0.52 0.57 719
准确率:
0.5187760778859527
micro F1:
0.5187760778859527
macro F1:
0.5656487510758079
'''