📌 主要步骤
- 安装必要的库
- 加载数据集(MNIST 手写数字)
- 数据预处理
- 划分训练集和测试集
- 训练 SVM 模型
- 评估模型
- 预测并可视化结果
1️⃣ 安装必要的库
在开始之前,请确保你的环境安装了以下库:
python
pip install numpy pandas scikit-learn matplotlib
2️⃣ 加载数据集
我们将使用 scikit-learn
自带的 digits 数据集,它包含 0-9 的手写数字,每张图片是 8x8 像素的灰度图。
python
from sklearn import datasets
import matplotlib.pyplot as plt
# 加载手写数字数据集
digits = datasets.load_digits()
# 显示数据集信息
print("数据集形状:", digits.data.shape)
print("标签类别:", digits.target_names)
# 显示前 5 条数据
print("前 5 个标签:", digits.target[:5])
# 可视化前 10 张手写数字
fig, axes = plt.subplots(1, 10, figsize=(10, 3))
for i, ax in enumerate(axes):
ax.imshow(digits.images[i], cmap='gray')
ax.set_title(f"Label: {digits.target[i]}")
ax.axis("off")
plt.show()
3️⃣ 数据预处理
将图片数据转换为一维数组 (从 8x8=64
变成 64 维
的特征向量),以便进行训练。
python
import pandas as pd
# 转换为 DataFrame 以便查看
df = pd.DataFrame(digits.data)
df['label'] = digits.target
# 显示前 5 行数据
print(df.head())
4️⃣ 划分训练集和测试集
将数据集划分为 80% 训练集 和 20% 测试集。
python
from sklearn.model_selection import train_test_split
# 划分数据
X_train, X_test, y_train, y_test = train_test_split(digits.data, digits.target, test_size=0.2, random_state=42)
# 打印数据集大小
print(f"训练集样本数: {len(X_train)}, 测试集样本数: {len(X_test)}")
5️⃣ 训练 SVM(支持向量机)模型
SVM(支持向量机)是一个非常强大的分类算法,在手写数字识别任务中表现优秀。
python
from sklearn.svm import SVC
# 创建 SVM 模型
svm_model = SVC(kernel='linear') # 使用线性核函数
# 训练模型
svm_model.fit(X_train, y_train)
print("SVM 模型训练完成!")
6️⃣ 评估模型
计算模型在测试集上的准确率。
python
from sklearn.metrics import accuracy_score
# 预测测试集
y_pred = svm_model.predict(X_test)
# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy:.2f}")
7️⃣ 预测并可视化结果
我们从测试集中选择 10 张手写数字,进行预测并可视化。
python
import numpy as np
# 选择前 10 个测试样本
sample_images = X_test[:10]
sample_labels = y_test[:10]
# 进行预测
predictions = svm_model.predict(sample_images)
# 可视化预测结果
fig, axes = plt.subplots(1, 10, figsize=(10, 3))
for i, ax in enumerate(axes):
ax.imshow(sample_images[i].reshape(8, 8), cmap='gray')
ax.set_title(f"P:{predictions[i]}\nT:{sample_labels[i]}")
ax.axis("off")
plt.show()
项目总结
通过这个项目,我们完成了一个 机器学习分类任务 : ✅ 加载 MNIST 数据集
✅ 数据预处理(转换 8x8 图片到 64 维特征向量)
✅ 划分数据集
✅ 训练 SVM 分类器
✅ 评估分类准确率
✅ 预测并可视化结果