sklearn基础--『分类模型评估』之评估报告

分类模型 评估时,scikit-learn提供了混淆矩阵分类报告 是两个非常实用且常用的工具。

它们为我们提供了详细的信息,帮助我们了解模型的优缺点,从而进一步优化模型。

这两个工具之所以单独出来介绍,是因为它们的输出内容特别适合用在模型的评估报告中。

1. 混淆矩阵

混淆矩阵Confusion Matrix)用于直观地展示模型预测结果与实际标签之间的对应关系。

它是一个表格,其 表示实际的类别标签,而表示模型预测的类别标签。

通过混淆矩阵,可以清晰地看到模型的哪些预测是正确的,哪些是错误的,以及错误预测的具体分布情况。

1.1. 使用示例

下面用手写数字识别 的示例,演示最后如何用混淆矩阵 来可视化的评估模型训练结果的。

首先,读取手写数字数据集(这个数据集是scikit-learn中自带的):

python 复制代码
import matplotlib.pyplot as plt
from sklearn import datasets

# 加载手写数据集
data = datasets.load_digits()

_, axes = plt.subplots(nrows=2, ncols=4, figsize=(10, 6))
for ax, image, label in zip(np.append(axes[0], axes[1]), data.images, data.target):
    ax.set_axis_off()
    ax.imshow(image, cmap=plt.cm.gray_r, interpolation="nearest")
    ax.set_title("目标值: {}".format(label))

然后,用支持向量机来训练数据,得到一个分类模型(reg):

python 复制代码
from sklearn.svm import SVC

n_samples = len(data.images)
X = data.images.reshape((n_samples, -1))
y = data.target

# 定义
reg = SVC()

# 训练模型
reg.fit(X, y)

最后,用得到的分类模型 来预测数据,再用混淆矩阵 来分析预测值真实值

python 复制代码
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# 用训练好的模型进行预测
y_pred = reg.predict(X)

cm = confusion_matrix(y, y_pred)
g = ConfusionMatrixDisplay(confusion_matrix=cm)
g.plot()

plt.show()


混淆矩阵 中,横轴 是预测值,纵轴 是真实值。

对角线上预测值与真实值符合的情况,可以看出模型分类效果不错,大部分数据都能正确分类的。

也有极个别分类错误的情况,比如:

  • 8被识别成1的错误有2个
  • 5被识别成9的错误有1个
  • 9被识别成3的错误有1个
  • ... ... 等等

2. 分类报告

分类报告 提供了模型在各个类别上的详细性能指标。

通常包括准确率Precision)、召回率Recall)、F1分数F1-Score)等评估指标,这些指标能够帮助我们更全面地了解模型的性能。

2.1. 使用示例

基于上面训练的手写数字识别模型,看看模型的各项指标。

python 复制代码
from sklearn.metrics import classification_report

# 这里的y 和 y_pred 是上一节示例中的值
report = classification_report(y, y_pred)
print(report)

报告中列出了手写数字0~9的识别情况。

3. 总结

总的来说,分类报告混淆矩阵 一起使用,能够更全面地评估模型的性能,指导模型的优化和改进。

而且它们生成的评估表格和图形,也能够应用于我们的分析报告中。

关于分类模型的内容可参考之前的文章:

  1. sklearn基础--『监督学习』之K-近邻分类
  2. sklearn基础--『监督学习』之逻辑回归分类
  3. sklearn基础--『监督学习』之贝叶斯分类
  4. sklearn基础--『监督学习』之决策树分类
  5. sklearn基础--『监督学习』之随机森林分类
  6. sklearn基础--『监督学习』之支持向量机分类
相关推荐
彩云回2 小时前
多维尺度分析法(MDS)
人工智能·机器学习·1024程序员节
FL16238631293 小时前
智慧交通红绿灯检测数据集VOC+YOLO格式1215张3类别
深度学习·yolo·机器学习
兩尛5 小时前
神经网络补充知识
人工智能·神经网络·机器学习
电鱼智能的电小鱼6 小时前
基于电鱼 ARM 工控机的煤矿主控系统高可靠运行方案——让井下控制系统告别“死机与重启”
arm开发·人工智能·嵌入式硬件·深度学习·机器学习
长桥夜波7 小时前
机器学习日报09
人工智能·机器学习
TGITCIC7 小时前
通过神经网络手搓一个带finetune功能的手写数字识别来学习“深度神经网络”
人工智能·深度学习·机器学习·卷积神经网络·dnn·文字识别·识别数字
yumgpkpm10 小时前
CMP(类ClouderaCDP7.3(404次编译) )完全支持华为鲲鹏Aarch64(ARM)使用 AI 优化库存水平、配送路线的具体案例及说明
大数据·人工智能·hive·hadoop·机器学习·zookeeper·cloudera
Cathy Bryant10 小时前
智能模型对齐(一致性)alignment
笔记·神经网络·机器学习·数学建模·transformer
南汐汐月12 小时前
重生归来,我要成功 Python 高手--day31 线性回归
python·机器学习·线性回归