【Python机器学习】分类器的不确定估计——决策函数

scikit-learn接口的分类器能够给出预测的不确定度估计,一般来说,分类器会预测一个测试点属于哪个类别,还包括它对这个预测的置信程度。

scikit-learn中有两个函数可以用于获取分类器的不确定度估计:decidion_function和predict_proba。

以一个二维数据集为例:

python 复制代码
import mglearn.tools
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.datasets import make_circles
import numpy as np
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt

X,y=make_circles(noise=0.25,factor=0.5,random_state=1)

y_named=np.array(['type0','type1'])[y]
#所有数组的划分方式都是一致的
X_train,X_test,y_train_named,y_test_named,y_train,y_test=train_test_split(
    X,y_named,y,random_state=0
)
#梯度提升模型
gbrt=GradientBoostingClassifier(random_state=0)
gbrt.fit(X_train,y_train_named)

对于二分类的情况,decidion_function返回值的形状是(n_samples,),为每个样本都返回一个浮点数:

python 复制代码
print('X_test形状:{}'.format(X_test.shape))
print('Decision_function 形状:{}'.format(gbrt.decision_function(X_test).shape))

对于类别1来说,值代表模型对数据点属于"正"类的置信程度。正值代表对正类的偏好,负值代表对反类的偏好,还可以通过查看决策值的正负号来展示预测值:

python 复制代码
print('Decision_function:{}'.format(gbrt.decision_function(X_test)[:10]))
print('正负-Decision_function:{}'.format(gbrt.decision_function(X_test)>0))
print('分类:{}'.format(gbrt.predict(X_test)))

对于二分类问题,反类始终是classes_属性的第一个元素,正类是第二个元素,因此,如果想要完全再现predict的输出,需要利用classes_属性:

python 复制代码
greater_zore=(gbrt.decision_function(X_test)>0).astype(int)
pred=gbrt.classes_[greater_zore]
print('索引是否与输出相同:{}'.format(np.all(pred==gbrt.predict(X_test))))

decidion_function可以在任意范围取值,取决于数据和参数模型:

python 复制代码
decision_function=gbrt.decision_function(X_test)
print('decision_function结果的最大值和最小值:{:.3f}、{:.3f}'.format(np.max(decision_function),np.min(decision_function)))

利用颜色编码画出所有点的decidion_function,还有决策边界:

python 复制代码
plt.rcParams['font.sans-serif']=['SimHei']
plt.rcParams['axes.unicode_minus'] = False
fig,axes=plt.subplots(1,2,figsize=(13,5))
mglearn.tools.plot_2d_separator(gbrt,X,ax=axes[0],alpha=.4,fill=True,cm=mglearn.cm2)
scores_image=mglearn.tools.plot_2d_scores(gbrt,X,ax=axes[1],alpha=.4,cm=mglearn.ReBl)
for ax in axes:
    mglearn.discrete_scatter(X_test[:, 0], X_test[:, 1], y_test, markers='^', ax=ax)
    mglearn.discrete_scatter(X_train[:, 0], X_train[:, 1], y_train, markers='o', ax=ax)
    ax.set_xlabel('特征0')
    ax.set_ylabel('特征1')
cbar=plt.colorbar(scores_image,ax=axes.tolist())
axes[0].legend(['测试分类0','测试分类1','训练分类0','训练分类1'],ncol=4,loc=(.1,1.1))
plt.show()
相关推荐
火兮明兮1 分钟前
Python训练第三十天
开发语言·python
是店小二呀2 分钟前
GPUGeek云平台实战:DeepSeek-R1-70B大语言模型一站式部署
人工智能·语言模型·自然语言处理·gpugeek平台
霜羽68922 分钟前
【数据结构篇】排序1(插入排序与选择排序)
数据结构·算法·排序算法
啊我不会诶4 分钟前
CF每日4题(1300-1400)
开发语言·c++·算法
jllllyuz7 分钟前
基于支持向量机(SVM)的P300检测分类
机器学习·支持向量机·分类
JK0x0711 分钟前
代码随想录算法训练营 Day51 图论Ⅱ岛屿问题Ⅰ
算法·深度优先·图论
freyazzr14 分钟前
Leetcode刷题 | Day64_图论09_dijkstra算法
数据结构·c++·算法·leetcode·图论
烦恼归林18 分钟前
永磁同步电机高性能控制算法(22)——基于神经网络的转矩脉动抑制算法&为什么低速时的转速波动大?
人工智能·神经网络·电机·电力电子·电机控制·simulink仿真
珊瑚里的鱼23 分钟前
【滑动窗口】LeetCode 1004题解 | 最大连续1的个数 Ⅲ
开发语言·c++·笔记·算法·leetcode
L_cl35 分钟前
【Python 算法零基础 4.排序 ② 冒泡排序】
数据结构·python·算法