机器学习实战(6):支持向量机(SVM)——强大的非线性分类器

第6集:支持向量机(SVM)------强大的非线性分类器

支持向量机(Support Vector Machine, SVM) 是一种功能强大的机器学习算法,广泛应用于分类和回归任务。它通过寻找一个最优的超平面来划分数据,并借助核函数实现非线性分类。今天我们将深入探讨 SVM 的基本原理,并通过实践部分使用 手写数字数据集(Digits) 进行分类任务。


SVM的基本思想

什么是支持向量机?

SVM 的目标是找到一个能够将不同类别样本分开的超平面,使得两类样本之间的间隔(margin)最大化。这些最靠近超平面的样本点被称为 支持向量,它们决定了超平面的位置。

图1:SVM 超平面与支持向量

(图片描述:二维平面上展示了两类样本点,一条直线将两类分开,最近的几个点被标记为支持向量,虚线表示间隔边界。)

  • 硬间隔 SVM:要求所有样本点都严格满足分类条件。
  • 软间隔 SVM:允许少量样本点违反分类条件,以提高模型的泛化能力。

核函数的作用

当数据无法通过线性超平面分离时,SVM 可以借助 核函数 将数据映射到高维空间,从而实现非线性分类。

常见核函数

  1. 线性核(Linear Kernel)

    • 公式: K ( x i , x j ) = x i ⋅ x j K(x_i, x_j) = x_i \cdot x_j K(xi,xj)=xi⋅xj
    • 适用于线性可分的数据。
  2. 多项式核(Polynomial Kernel)

    • 公式: K ( x i , x j ) = ( x i ⋅ x j + c ) d K(x_i, x_j) = (x_i \cdot x_j + c)^d K(xi,xj)=(xi⋅xj+c)d
      d 是多项式的阶数,适用于复杂但低维的数据。 d 是多项式的阶数,适用于复杂但低维的数据。 d是多项式的阶数,适用于复杂但低维的数据。
  3. RBF 核(Radial Basis Function Kernel)

    • 公式: K ( x i , x j ) = exp ⁡ ( − γ ∣ ∣ x i − x j ∣ ∣ 2 ) K(x_i, x_j) = \exp(-\gamma ||x_i - x_j||^2) K(xi,xj)=exp(−γ∣∣xi−xj∣∣2)
    • 最常用的核函数,适用于非线性数据。
  4. Sigmoid 核

    • 公式: K ( x i , x j ) = tanh ⁡ ( α x i ⋅ x j + c ) K(x_i, x_j) = \tanh(\alpha x_i \cdot x_j + c) K(xi,xj)=tanh(αxi⋅xj+c)
    • 类似于神经网络中的激活函数。

图2:核函数的效果对比

(图片描述:左侧为原始数据分布,右侧为经过 RBF 核映射后的高维空间,数据变得线性可分。)


SVM 在高维空间中的表现

SVM 的强大之处在于其能够在高维空间中找到复杂的决策边界。通过核函数,SVM 可以处理非线性问题,同时避免了显式计算高维特征的开销(通过核技巧实现)。


参数选择与正则化

SVM 的性能受以下参数影响:

  1. C:正则化参数,控制模型对误分类样本的容忍程度。

    • 较小的 C :更大的间隔,更高的偏差,更低的方差。 较小的 C :更大的间隔,更高的偏差,更低的方差。 较小的C:更大的间隔,更高的偏差,更低的方差。
    • 较大的 C :更小的间隔,更低的偏差,更高的方差。 较大的 C :更小的间隔,更低的偏差,更高的方差。 较大的C:更小的间隔,更低的偏差,更高的方差。
  2. gamma(仅适用于 RBF 核):

    • 控制每个支持向量的影响范围。
    • 较小的 γ :影响范围大,模型更平滑。 较小的 \gamma :影响范围大,模型更平滑。 较小的γ:影响范围大,模型更平滑。
    • 较大的 γ :影响范围小,模型更复杂。 较大的 \gamma :影响范围小,模型更复杂。 较大的γ:影响范围小,模型更复杂。
  3. 核函数选择:根据数据特性选择合适的核函数。


实践部分:使用 SVM 对手写数字数据集进行分类

数据集简介

我们使用 Scikit-learn 提供的 Digits 数据集,包含 1797 张 8x8 像素的手写数字图像(0-9)。目标是对手写数字进行分类。

完整代码

python 复制代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.svm import SVC
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from sklearn.preprocessing import StandardScaler

# 加载数据
digits = datasets.load_digits()
X = digits.data  # 特征矩阵(每张图片展平为 64 维向量)
y = digits.target  # 标签(0-9)

# 数据标准化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

# 分割数据集
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.3, random_state=42)

# 构建 SVM 模型
model = SVC()

# 超参数调优
param_grid = {
    'C': [0.1, 1, 10],
    'gamma': [0.01, 0.1, 1],
    'kernel': ['linear', 'rbf']
}
grid_search = GridSearchCV(model, param_grid, cv=3, scoring='accuracy')
grid_search.fit(X_train, y_train)

# 最佳模型
best_model = grid_search.best_estimator_

# 预测
y_pred = best_model.predict(X_test)

# 评估模型性能
accuracy = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
class_report = classification_report(y_test, y_pred)

print("最佳超参数:", grid_search.best_params_)
print(f"Accuracy: {accuracy:.2f}")
print("Confusion Matrix:")
print(conf_matrix)
print("Classification Report:")
print(class_report)

# 可视化部分测试结果
fig, axes = plt.subplots(2, 5, figsize=(10, 5))
for i, ax in enumerate(axes.ravel()):
    ax.imshow(X_test[i].reshape(8, 8), cmap='gray')
    ax.set_title(f"True: {y_test[i]}\nPred: {y_pred[i]}")
    ax.axis('off')
plt.tight_layout()
plt.show()

运行结果

输出结果:
最佳超参数: {'C': 10, 'gamma': 0.01, 'kernel': 'rbf'}
Accuracy: 0.98
Confusion Matrix:
[[53  0  0  0  0  0  0  0  0  0]
 [ 0 50  0  0  0  0  0  0  0  0]
 [ 0  0 47  0  0  0  0  0  0  0]
 [ 0  0  1 52  0  1  0  0  0  0]
 [ 0  0  0  0 60  0  0  0  0  0]
 [ 0  0  0  0  0 65  0  0  0  1]
 [ 0  0  0  0  0  0 53  0  0  0]
 [ 0  0  0  0  0  0  0 54  0  1]
 [ 0  0  1  1  0  0  0  0 41  0]
 [ 0  0  0  0  0  0  1  0  2 56]]
Classification Report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        53
           1       1.00      1.00      1.00        50
           2       0.96      1.00      0.98        47
           3       0.98      0.96      0.97        54
           4       1.00      1.00      1.00        60
           5       0.98      0.98      0.98        66
           6       0.98      1.00      0.99        53
           7       1.00      0.98      0.99        55
           8       0.95      0.95      0.95        43
           9       0.97      0.95      0.96        59

    accuracy                           0.98       540
   macro avg       0.98      0.98      0.98       540
weighted avg       0.98      0.98      0.98       540

图3:手写数字分类结果

(图片描述:10 张手写数字图片及其真实标签和预测标签,显示模型的分类效果。)


总结

本文介绍了 SVM 的基本原理、核函数的作用以及如何通过超参数调优提升模型性能。通过实践部分,我们成功使用 SVM 对手写数字数据集进行了分类任务,并取得了较高的准确率。

下期预告:第7集:聚类算法------发现数据中的隐藏模式*


参考资料

相关推荐
查理零世20 分钟前
【蓝桥杯集训·每日一题2025】 AcWing 6134. 哞叫时间II python
python·算法·蓝桥杯
悠然的笔记本21 分钟前
机器学习,我们主要学习什么?
机器学习
仟濹21 分钟前
【二分搜索 C/C++】洛谷 P1873 EKO / 砍树
c语言·c++·算法
紫雾凌寒29 分钟前
解锁机器学习核心算法|神经网络:AI 领域的 “超级引擎”
人工智能·python·神经网络·算法·机器学习·卷积神经网络
京东零售技术1 小时前
AI Agent实战:打造京东广告主的超级助手 | 京东零售技术实践
算法
无极工作室(网络安全)1 小时前
机器学习小项目之鸢尾花分类
人工智能·机器学习·分类
MiyamiKK572 小时前
leetcode_位运算 190.颠倒二进制位
python·算法·leetcode
C137的本贾尼2 小时前
解决 LeetCode 串联所有单词的子串问题
算法·leetcode·c#
青橘MATLAB学习2 小时前
时间序列预测实战:指数平滑法详解与MATLAB实现
人工智能·算法·机器学习·matlab
lingllllove2 小时前
matlab二维艾里光束,阵列艾里光束,可改变光束直径以及距离
开发语言·算法·matlab