支持向量机(SVM)入门:超平面与核函数的通俗解释

一、引言

在机器学习的众多算法中,支持向量机(Support Vector Machine,SVM)是一种强大且广泛应用的分类和回归算法。它由 Vladimir Vapnik 和 Alexey Chervonenkis 在 20 世纪 60 年代末提出,经过多年的发展和改进,已经成为机器学习领域的经典算法之一。SVM 在处理小样本、高维度数据时表现出色,尤其在文本分类、图像识别、生物信息学等领域有着广泛的应用。

本文将以通俗易懂的方式,深入讲解支持向量机中的超平面和核函数这两个核心概念,帮助读者更好地理解 SVM 的工作原理。

二、线性可分问题与超平面

2.1 线性可分问题的引入

在机器学习的分类问题中,我们常常会遇到线性可分的数据。例如,有两类数据点,分别用红色和蓝色表示,我们希望找到一条直线(在二维空间中)或者一个平面(在三维空间中),能够将这两类数据完美地分开。这条直线或平面就被称为超平面。

假设我们有一个二维数据集,包含两类样本点,我们的目标是找到一个最优的超平面,使得两类样本点能够被清晰地分开,并且这个超平面到两类样本点的距离尽可能大。

2.2 超平面的数学定义

在二维空间中,超平面是一条直线,其方程可以表示为:

w_1x_1 + w_2x_2 + b = 0

其中,(w = [w_1, w_2]^T) 是超平面的法向量,(b) 是偏置项,((x_1, x_2)) 是二维空间中的点。

在更一般的 (n) 维空间中,超平面的方程可以表示为:

w\^T x + b = 0

其中,(w) 是 (n) 维的法向量,(x) 是 (n) 维空间中的点。

2.3 超平面的几何意义

超平面将 (n) 维空间划分为两个部分。对于空间中的任意一点 (x),如果 (w^T x + b > 0),则该点位于超平面的一侧;如果 (w^T x + b < 0),则该点位于超平面的另一侧。

我们可以通过一个简单的 Python 代码来可视化二维空间中的超平面:

python 复制代码
import numpy as np
import matplotlib.pyplot as plt

# 生成一些线性可分的数据
np.random.seed(0)
X1 = np.random.randn(20, 2) + [2, 2]
X2 = np.random.randn(20, 2) + [-2, -2]
X = np.vstack((X1, X2))
y = np.hstack((np.ones(20), -np.ones(20)))

# 定义超平面的参数
w = np.array([1, 1])
b = 0

# 绘制数据点
plt.scatter(X1[:, 0], X1[:, 1], c='r', label='Class 1')
plt.scatter(X2[:, 0], X2[:, 1], c='b', label='Class -1')

# 绘制超平面
x1 = np.linspace(-5, 5, 100)
x2 = -(w[0] * x1 + b) / w[1]
plt.plot(x1, x2, 'g-', label='Hyperplane')

plt.xlabel('x1')
plt.ylabel('x2')
plt.title('Linear Separable Data and Hyperplane')
plt.legend()
plt.show()

三、支持向量与间隔最大化

3.1 支持向量的概念

在寻找最优超平面的过程中,我们发现有一些样本点距离超平面最近,这些样本点被称为支持向量。支持向量决定了超平面的位置和方向,因为它们是最难以分类的样本点。

3.2 间隔的定义

间隔是指超平面到最近的支持向量的距离。我们希望找到一个超平面,使得这个间隔最大。间隔最大化的超平面能够提高模型的泛化能力,因为它能够在一定程度上抵御噪声和异常点的影响。

间隔的计算公式为:

\\gamma = \\frac{2}{\|w\|}

其中,(|w|) 是法向量 (w) 的模。

3.3 间隔最大化的优化问题

我们的目标是找到一个超平面 (w^T x + b = 0),使得间隔 (\gamma) 最大。这可以转化为一个优化问题:

\\begin{aligned} \\max_{w, b} \&\\quad \\frac{2}{\|w\|} \\ \\text{s.t.} \&\\quad y_i(w\^T x_i + b) \\geq 1, \\quad i = 1, 2, \\cdots, N \\end{aligned}

其中,(y_i) 是样本点 (x_i) 的标签,(N) 是样本的数量。

为了方便计算,我们通常将上述优化问题转化为其对偶问题:

\\begin{aligned} \\max_{\\alpha} \&\\quad \\sum_{i=1}\^{N} \\alpha_i - \\frac{1}{2} \\sum_{i=1}\^{N} \\sum_{j=1}\^{N} \\alpha_i \\alpha_j y_i y_j x_i\^T x_j \\ \\text{s.t.} \&\\quad \\sum_{i=1}\^{N} \\alpha_i y_i = 0 \\ \&\\quad 0 \\leq \\alpha_i \\leq C, \\quad i = 1, 2, \\cdots, N \\end{aligned}

其中,(\alpha_i) 是拉格朗日乘子,(C) 是惩罚参数,用于控制误分类的惩罚程度。

四、核函数的引入

4.1 线性不可分问题

在实际应用中,我们常常会遇到线性不可分的数据,即无法找到一个超平面将两类数据完美地分开。例如,在二维空间中,两类数据点可能呈现出非线性的分布。

4.2 核函数的作用

核函数的作用是将原始数据映射到一个更高维的特征空间,使得在这个高维空间中数据变得线性可分。核函数的巧妙之处在于,它不需要显式地计算数据在高维空间中的映射,而是直接计算高维空间中数据点的内积。

4.3 常见的核函数

常见的核函数有以下几种:

核函数名称 数学表达式 特点
线性核函数 (K(x_i, x_j) = x_i^T x_j) 适用于线性可分的数据,计算速度快
多项式核函数 (K(x_i, x_j) = (\gamma x_i^T x_j + r)^d) 可以处理非线性数据,(\gamma)、(r) 和 (d) 是参数
高斯核函数(径向基函数,RBF) (K(x_i, x_j) = \exp(-\gamma |x_i - x_j|^2)) 具有很强的非线性映射能力,是最常用的核函数之一
Sigmoid 核函数 (K(x_i, x_j) = \tanh(\gamma x_i^T x_j + r)) 可以处理非线性数据,但容易出现梯度消失的问题
4.4 核函数在 SVM 中的应用

在引入核函数后,SVM 的对偶问题可以改写为:

\\begin{aligned} \\max_{\\alpha} \&\\quad \\sum_{i=1}\^{N} \\alpha_i - \\frac{1}{2} \\sum_{i=1}\^{N} \\sum_{j=1}\^{N} \\alpha_i \\alpha_j y_i y_j K(x_i, x_j) \\ \\text{s.t.} \&\\quad \\sum_{i=1}\^{N} \\alpha_i y_i = 0 \\ \&\\quad 0 \\leq \\alpha_i \\leq C, \\quad i = 1, 2, \\cdots, N \\end{aligned}

五、使用 Python 实现 SVM

5.1 使用 Scikit-learn 库实现 SVM

Scikit-learn 是一个强大的 Python 机器学习库,提供了 SVM 的实现。以下是一个使用 SVM 进行分类的示例代码:

python 复制代码
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score

# 加载鸢尾花数据集
iris = datasets.load_iris()
X = iris.data[:, :2]  # 只取前两个特征
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建 SVM 分类器
clf = SVC(kernel='linear')

# 训练模型
clf.fit(X_train, y_train)

# 预测测试集
y_pred = clf.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"Accuracy: {accuracy}")
5.2 不同核函数的效果比较

我们可以使用不同的核函数来训练 SVM 模型,并比较它们的分类效果。以下是一个比较线性核函数和高斯核函数的示例代码:

python 复制代码
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score

# 加载鸢尾花数据集
iris = datasets.load_iris()
X = iris.data[:, :2]  # 只取前两个特征
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 创建线性核 SVM 分类器
clf_linear = SVC(kernel='linear')
clf_linear.fit(X_train, y_train)
y_pred_linear = clf_linear.predict(X_test)
accuracy_linear = accuracy_score(y_test, y_pred_linear)

# 创建高斯核 SVM 分类器
clf_rbf = SVC(kernel='rbf')
clf_rbf.fit(X_train, y_train)
y_pred_rbf = clf_rbf.predict(X_test)
accuracy_rbf = accuracy_score(y_test, y_pred_rbf)

print(f"Linear kernel accuracy: {accuracy_linear}")
print(f"RBF kernel accuracy: {accuracy_rbf}")

# 绘制决策边界
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired)
ax = plt.gca()
xlim = ax.get_xlim()
ylim = ax.get_ylim()

# 创建网格点
xx = np.linspace(xlim[0], xlim[1], 30)
yy = np.linspace(ylim[0], ylim[1], 30)
YY, XX = np.meshgrid(yy, xx)
xy = np.vstack([XX.ravel(), YY.ravel()]).T
Z = clf_linear.decision_function(xy).reshape(XX.shape)

# 绘制决策边界和间隔
ax.contour(XX, YY, Z, colors='k', levels=[-1, 0, 1], alpha=0.5, linestyles=['--', '-', '--'])
plt.title('Linear Kernel')

plt.subplot(1, 2, 2)
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Paired)
ax = plt.gca()
xlim = ax.get_xlim()
ylim = ax.get_ylim()

# 创建网格点
xx = np.linspace(xlim[0], xlim[1], 30)
yy = np.linspace(ylim[0], ylim[1], 30)
YY, XX = np.meshgrid(yy, xx)
xy = np.vstack([XX.ravel(), YY.ravel()]).T
Z = clf_rbf.decision_function(xy).reshape(XX.shape)

# 绘制决策边界和间隔
ax.contour(XX, YY, Z, colors='k', levels=[-1, 0, 1], alpha=0.5, linestyles=['--', '-', '--'])
plt.title('RBF Kernel')

plt.show()

六、总结

本文详细介绍了支持向量机(SVM)中的超平面和核函数这两个核心概念。超平面是 SVM 用于分类的基础,通过间隔最大化的方法可以找到最优的超平面。而核函数则是 SVM 处理非线性数据的关键,它能够将原始数据映射到高维空间,使得数据变得线性可分。通过 Python 代码示例,我们展示了如何使用 SVM 进行分类,并比较了不同核函数的效果。希望本文能够帮助读者更好地理解 SVM 的工作原理,为进一步学习和应用 SVM 打下基础。

相关推荐
im_AMBER2 小时前
Leetcode 102 反转链表
数据结构·c++·学习·算法·leetcode·链表
今儿敲了吗3 小时前
01|多项式输出
c++·笔记·算法
Xの哲學3 小时前
深入剖析Linux文件系统数据结构实现机制
linux·运维·网络·数据结构·算法
AlenTech3 小时前
200. 岛屿数量 - 力扣(LeetCode)
算法·leetcode·职场和发展
C雨后彩虹3 小时前
竖直四子棋
java·数据结构·算法·华为·面试
不如自挂东南吱4 小时前
空间相关性 和 怎么捕捉空间相关性
人工智能·深度学习·算法·机器学习·时序数据库
洛生&4 小时前
Elevator Rides
算法
2501_933513044 小时前
关于一种计数的讨论、ARC212C Solution
算法
Wu_Dylan4 小时前
智能体系列(二):规划(Planning):从 CoT、ToT 到动态采样与搜索
人工智能·算法