零基础入门机器学习 -- 第六章支持向量机SVM

6.1 什么是支持向量机?

在机器学习中,我们常常遇到分类问题,比如:

  • 垃圾邮件分类(邮件是垃圾邮件还是正常邮件?)
  • 肿瘤检测(肿瘤是良性还是恶性?)
  • 图像识别(一张图片是猫还是狗?)

支持向量机(Support Vector Machine, SVM)是一种非常强大的分类算法 ,它的核心思想是:
找到一条"最优的分割线"(决策边界),把不同类别的数据分开。


6.2 直观理解:画一条最佳分割线

6.2.1 你是学校的班主任

假设你是学校的一位班主任,你需要组织一次班级活动,你的班级有两种学生:

  1. "学霸"(成绩好)
  2. "学渣"(成绩一般)

你要把这两类学生分开,分别安排不同的座位:

  • 学霸坐左边
  • 学渣坐右边

你可以用一根绳子(相当于SVM中的"分割线")把他们分开。

6.2.2 多种分割方法

你发现有多种方法可以分开他们:

  1. 随便画一条线(可能有点歪,不是最优的)
  2. 尽量让两类人群距离线远一些(更稳健)

SVM的目标是找到那条"最优的线",使得它:

  • 最大限度地远离两边的点(保证分类的稳定性)
  • 尽量正确地分类所有数据点

这条最佳分割线就是SVM 的核心思想


6.3 关键概念

6.3.1 重要的三条线

  1. 分割线(决策边界,Decision Boundary)

    • 这是一条将不同类别分开的线。
  2. 支持向量(Support Vectors)

    • 离分割线最近的点,这些点对分割线的决定影响最大。
  3. 间隔(Margin)

    • 这是分割线到最近数据点的距离 ,SVM 试图最大化这个间隔,以保证分类的稳定性。

📌 示例:看下图,SVM 找到了最佳的分割线,并且保证"最靠近它的点"(支持向量)与分割线的距离最大化。

  O = 学霸
  X = 学渣

  O  O  O     |     X  X  X
  O  O  O  ---|---  X  X  X  ←  这就是 SVM 找到的最佳分割线
  O  O  O     |     X  X  X

6.4 代码示例:手写数字分类

现在,我们用 Python 和 SVM 来进行手写数字识别,这个任务就是要让计算机区分 0~9 之间的手写数字。

6.4.1 准备数据

我们使用 scikit-learn 提供的 手写数字数据集

python 复制代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC  # SVM 分类器
from sklearn.metrics import accuracy_score

# 加载手写数字数据集
digits = datasets.load_digits()

# 显示前5个手写数字
fig, axes = plt.subplots(1, 5, figsize=(10, 3))
for i, ax in enumerate(axes):
    ax.imshow(digits.images[i], cmap=plt.cm.gray_r)
    ax.axis('off')

plt.show()

📌 运行后,你会看到 5 个手写数字的图片!

示例输出:

解释

  • digits = datasets.load_digits() 加载手写数字数据(每个数字是 8x8 的灰度图像)
  • plt.imshow(digits.images[i], cmap=plt.cm.gray_r) 用于显示手写数字图片。

6.4.2 数据预处理

SVM 需要将图片转换为数值格式(每个 8×8 图片变成一个 64 维的向量)。

python 复制代码
# 把图片数据展平成一维
X = digits.images.reshape((len(digits.images), -1))  # 64 维特征
y = digits.target  # 标签(数字0-9)

# 划分训练集和测试集(80% 训练,20% 测试)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

解释

  • X = digits.images.reshape((len(digits.images), -1)) 把 8×8 图像展平成 64 维特征向量。
  • y = digits.target 获取对应的标签(即数字 0~9)。
  • train_test_split() 将数据分成训练集测试集

6.4.3 训练 SVM 模型

python 复制代码
# 创建 SVM 分类器
svm_model = SVC(kernel='linear')  # 使用线性核
svm_model.fit(X_train, y_train)  # 训练模型

# 在测试集上进行预测
y_pred = svm_model.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy:.2f}")

解释

  • SVC(kernel='linear') 创建一个SVM 分类器(使用线性核)。
  • fit(X_train, y_train) 让 SVM 进行训练。
  • predict(X_test) 在测试集上进行预测。
  • accuracy_score(y_test, y_pred) 计算准确率。

📌 运行后,你会看到 SVM 在手写数字分类任务上的准确率!


6.5 重要难点总结

难点 直观解释
为什么 SVM 选择最大间隔? 让分类器更稳健,不容易过拟合
什么是支持向量? 决定分割线位置的关键数据点
为什么有不同的 SVM 核函数? 线性核适用于简单数据,RBF 核适用于复杂数据
SVM 比普通逻辑回归好在哪? 在小数据集上表现更好,能找到最佳边界

6.6 课后练习

  1. 更换不同的 SVM 核函数 ,例如 kernel='rbf' 试试看效果如何?
  2. 调整 C 参数,观察它对 SVM 分类效果的影响。
  3. 尝试用 SVM 进行其他分类任务 (比如鸢尾花数据集 datasets.load_iris())。

6.7 课后讲解

📌 练习 1:更换核函数

python 复制代码
svm_model = SVC(kernel='rbf')  # 试试 RBF 核
svm_model.fit(X_train, y_train)
y_pred = svm_model.predict(X_test)
print(f"RBF 核 SVM 准确率: {accuracy_score(y_test, y_pred):.2f}")

📌 练习 2:调整 C 参数

python 复制代码
svm_model = SVC(kernel='linear', C=0.1)  # C 越小,分类边界越平滑
svm_model.fit(X_train, y_train)

📌 练习 3:鸢尾花分类

python 复制代码
iris = datasets.load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2)
svm_model.fit(X_train, y_train)

你现在掌握了 SVM 的基本概念和实战!如果你有任何问题,可以留言讨论!🚀🚀🚀

相关推荐
带娃的IT创业者17 分钟前
机器学习实战(8):降维技术——主成分分析(PCA)
人工智能·机器学习·分类·聚类
鸡鸭扣40 分钟前
Docker:3、在VSCode上安装并运行python程序或JavaScript程序
运维·vscode·python·docker·容器·js
调皮的芋头41 分钟前
iOS各个证书生成细节
人工智能·ios·app·aigc
paterWang1 小时前
基于 Python 和 OpenCV 的酒店客房入侵检测系统设计与实现
开发语言·python·opencv
东方佑1 小时前
使用Python和OpenCV实现图像像素压缩与解压
开发语言·python·opencv
饮长安千年月2 小时前
Linksys WRT54G路由器溢出漏洞分析–运行环境修复
网络·物联网·学习·安全·机器学习
神秘_博士2 小时前
自制AirTag,支持安卓/鸿蒙/PC/Home Assistant,无需拥有iPhone
arm开发·python·物联网·flutter·docker·gitee
flying robot3 小时前
人工智能基础之数学基础:01高等数学基础
人工智能·机器学习
Moutai码农3 小时前
机器学习-生命周期
人工智能·python·机器学习·数据挖掘
188_djh3 小时前
# 10分钟了解DeepSeek,保姆级部署DeepSeek到WPS,实现AI赋能
人工智能·大语言模型·wps·ai技术·ai应用·deepseek·ai知识