第6集:支持向量机(SVM)------强大的非线性分类器
支持向量机(Support Vector Machine, SVM) 是一种功能强大的机器学习算法,广泛应用于分类和回归任务。它通过寻找一个最优的超平面来划分数据,并借助核函数实现非线性分类。今天我们将深入探讨 SVM 的基本原理,并通过实践部分使用 手写数字数据集(Digits) 进行分类任务。
SVM的基本思想
什么是支持向量机?
SVM 的目标是找到一个能够将不同类别样本分开的超平面,使得两类样本之间的间隔(margin)最大化。这些最靠近超平面的样本点被称为 支持向量,它们决定了超平面的位置。
图1:SVM 超平面与支持向量
(图片描述:二维平面上展示了两类样本点,一条直线将两类分开,最近的几个点被标记为支持向量,虚线表示间隔边界。)
- 硬间隔 SVM:要求所有样本点都严格满足分类条件。
- 软间隔 SVM:允许少量样本点违反分类条件,以提高模型的泛化能力。
核函数的作用
当数据无法通过线性超平面分离时,SVM 可以借助 核函数 将数据映射到高维空间,从而实现非线性分类。
常见核函数
-
线性核(Linear Kernel):
- 公式: K ( x i , x j ) = x i ⋅ x j K(x_i, x_j) = x_i \cdot x_j K(xi,xj)=xi⋅xj
- 适用于线性可分的数据。
-
多项式核(Polynomial Kernel):
- 公式: K ( x i , x j ) = ( x i ⋅ x j + c ) d K(x_i, x_j) = (x_i \cdot x_j + c)^d K(xi,xj)=(xi⋅xj+c)d
d 是多项式的阶数,适用于复杂但低维的数据。 d 是多项式的阶数,适用于复杂但低维的数据。 d是多项式的阶数,适用于复杂但低维的数据。
- 公式: K ( x i , x j ) = ( x i ⋅ x j + c ) d K(x_i, x_j) = (x_i \cdot x_j + c)^d K(xi,xj)=(xi⋅xj+c)d
-
RBF 核(Radial Basis Function Kernel):
- 公式: K ( x i , x j ) = exp ( − γ ∣ ∣ x i − x j ∣ ∣ 2 ) K(x_i, x_j) = \exp(-\gamma ||x_i - x_j||^2) K(xi,xj)=exp(−γ∣∣xi−xj∣∣2)
- 最常用的核函数,适用于非线性数据。
-
Sigmoid 核:
- 公式: K ( x i , x j ) = tanh ( α x i ⋅ x j + c ) K(x_i, x_j) = \tanh(\alpha x_i \cdot x_j + c) K(xi,xj)=tanh(αxi⋅xj+c)
- 类似于神经网络中的激活函数。
图2:核函数的效果对比
(图片描述:左侧为原始数据分布,右侧为经过 RBF 核映射后的高维空间,数据变得线性可分。)
SVM 在高维空间中的表现
SVM 的强大之处在于其能够在高维空间中找到复杂的决策边界。通过核函数,SVM 可以处理非线性问题,同时避免了显式计算高维特征的开销(通过核技巧实现)。
参数选择与正则化
SVM 的性能受以下参数影响:
-
C
:正则化参数,控制模型对误分类样本的容忍程度。- 较小的 C :更大的间隔,更高的偏差,更低的方差。 较小的 C :更大的间隔,更高的偏差,更低的方差。 较小的C:更大的间隔,更高的偏差,更低的方差。
- 较大的 C :更小的间隔,更低的偏差,更高的方差。 较大的 C :更小的间隔,更低的偏差,更高的方差。 较大的C:更小的间隔,更低的偏差,更高的方差。
-
gamma
(仅适用于 RBF 核):- 控制每个支持向量的影响范围。
- 较小的 γ :影响范围大,模型更平滑。 较小的 \gamma :影响范围大,模型更平滑。 较小的γ:影响范围大,模型更平滑。
- 较大的 γ :影响范围小,模型更复杂。 较大的 \gamma :影响范围小,模型更复杂。 较大的γ:影响范围小,模型更复杂。
-
核函数选择:根据数据特性选择合适的核函数。
实践部分:使用 SVM 对手写数字数据集进行分类
数据集简介
我们使用 Scikit-learn 提供的 Digits 数据集,包含 1797 张 8x8 像素的手写数字图像(0-9)。目标是对手写数字进行分类。
完整代码
python
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.preprocessing import StandardScaler
# 加载数据
digits = datasets.load_digits()
X = digits.data # 特征矩阵(每张图片展平为 64 维向量)
y = digits.target # 标签(0-9)
# 数据标准化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)
# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.3, random_state=42)
# 构建 SVM 模型
model = SVC()
# 超参数调优
param_grid = {
'C': [0.1, 1, 10],
'gamma': [0.01, 0.1, 1],
'kernel': ['linear', 'rbf']
}
grid_search = GridSearchCV(model, param_grid, cv=3, scoring='accuracy')
grid_search.fit(X_train, y_train)
# 最佳模型
best_model = grid_search.best_estimator_
# 预测
y_pred = best_model.predict(X_test)
# 评估模型性能
accuracy = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
class_report = classification_report(y_test, y_pred)
print("最佳超参数:", grid_search.best_params_)
print(f"Accuracy: {accuracy:.2f}")
print("Confusion Matrix:")
print(conf_matrix)
print("Classification Report:")
print(class_report)
# 可视化部分测试结果
fig, axes = plt.subplots(2, 5, figsize=(10, 5))
for i, ax in enumerate(axes.ravel()):
ax.imshow(X_test[i].reshape(8, 8), cmap='gray')
ax.set_title(f"True: {y_test[i]}\nPred: {y_pred[i]}")
ax.axis('off')
plt.tight_layout()
plt.show()
运行结果
输出结果:
最佳超参数: {'C': 10, 'gamma': 0.01, 'kernel': 'rbf'}
Accuracy: 0.98
Confusion Matrix:
[[53 0 0 0 0 0 0 0 0 0]
[ 0 50 0 0 0 0 0 0 0 0]
[ 0 0 47 0 0 0 0 0 0 0]
[ 0 0 1 52 0 1 0 0 0 0]
[ 0 0 0 0 60 0 0 0 0 0]
[ 0 0 0 0 0 65 0 0 0 1]
[ 0 0 0 0 0 0 53 0 0 0]
[ 0 0 0 0 0 0 0 54 0 1]
[ 0 0 1 1 0 0 0 0 41 0]
[ 0 0 0 0 0 0 1 0 2 56]]
Classification Report:
precision recall f1-score support
0 1.00 1.00 1.00 53
1 1.00 1.00 1.00 50
2 0.96 1.00 0.98 47
3 0.98 0.96 0.97 54
4 1.00 1.00 1.00 60
5 0.98 0.98 0.98 66
6 0.98 1.00 0.99 53
7 1.00 0.98 0.99 55
8 0.95 0.95 0.95 43
9 0.97 0.95 0.96 59
accuracy 0.98 540
macro avg 0.98 0.98 0.98 540
weighted avg 0.98 0.98 0.98 540
图3:手写数字分类结果
(图片描述:10 张手写数字图片及其真实标签和预测标签,显示模型的分类效果。)
总结
本文介绍了 SVM 的基本原理、核函数的作用以及如何通过超参数调优提升模型性能。通过实践部分,我们成功使用 SVM 对手写数字数据集进行了分类任务,并取得了较高的准确率。