6.1 什么是支持向量机?
在机器学习中,我们常常遇到分类问题,比如:
- 垃圾邮件分类(邮件是垃圾邮件还是正常邮件?)
- 肿瘤检测(肿瘤是良性还是恶性?)
- 图像识别(一张图片是猫还是狗?)
支持向量机(Support Vector Machine, SVM)是一种非常强大的分类算法 ,它的核心思想是:
找到一条"最优的分割线"(决策边界),把不同类别的数据分开。
6.2 直观理解:画一条最佳分割线
6.2.1 你是学校的班主任
假设你是学校的一位班主任,你需要组织一次班级活动,你的班级有两种学生:
- "学霸"(成绩好)
- "学渣"(成绩一般)
你要把这两类学生分开,分别安排不同的座位:
- 学霸坐左边
- 学渣坐右边
你可以用一根绳子(相当于SVM中的"分割线")把他们分开。
6.2.2 多种分割方法
你发现有多种方法可以分开他们:
- 随便画一条线(可能有点歪,不是最优的)
- 尽量让两类人群距离线远一些(更稳健)
SVM的目标是找到那条"最优的线",使得它:
- 最大限度地远离两边的点(保证分类的稳定性)
- 尽量正确地分类所有数据点
这条最佳分割线就是SVM 的核心思想。
6.3 关键概念
6.3.1 重要的三条线
-
分割线(决策边界,Decision Boundary)
- 这是一条将不同类别分开的线。
-
支持向量(Support Vectors)
- 离分割线最近的点,这些点对分割线的决定影响最大。
-
间隔(Margin)
- 这是分割线到最近数据点的距离 ,SVM 试图最大化这个间隔,以保证分类的稳定性。
📌 示例:看下图,SVM 找到了最佳的分割线,并且保证"最靠近它的点"(支持向量)与分割线的距离最大化。
O = 学霸
X = 学渣
O O O | X X X
O O O ---|--- X X X ← 这就是 SVM 找到的最佳分割线
O O O | X X X
6.4 代码示例:手写数字分类
现在,我们用 Python 和 SVM 来进行手写数字识别,这个任务就是要让计算机区分 0~9 之间的手写数字。
6.4.1 准备数据
我们使用 scikit-learn
提供的 手写数字数据集。
python
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC # SVM 分类器
from sklearn.metrics import accuracy_score
# 加载手写数字数据集
digits = datasets.load_digits()
# 显示前5个手写数字
fig, axes = plt.subplots(1, 5, figsize=(10, 3))
for i, ax in enumerate(axes):
ax.imshow(digits.images[i], cmap=plt.cm.gray_r)
ax.axis('off')
plt.show()
📌 运行后,你会看到 5 个手写数字的图片!
示例输出:
✅ 解释:
digits = datasets.load_digits()
加载手写数字数据(每个数字是 8x8 的灰度图像)plt.imshow(digits.images[i], cmap=plt.cm.gray_r)
用于显示手写数字图片。
6.4.2 数据预处理
SVM 需要将图片转换为数值格式(每个 8×8 图片变成一个 64 维的向量)。
python
# 把图片数据展平成一维
X = digits.images.reshape((len(digits.images), -1)) # 64 维特征
y = digits.target # 标签(数字0-9)
# 划分训练集和测试集(80% 训练,20% 测试)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
✅ 解释:
X = digits.images.reshape((len(digits.images), -1))
把 8×8 图像展平成 64 维特征向量。y = digits.target
获取对应的标签(即数字 0~9)。train_test_split()
将数据分成训练集 和测试集。
6.4.3 训练 SVM 模型
python
# 创建 SVM 分类器
svm_model = SVC(kernel='linear') # 使用线性核
svm_model.fit(X_train, y_train) # 训练模型
# 在测试集上进行预测
y_pred = svm_model.predict(X_test)
# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy:.2f}")
✅ 解释:
SVC(kernel='linear')
创建一个SVM 分类器(使用线性核)。fit(X_train, y_train)
让 SVM 进行训练。predict(X_test)
在测试集上进行预测。accuracy_score(y_test, y_pred)
计算准确率。
📌 运行后,你会看到 SVM 在手写数字分类任务上的准确率!
6.5 重要难点总结
难点 | 直观解释 |
---|---|
为什么 SVM 选择最大间隔? | 让分类器更稳健,不容易过拟合 |
什么是支持向量? | 决定分割线位置的关键数据点 |
为什么有不同的 SVM 核函数? | 线性核适用于简单数据,RBF 核适用于复杂数据 |
SVM 比普通逻辑回归好在哪? | 在小数据集上表现更好,能找到最佳边界 |
6.6 课后练习
- 更换不同的 SVM 核函数 ,例如
kernel='rbf'
试试看效果如何? - 调整
C
参数,观察它对 SVM 分类效果的影响。 - 尝试用 SVM 进行其他分类任务 (比如鸢尾花数据集
datasets.load_iris()
)。
6.7 课后讲解
📌 练习 1:更换核函数
python
svm_model = SVC(kernel='rbf') # 试试 RBF 核
svm_model.fit(X_train, y_train)
y_pred = svm_model.predict(X_test)
print(f"RBF 核 SVM 准确率: {accuracy_score(y_test, y_pred):.2f}")
📌 练习 2:调整 C 参数
python
svm_model = SVC(kernel='linear', C=0.1) # C 越小,分类边界越平滑
svm_model.fit(X_train, y_train)
📌 练习 3:鸢尾花分类
python
iris = datasets.load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2)
svm_model.fit(X_train, y_train)
你现在掌握了 SVM 的基本概念和实战!如果你有任何问题,可以留言讨论!🚀🚀🚀