Digits 手写数字识别:随机森林多分类 + 像素级特征热力图


项目 内容
summary 基于 scikit-learn Digits 数据集,使用随机森林训练 10 分类手写数字识别模型,并将 Permutation Importance 映射为 8×8 像素热力图,直观展示模型「看」的是哪里。
description 本文基于 scikit-learn Digits 手写数字数据集,使用随机森林进行 0-9 多分类识别,涵盖混淆矩阵分析、像素级特征重要性热力图和误分类样本可视化,提供完整可复现代码。

本项目由 星枢 支持

星枢官网:https://claudeaihub.cloud/

Digits 手写数字识别:随机森林多分类 + 像素级特征热力图

前两个项目分别做了二分类和回归,这次试试多分类。多分类比二分类更复杂的地方在于:错误不再只是「对或错」,而是「把 A 错认成了 B」。哪些数字容易互相混淆?模型「看」的是图像的哪个区域?这些问题比单纯追求准确率更有意思。

这次用的是 scikit-learn 内置的 Digits 数据集------1,797 张 8×8 的手写数字图片,每张就是 64 个像素值。任务很简单:让模型认出这是 0-9 中的哪个数字。

完整代码:GitHub 仓库

核心结论:

  • Accuracy:96.67%
  • Macro F1:96.61%
  • 误分类:12 / 360(3.33%)
  • 最难认的数字:8(Recall 只有 0.86)
  • 模型最关注的像素:第 21 号像素(8×8 网格中心偏上位置)

1. 数据集

Digits 数据集是 MNIST 的简化版。每张图只有 8×8 = 64 个像素,但已经足以区分 0-9 十个数字。

python 复制代码
from sklearn.datasets import load_digits
digits = load_digits()
# digits.data: (1797, 64) --- 1797 张图,每张 64 个像素
# digits.target: (1797,) --- 0-9 的标签

各类别分布:

数字 样本数
0 178
1 182
2 177
3 183
4 181
5 182
6 181
7 179
8 174
9 180

分布很均衡,不需要处理类别不均衡。训练集 1,437 张,测试集 360 张。

2. 环境准备

多了一个 seaborn 用于画更漂亮的混淆矩阵:

text 复制代码
pandas
numpy
scikit-learn
matplotlib
seaborn
bash 复制代码
git clone https://github.com/coderWang404/xingshuProjects.git
cd xingshuProjects/2026-06-08-digits-multiclass

python -m venv venv
source venv/bin/activate
pip install -r requirements.txt

3. 运行实验

bash 复制代码
python experiments/digits-multiclass/run_experiment.py

4. 建模思路

随机森林天然适合这种结构化特征(64 个像素值就是 64 个数值特征)。参数:

python 复制代码
model = RandomForestClassifier(
    n_estimators=200,
    max_depth=12,
    min_samples_leaf=2,
    random_state=42,
    n_jobs=-1,
)

min_samples_leaf=2 设得很小,因为单张图片的特征空间很紧凑(只有 64 维),稍微深一点的树也不会过拟合。max_depth=12 对于 64 维数据来说绰绰有余------随机森林不需要像神经网络那样几十层。

5. 结果:Accuracy 96.67%,但数字 8 是「困难户」

测试集 360 张图片,只错了 12 张。核心指标:

指标 数值
Accuracy 0.9667
Macro Precision 0.9674
Macro Recall 0.9662
Macro F1 0.9661

按数字拆开看:

数字 Precision Recall F1 Support
0 0.97 0.97 0.97 36
1 0.92 0.97 0.95 36
2 1.00 0.97 0.99 35
3 0.97 1.00 0.99 37
4 0.97 1.00 0.99 36
5 0.97 1.00 0.99 37
6 1.00 0.97 0.99 36
7 0.92 1.00 0.96 36
8 0.94 0.86 0.90 35
9 1.00 0.92 0.96 36

数字 8 是唯一的「困难户」------Recall 只有 0.86,意味着 35 张 8 里面有 5 张被认错了。看混淆矩阵会更清楚:

从图里能读出几个有趣的模式:

  1. 8 最容易被错认成 1(3 次)。其实不难理解:8×8 的分辨率下,一个写得瘦长的 8 确实容易看起来像 1。
  2. 8 还会被错认成 7(2 次)。8 的上半圈和 7 的横折在低分辨率下很容易混淆。
  3. 9 的 Recall 也不高(0.92),有 2 张被错认成 7 和 8。
  4. 数字 2、3、4、5、6 几乎没有错误------它们在 8×8 分辨率下的特征非常鲜明。

上面是部分错误案例的实拍。肉眼也能看出:有些 8 确实写得太像 1 或 7 了,别说是模型,人眼也得仔细看才能分清楚。

6. 像素级特征重要性:模型「看」的是哪里

这是这次实验最有意思的部分。我把 Permutation Importance 映射回了 8×8 的像素网格:

颜色越亮(黄绿色),说明这个像素对模型判断越重要。能看出非常明显的规律:

模型最关注的是图像的中心区域。第 21 号像素(第 2 行第 5 列,中心偏上)的重要性高达 0.024,是第二名的 2.6 倍。这个数字区域恰好是大多数手写数字的「笔画交汇点」------0 的中心空洞、8 的上下连接处、6 的圆环中心都落在这个区域。

边缘像素(最外一圈)的重要性普遍很低。这也符合直觉:手写数字的有效笔画很少延伸到最边缘。

但有一个反直觉的发现:角上的像素(比如第 0、7、56、63 号)虽然重要性低,但几乎没有为零的。这说明即使是边缘,模型也会用到------比如 1 的竖线可能偏左或偏右,7 的横折可能偏上。随机森林通过组合这些「弱信号」来做出最终判断。

7. 输出文件

运行后 experiments/digits-multiclass/outputs/ 会生成:

text 复制代码
metrics.json                  # 完整指标 JSON
classification_report.txt     # 每类分类报告
dataset_profile.csv           # 像素值统计
feature_importance.csv        # 全部像素重要性
class_distribution.png        # 类别分布图
confusion_matrix.png          # 混淆矩阵热图
feature_importance.png        # Top 20 像素条形图
pixel_importance_heatmap.png  # 8×8 像素热力图 ⭐
misclassified_samples.png     # 错误案例实拍
summary.md                    # 实验摘要

8. 总结

这个实验让我对「多分类」和「图像特征的模型解释」有了更直观的理解:

  • Accuracy 96.67% 在多分类里已经很高了,但拆到每类就会发现 8 是明显的短板。
  • 混淆矩阵比 Accuracy 更有信息量------它告诉你模型到底在「哪里犯错」,而不是简单告诉你「对了多少」。
  • 像素热力图是最有 insight 的可视化------它把抽象的「特征重要性」变成了肉眼可见的「模型在看哪里」。
  • 低分辨率图像(8×8)对随机森林已经很友好了。如果换成 28×28 的 MNIST,特征数会变成 784,随机森林的性能会下降,这时候就该 CNN 出场了。

如果想拿它当学习模板,建议试试这几个改动:

  • max_depth 从 12 改成 5 或 20,观察 Accuracy 和混淆矩阵的变化
  • 用 Support Vector Machine(SVM)跑一遍同样的数据,对比两种模型的错误模式
  • 把 8×8 图像可视化出来,对比模型认为「重要」的像素和实际笔画的重合度

完整代码与实验脚本:GitHub 仓库

相关推荐
liulilittle2 小时前
拥塞控制:排水终止的两种决策:OR 与 AND
网络·tcp/ip·计算机网络·算法·信息与通信·tcp·通信
weixin_307779132 小时前
从脚本执行到智能体协作:AI辅助测试能力的范式重构
运维·开发语言·人工智能·算法·测试用例
量化君也2 小时前
从回测到全自动实盘交易,全天候策略需要经历哪些改造?
大数据·人工智能·python·算法·金融
fox_lht3 小时前
第十五章 函数式语言:迭代器和闭包
开发语言·后端·学习·算法·rust
zhengzhouliuhaha4 小时前
智能医疗设备控费系统:以全院一体化管控,筑牢医疗资源“安全阀”
大数据·数据结构·人工智能·算法·安全·机器学习·软件需求
June`4 小时前
CUDA程序效率如何计算以及工具如何使用
算法·cuda
兰令水4 小时前
leecodecode【树形DP】【2026.6.11打卡-java版本】
java·算法·深度优先
装不满的克莱因瓶4 小时前
RLHF中的PPO算法——大语言模型对齐优化的核心引擎
人工智能·python·深度学习·算法·机器学习·语言模型·自然语言处理
SWAGGY..4 小时前
Linux系统编程:(十三)环境变量
java·linux·算法