支持向量机

支持向量机(Support Vector Machine,SVM)是一种经典的监督学习算法,在分类问题中表现出色。它通过寻找最优超平面来实现对不同类别样本的划分,具有良好的泛化能力和对高维数据的处理能力。本文将详细介绍 SVM 的原理,并结合代码实现展示其应用。

一、SVM 基本原理

1. 核心思想

SVM 的基本需求是在样本空间中找到一个划分超平面,将不同类别的样本分开。理想的超平面应具备对训练样本局部扰动的 "容忍性" 最好这一特性,而实现这一目标的关键是最大化间隔(margin)

间隔指的是超平面到最近的样本点之间的距离的两倍。那些距离超平面最近的样本点被称为支持向量,它们决定了超平面的位置和方向。

2. 超平面与距离计算

  • 超平面定义 :在 n 维空间中,超平面是 n-1 维的子空间,由方程表示,其中w是 n 维向量,b是实数。例如,三维空间中的超平面是二维平面,二维空间中的超平面是一维直线。
  • 点到超平面的距离 :对于 n 维空间中的点x,到超平面的距离为

3. 优化目标

SVM 的优化目标是最大化间隔,也就是最大化支持向量到超平面的距离。通过一系列推导,可将其转化为求解带约束的极小值问题:

  • 目标函数:
  • 约束条件:(其中y_i为样本标签,Φ(xi)是对数据的变换)

该问题可通过拉格朗日乘子法求解,转化为对偶问题后进行计算。

二、SVM 代码实现

下面使用 Python 的 scikit-learn 库实现 SVM 分类,并分别展示线性核和高斯核的效果。

python 复制代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score

# 1. 准备数据
# 使用鸢尾花数据集(简化为二分类问题)
iris = datasets.load_iris()
X = iris.data[:100, :2]  # 取前100个样本,前2个特征
y = iris.target[:100]    # 标签

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 2. 构建SVM模型
# 线性核SVM
svm_linear = SVC(kernel='linear', C=1.0)
# 高斯核SVM
svm_rbf = SVC(kernel='rbf', C=1.0, gamma='scale')

# 3. 训练模型
svm_linear.fit(X_train, y_train)
svm_rbf.fit(X_train, y_train)

# 4. 预测与评估
y_pred_linear = svm_linear.predict(X_test)
y_pred_rbf = svm_rbf.predict(X_test)

accuracy_linear = accuracy_score(y_test, y_pred_linear)
accuracy_rbf = accuracy_score(y_test, y_pred_rbf)

print(f"线性核SVM准确率:{accuracy_linear:.4f}")
print(f"高斯核SVM准确率:{accuracy_rbf:.4f}")

# 5. 可视化决策边界
def plot_decision_boundary(model, X, y, title):
    h = 0.02  # 网格步长
    x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
    
    Z = model.predict(np.c_[xx.ravel(), yy.ravel()])
    Z = Z.reshape(xx.shape)
    
    plt.contourf(xx, yy, Z, alpha=0.8)
    plt.scatter(X[:, 0], X[:, 1], c=y, edgecolors='k')
    plt.title(title)
    plt.xlabel('特征1')
    plt.ylabel('特征2')
    plt.show()

plot_decision_boundary(svm_linear, X, y, '线性核SVM决策边界')
plot_decision_boundary(svm_rbf, X, y, '高斯核SVM决策边界')

三、代码解析

  1. 数据准备:使用鸢尾花数据集的前 100 个样本(两类),并选取前两个特征以便可视化,然后将数据划分为训练集和测试集。
  2. 模型构建 :分别创建了线性核和高斯核的 SVM 模型。C为惩罚参数,gamma是高斯核的参数,影响核函数的宽度。
  3. 训练与评估:用训练集训练模型,再用测试集评估模型性能,通过准确率来衡量。
  4. 决策边界可视化:通过生成网格点,预测每个网格点的类别,绘制出决策边界,直观展示 SVM 的分类效果。

运行上述代码,可以看到不同核函数的 SVM 在鸢尾花数据集上的分类表现。一般来说,对于线性可分或近似线性可分的数据,线性核即可取得较好效果;对于非线性数据,高斯核等非线性核函数往往能表现更优。