MNIST手写数字识别------KNN算法实战
问题背景
手写数字识别是机器学习领域的 "Hello World"。MNIST 数据集包含从零到九的手绘数字灰度图像,共42000张训练图片和28000张测试图片,每张是28×28像素的灰度图,每个像素取值0到255。把这些像素展平成一维向量,就得到784个特征。任务很简单:给一张图,判断它是0到9中的哪个数字。
相关数据集可从kaggle下载- Digit Recognizer
如果不知道如何下载可以跳转至--如何使用Kaggle下载数据集?
更具体的算法原理思路讲解请跳转至--KNN算法详解:从原理到实践入门
数据集说明
| 特征名 | 类型 | 说明 |
|---|---|---|
label |
整数 | 数字标签(0-9),目标变量 |
pixel0~pixel783 |
整数 | 28×28像素的灰度值(0-255) |
代码精讲流程
导入库函数
python
import numpy as np
import pandas as pd
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.metrics import classification_report, accuracy_score
import matplotlib.pyplot as plt
import time
| 代码行 | 功能说明 | 技术要点 |
|---|---|---|
import numpy as np |
导入数值计算库 | numpy能在C层面执行向量化运算,处理42000×784的矩阵时比原生Python列表快两个数量级 |
import pandas as pd |
导入数据分析核心库 | pandas负责CSV的读取和表格操作,DataFrame可以按列名索引,比numpy更适合处理带表头的数据 |
from sklearn.neighbors import KNeighborsClassifier |
导入KNN分类器 | scikit-learn中KNN算法的封装实现 |
from sklearn.model_selection import GridSearchCV |
导入网格搜索工具 | 自动遍历参数网格并用交叉验证评估每种组合 |
from sklearn.preprocessing import StandardScaler |
导入标准化工具 | 负责把每个特征的均值拉到0、标准差拉到1 |
from sklearn.decomposition import PCA |
导入主成分分析 | 通过线性变换把高维数据投影到低维空间,同时保留尽可能多的方差信息 |
from sklearn.metrics import classification_report, accuracy_score |
导入评估指标 | 用于计算准确率、召回率、F1分数等 |
import matplotlib.pyplot as plt |
导入可视化库 | 把784维向量还原成肉眼可辨的图像 |
import time |
导入计时模块 | 记录各阶段耗时,KNN的预测速度往往是瓶颈,没有计时就无法判断优化是否有效 |
加载数据集
python
train_df = pd.read_csv("train.csv")
test_df = pd.read_csv("test.csv")
y_train = train_df["label"].values
X_train = train_df.drop("label", axis=1).values
X_test = test_df.values
| 代码行 | 功能说明 | 技术要点 |
|---|---|---|
train_df = pd.read_csv("train.csv") |
读取训练集数据 | 训练集包含42000条记录,785列(1列标签+784列像素) |
test_df = pd.read_csv("test.csv") |
读取测试集数据 | 测试集包含28000条记录,784列(无标签) |
y_train = train_df["label"].values |
提取目标变量 | 将标签列转换为numpy数组 |
X_train = train_df.drop("label", axis=1).values |
提取特征矩阵 | 删除标签列,将剩余列转换为numpy数组(42000×784) |
X_test = test_df.values |
测试集特征矩阵 | 测试集本身无标签,直接转换为numpy数组(28000×784) |
标准化处理
python
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
| 代码行 | 功能说明 | 技术要点 |
|---|---|---|
scaler = StandardScaler() |
创建标准化器实例 | 计算每个特征的均值(μ)和标准差(σ) |
scaler.fit_transform(X_train) |
训练集:计算μ和σ,并转换 | 先拟合(fit)学习参数,再转换(transform)数据 |
scaler.transform(X_test) |
测试集:用训练集的μ和σ转换 | 重要:只用训练集的统计量,避免数据泄漏 |
为什么需要标准化?
原始像素值范围是0-255。直接用这个范围做KNN的距离计算会有问题:有些像素(如图片中心)方差很大,值在0-255之间大幅波动;有些像素(如四角)几乎永远是0,方差接近0。标准化后每个特征的均值为0、标准差为1,使所有特征在距离计算中具有相同的权重。
PCA降维
python
pca = PCA(n_components=0.95, random_state=42)
X_train_pca = pca.fit_transform(X_train_scaled)
X_test_pca = pca.transform(X_test_scaled)
| 代码行 | 功能说明 | 技术要点 |
|---|---|---|
PCA(n_components=0.95) |
创建PCA实例 | 保留95%的方差信息,PCA会自动计算需要多少个主成分 |
random_state=42 |
设置随机种子 | 保证实验可复现 |
pca.fit_transform(X_train_scaled) |
训练集:拟合并转换 | 学习主成分方向,同时转换数据 |
pca.transform(X_test_scaled) |
测试集:用同一PCA转换 | 使用训练集学习到的主成分方向 |
为什么需要PCA降维?
784个像素中,大量维度是无效噪声------图片四角几乎永远是0,这些维度对区分数字没有贡献,反而干扰距离计算。PCA通过线性变换,把数据投影到方差最大的若干方向上,用更少的维度保留绝大部分信息。
n_components=0.95的含义
这是最关键的参数!它表示保留95%的方差信息,PCA会自动计算需要多少个主成分才能达到95%的方差解释率,剩下5%被当作噪声丢弃。
自动搜索最优参数
python
knn_base = KNeighborsClassifier(n_jobs=-1)
param_grid = {
'n_neighbors': [5, 7, 9],
'weights': ['uniform', 'distance'],
'metric': ['minkowski'],
'p': [1, 2]
}
grid_search = GridSearchCV(
estimator=knn_base,
param_grid=param_grid,
cv=5,
scoring='accuracy',
n_jobs=1,
verbose=2
)
grid_search.fit(X_train_pca, y_train)
| 代码行 | 功能说明 | 技术要点 |
|---|---|---|
KNeighborsClassifier(n_jobs=-1) |
创建KNN基础模型 | n_jobs=-1表示预测时使用所有CPU核心并行加速 |
param_grid = {...} |
定义参数网格 | 包含K值、权重方式、距离度量三类参数 |
'n_neighbors': [5, 7, 9] |
K值候选 | 3个K值:5、7、9 |
'weights': ['uniform', 'distance'] |
权重方式 | uniform均匀投票,distance距离加权投票 |
'metric': ['minkowski'] |
距离度量 | 闵可夫斯基距离,通过p参数控制具体类型 |
'p': [1, 2] |
距离参数 | p=1曼哈顿距离,p=2欧氏距离 |
GridSearchCV(...) |
创建网格搜索实例 | 自动遍历所有参数组合 |
cv=5 |
5折交叉验证 | 评估模型稳定性 |
scoring='accuracy' |
使用准确率作为评估指标 | 分类问题最常用的指标 |
n_jobs=1 |
串行运行 | 重要:避免内存溢出,详见下方说明 |
verbose=2 |
打印详细进度 | 方便观察搜索过程 |
为什么设置
n_jobs=1?最初设为
n_jobs=-1(全核并行),结果报了TerminatedWorkerError------每个并行任务都要复制一份数据,内存不足导致任务被系统杀死。改为n_jobs=1后串行搜索,内存只保留一份数据。注意KNN模型内部的n_jobs=-1仍然保留------它在预测时做数据并行(共享训练数据,只读不复制),不会导致内存问题。
最优模型
python
knn_final = grid_search.best_estimator_
| 代码行 | 功能说明 | 技术要点 |
|---|---|---|
grid_search.best_estimator_ |
获取最优模型 | GridSearchCV在找到最优参数后,会自动用这组参数在整个训练集上重新训练一个模型 |
测试集预测
python
y_test_pred = knn_final.predict(X_test_pca)
submission = pd.DataFrame({
"ImageId": list(range(1, len(y_test_pred) + 1)),
"Label": y_test_pred
})
submission.to_csv("submission.csv", index=False)
| 代码行 | 功能说明 | 技术要点 |
|---|---|---|
knn_final.predict(X_test_pca) |
对测试集进行预测 | 输入的是经过标准化和PCA降维后的测试数据 |
"ImageId": list(range(1, ...)) |
生成图片ID列 | Kaggle约定ImageId从1开始(1-indexed) |
submission.to_csv("submission.csv", index=False) |
保存提交文件 | index=False防止pandas把行索引写进CSV |
完整代码汇总
python
# ============================================================
# MNIST 手写数字识别 ------ KNN
# ============================================================
import numpy as np
import pandas as pd
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.metrics import classification_report, accuracy_score
import matplotlib.pyplot as plt
import time
# 加载数据
print("=" * 50)
print("正在加载数据...")
start_time = time.time()
train_df = pd.read_csv("train.csv")
test_df = pd.read_csv("test.csv")
print(f"训练集大小: {train_df.shape}")
print(f"测试集大小: {test_df.shape}")
print(f"数据加载耗时: {time.time() - start_time:.2f} 秒")
# 分离特征和标签
y_train = train_df["label"].values
X_train = train_df.drop("label", axis=1).values
X_test = test_df.values
# 标准化处理
print("=" * 50)
print("正在进行标准化处理...")
start_time = time.time()
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
print(f"标准化后均值: {X_train_scaled.mean():.6f}")
print(f"标准化后标准差: {X_train_scaled.std():.4f}")
print(f"标准化耗时: {time.time() - start_time:.2f} 秒")
# PCA降维
print("=" * 50)
print("正在进行 PCA 降维...")
start_time = time.time()
pca = PCA(n_components=0.95, random_state=42)
X_train_pca = pca.fit_transform(X_train_scaled)
X_test_pca = pca.transform(X_test_scaled)
print(f"降维后维度: {X_train_pca.shape[1]} (原始: 784)")
print(f"方差保留比例: {pca.explained_variance_ratio_.sum():.4f}")
print(f"PCA 耗时: {time.time() - start_time:.2f} 秒")
# GridSearchCV寻找最优参数
print("=" * 50)
print("正在使用 GridSearchCV 搜索最优参数...")
start_time = time.time()
knn_base = KNeighborsClassifier(n_jobs=-1)
param_grid = {
'n_neighbors': [5, 7, 9],
'weights': ['uniform', 'distance'],
'metric': ['minkowski'],
'p': [1, 2]
}
grid_search = GridSearchCV(
estimator=knn_base,
param_grid=param_grid,
cv=5,
scoring='accuracy',
n_jobs=1,
verbose=2
)
grid_search.fit(X_train_pca, y_train)
print(f"\nGridSearchCV 耗时: {time.time() - start_time:.2f} 秒")
print(f"最优参数: {grid_search.best_params_}")
print(f"最优交叉验证准确率: {grid_search.best_score_:.4f}")
# 获取最优模型
knn_final = grid_search.best_estimator_
# 测试集预测
print("=" * 50)
print("正在对测试集进行预测...")
start_time = time.time()
y_test_pred = knn_final.predict(X_test_pca)
print(f"测试集预测耗时: {time.time() - start_time:.2f} 秒")
# 保存预测结果
submission = pd.DataFrame({
"ImageId": list(range(1, len(y_test_pred) + 1)),
"Label": y_test_pred
})
submission.to_csv("submission.csv", index=False)
print(f"预测结果已保存为 submission.csv")
# 可视化
best_k = grid_search.best_params_['n_neighbors']
best_weights = grid_search.best_params_['weights']
best_p = grid_search.best_params_['p']
fig, axes = plt.subplots(2, 5, figsize=(14, 6))
fig.suptitle(
f"测试集预测结果 (k={best_k}, weights={best_weights}, p={best_p})",
fontsize=14
)
for idx, ax in enumerate(axes.flat):
image = X_test[idx].reshape(28, 28) / 255.0
ax.imshow(image, cmap="gray")
ax.set_title(f"预测: {y_test_pred[idx]}")
ax.axis("off")
plt.tight_layout()
plt.savefig("test_predictions.png", dpi=150)
plt.close()
print("测试集预测图像已保存为 test_predictions.png")
print("=" * 50)
print("全部完成!")
结语:本文详细讲解了使用KNN算法解决MNIST手写数字识别问题的完整流程。KNN是最直观的分类算法,其核心在于"距离"的定义和计算。通过标准化处理和PCA降维,可以有效提升模型的预测性能和速度。希望这篇文章能帮助你掌握机器学习实战的核心技能!
觉得有帮助?点赞收藏关注,后续持续更新机器学习系列文章。
有问题或建议?欢迎在评论区留言讨论。