基于 SVM(支持向量机)的手写数字识别

📌 主要步骤

  1. 安装必要的库
  2. 加载数据集(MNIST 手写数字)
  3. 数据预处理
  4. 划分训练集和测试集
  5. 训练 SVM 模型
  6. 评估模型
  7. 预测并可视化结果

1️⃣ 安装必要的库

在开始之前,请确保你的环境安装了以下库:

python 复制代码
pip install numpy pandas scikit-learn matplotlib

2️⃣ 加载数据集

我们将使用 scikit-learn 自带的 digits 数据集,它包含 0-9 的手写数字,每张图片是 8x8 像素的灰度图

python 复制代码
from sklearn import datasets
import matplotlib.pyplot as plt

# 加载手写数字数据集
digits = datasets.load_digits()

# 显示数据集信息
print("数据集形状:", digits.data.shape)
print("标签类别:", digits.target_names)

# 显示前 5 条数据
print("前 5 个标签:", digits.target[:5])

# 可视化前 10 张手写数字
fig, axes = plt.subplots(1, 10, figsize=(10, 3))
for i, ax in enumerate(axes):
    ax.imshow(digits.images[i], cmap='gray')
    ax.set_title(f"Label: {digits.target[i]}")
    ax.axis("off")
plt.show()

3️⃣ 数据预处理

将图片数据转换为一维数组 (从 8x8=64 变成 64 维 的特征向量),以便进行训练。

python 复制代码
import pandas as pd

# 转换为 DataFrame 以便查看
df = pd.DataFrame(digits.data)
df['label'] = digits.target

# 显示前 5 行数据
print(df.head())

4️⃣ 划分训练集和测试集

将数据集划分为 80% 训练集20% 测试集

python 复制代码
from sklearn.model_selection import train_test_split

# 划分数据
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.2, random_state=42)

# 打印数据集大小
print(f"训练集样本数: {len(X_train)}, 测试集样本数: {len(X_test)}")

5️⃣ 训练 SVM(支持向量机)模型

SVM(支持向量机)是一个非常强大的分类算法,在手写数字识别任务中表现优秀。

python 复制代码
from sklearn.svm import SVC

# 创建 SVM 模型
svm_model = SVC(kernel='linear')  # 使用线性核函数

# 训练模型
svm_model.fit(X_train, y_train)

print("SVM 模型训练完成!")

6️⃣ 评估模型

计算模型在测试集上的准确率。

python 复制代码
from sklearn.metrics import accuracy_score

# 预测测试集
y_pred = svm_model.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy:.2f}")

7️⃣ 预测并可视化结果

我们从测试集中选择 10 张手写数字,进行预测并可视化。

python 复制代码
import numpy as np

# 选择前 10 个测试样本
sample_images = X_test[:10]
sample_labels = y_test[:10]

# 进行预测
predictions = svm_model.predict(sample_images)

# 可视化预测结果
fig, axes = plt.subplots(1, 10, figsize=(10, 3))
for i, ax in enumerate(axes):
    ax.imshow(sample_images[i].reshape(8, 8), cmap='gray')
    ax.set_title(f"P:{predictions[i]}\nT:{sample_labels[i]}")
    ax.axis("off")

plt.show()

项目总结

通过这个项目,我们完成了一个 机器学习分类任务 : ✅ 加载 MNIST 数据集

✅ 数据预处理(转换 8x8 图片到 64 维特征向量)

✅ 划分数据集

✅ 训练 SVM 分类器

✅ 评估分类准确率

✅ 预测并可视化结果

相关推荐
浪九天1 小时前
人工智能直通车系列14【机器学习基础】(逻辑回归原理逻辑回归模型实现)
人工智能·深度学习·神经网络·机器学习·自然语言处理
紫雾凌寒4 小时前
计算机视觉应用|自动驾驶的感知革命:多传感器融合架构的技术演进与落地实践
人工智能·机器学习·计算机视觉·架构·自动驾驶·多传感器融合·waymo
安忘4 小时前
LeetCode 热题 -189. 轮转数组
算法·leetcode·职场和发展
Y1nhl4 小时前
力扣hot100_二叉树(4)_python版本
开发语言·pytorch·python·算法·leetcode·机器学习
龚大龙5 小时前
机器学习(李宏毅)——Auto-Encoder
人工智能·机器学习
曼诺尔雷迪亚兹5 小时前
2025年四川烟草工业计算机岗位备考详细内容
数据结构·数据库·计算机网络·算法
蜡笔小新..5 小时前
某些网站访问很卡 or 力扣网站经常进不去(2025/3/10)
算法·leetcode·职场和发展
IT猿手6 小时前
2025最新群智能优化算法:基于RRT的优化器(RRT-based Optimizer,RRTO)求解23个经典函数测试集,MATLAB
开发语言·人工智能·算法·机器学习·matlab
刘大猫266 小时前
五、MyBatis的增删改查模板(参数形式包括:String、对象、集合、数组、Map)
人工智能·算法·智能合约
修己xj6 小时前
算法系列之深度/广度优先搜索解决水桶分水的最优解及全部解
算法