sklearn包中对于分类问题,如何计算accuracy和roc_auc_score?

1. 基础条件

python 复制代码
import numpy as np
from sklearn import metrics

y_true = np.array([1, 7, 4, 6, 3])
y_prediction = np.array([3, 7, 4, 6, 3])

2. accuracy_score计算

python 复制代码
acc = metrics.accuracy_score(y_true, y_prediction)

这个没问题

3. roc_auc_score计算

The binary and multiclass cases expect labels with shape (n_samples,) while the multilabel case expects binary label indicators with shape (n_samples, n_classes).

因此metrics.roc_auc_score对于multiclasses类的roc_auc_score计算,需要一个二维array,每一列是表示分的每一类,每一行是表示是否为此类。

python 复制代码
from sklearn.preprocessing import OneHotEncoder
enc = OneHotEncoder(sparse=False)
enc.fit(y_true.reshape(-1, 1))
y_true_onehot = enc.transform(y_true.reshape(-1, 1))
y_predictions_onehot = \
    enc.transform(y_prediction.reshape(-1, 1))
bash 复制代码
In [201]: y_true_onehot
Out[201]: 
array([[1., 0., 0., 0., 0.],
       [0., 0., 0., 0., 1.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 1., 0., 0., 0.]])

In [202]: y_predictions_onehot
Out[202]: 
array([[0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 1.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 1., 0., 0., 0.]])
bash 复制代码
In [204]: enc.categories_
Out[204]: [array([1, 3, 4, 6, 7])]

所以结合enc.categories_y_true_onehoty_truey_true_onehot的对应关系如下:

Class 1 3 4 6 7
true value: 1 1
true value: 7 1
true value: 4 1
true value: 6 1
true value: 3 1

因此,对于y_predictiony_prediction_onehot的对应关系就是如下:

Class 1 3 4 6 7
Prediction value: 3 1
Prediction value: 7 1
Prediction value: 4 1
Prediction value: 6 1
Prediction value: 3 1

这就解释了上述y_true_onehoty_prediction_onehot的返回结果。

python 复制代码
ensemble_auc = metrics.roc_auc_score(y_true_onehot,
                                     y_predictions_onehot)
bash 复制代码
In [200]: ensemble_auc
Out[200]: 0.875
相关推荐
搞笑的秀儿8 分钟前
信息新技术
大数据·人工智能·物联网·云计算·区块链
阿里云大数据AI技术26 分钟前
OpenSearch 视频 RAG 实践
数据库·人工智能·llm
遇雪长安31 分钟前
差分定位技术:原理、分类与应用场景
算法·分类·数据挖掘·rtk·差分定位
XMAIPC_Robot39 分钟前
基于ARM+FPGA的光栅尺精密位移加速度测试解决方案
arm开发·人工智能·fpga开发·自动化·边缘计算
加油吧zkf1 小时前
YOLO目标检测数据集类别:分类与应用
人工智能·计算机视觉·目标跟踪
是Dream呀1 小时前
基于连接感知的实时困倦分类图神经网络
神经网络·分类·数据挖掘
Blossom.1181 小时前
机器学习在智能制造业中的应用:质量检测与设备故障预测
人工智能·深度学习·神经网络·机器学习·机器人·tensorflow·sklearn
天天扭码1 小时前
AI时代,前端如何处理大模型返回的多模态数据?
前端·人工智能·面试
难受啊马飞2.01 小时前
如何判断 AI 将优先自动化哪些任务?
运维·人工智能·ai·语言模型·程序员·大模型·大模型学习
顺丰同城前端技术团队1 小时前
掌握未来:构建专属领域的大模型与私有知识库——从部署到微调的全面指南
人工智能·deepseek