代码实现:
# !/usr/bin/python3
# -*- coding: utf-8 -*-
import pandas as pd
import numpy as np
from sklearn.datasets import make_classification
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix,classification_report # 混淆矩阵
from sklearn.metrics import precision_recall_curve,roc_auc_score,roc_curve,auc # ROC曲线等
import matplotlib.pyplot as plt
'''
1. 准备数据,划分训练集和测试集
'''
# 生成一个5000样本量,30个特征,3个分类的数据集
X,y = make_classification(n_samples=5000,n_features=30,n_classes=2,n_informative=2,random_state=10)
'''
make_classification 参数:
n_samples:生成的样本数量,默认为100。
n_features:特征数量,默认为20。
n_informative:信息性特征数量,默认为2。这些特征与输出类别有关。
n_redundant:冗余特征数量,默认为2。这些特征是信息性特征的线性组合。
n_repeated:重复特征数量,默认为0。这些特征是从其他特征中复制的。
n_classes:类别数量,默认为2。
n_clusters_per_class:每个类别中的簇数量,默认为2。
weights:每个类别的样本权重,默认为None。
flip_y:标签翻转概率,默认为0.01,用于增加噪声。
class_sep:类间分离因子,默认为1.0。值越大,类分离越明显。
hypercube:布尔值,指定特征是否在超立方体中生成,默认为True。
shift和scale:用于特征的偏移和平移。
shuffle:布尔值,指定生成数据后是否打乱数据,默认为True。
random_state:随机数生成器的状态或种子,用于确保数据可重复。
返回值包括两个数组:X(形状为[n_samples, n_features]的特征矩阵)和y(形状为[n_samples]的目标向量
'''
# 将数据集划分为训练集和测试集
Xtrain,Xtest,Ytrain,Ytest = train_test_split(X,y,test_size=0.3,random_state=1)
'''
2. 建立模型&评估模型
'''
k_values = [1,3,5,7]
for k in k_values:
clf = KNeighborsClassifier(n_neighbors=k) # 实例化模型
clf = clf.fit(Xtrain,Ytrain,) # 使用训练集训练模型
score = clf.score(Xtest,Ytest) # 看模型在新数据集(测试集)上的预测效果
print(score) # 准确率
# 看测试集上的获得的预测概率
y_prob_1 = clf.predict_proba(Xtest)[:,1]
# print(y_prob_1)
'''
predict_proba返回的是一个n行k列的数组,其中每一行代表一个测试样本,每一列代表一个类别。
例如,对于二分类问题,返回的数组有两列,第一列表示属于第一个类别的概率,第二列表示属于第二个类别的概率。
'''
# print(y_prob_0)
# print(y_prob_1)
# print(y_prob_2)
# print(y_prob)
'''
3. 绘制ROC曲线,PR曲线
'''
# 假正率(FPR)、真正率(TPR)和阈值(thresholds)
'''
假正率(FPR):假正率表示在实际为负例的样本中,被模型错误预测为正例的比例。
真正率(TPR):真正率也称为灵敏度(Sensitivity)或召回率(Recall),它表示在实际为正例的样本中,被模型正确预测为正例的比例。
'''
FPR, TPR, thresholds = roc_curve(Ytest,y_prob_1)
'''
ROC曲线:ROC曲线主要用于衡量二分类器的性能。它以假正率(FPR)为横坐标,真正率(TPR)为纵坐标,绘制出分类器的性能曲线。
ROC曲线越靠近左上角,表示分类器的性能越好
ROC曲线越靠近左上角(0, 1)点,说明分类器的性能越好。
AUC(Area Under the Curve)是ROC 曲线下方的面积,范围在0到1之间,可以理解为模型正确区分正例和反例的能力。
一个完美的分类器的AUC值为1,而一个随机猜测的分类器的AUC值为0.5。
'''
# 计算ROC曲线的参数
ROC_AUC = auc(FPR, TPR)
# 绘制ROC曲线
plt.subplot(1, 2, 2)
plt.plot(FPR, TPR, lw=2, label=f'k={k}, AUC = {ROC_AUC:.2f}')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.legend(loc="lower right")
plt.grid(True)
plt.title('ROC')
plt.xlabel('FPR')
plt.ylabel('TPR')
# plt.show()
'''
PR 曲线,全称为 Precision - Recall 曲线,是用于评估分类模型性能的重要工具
定义与原理:PR 曲线通过绘制精度(Precision)与召回率(Recall)之间的关系曲线,来展示模型在不同阈值下的表现。
精度表示在所有被预测为正类的样本中实际为正类的比例,召回率表示在所有实际为正类的样本中被正确预测为正类的比例。
通过改变分类阈值,可以得到一系列不同的精度和召回率值,将这些值绘制成曲线,就得到了 PR 曲线。
绘制方法:首先计算模型在不同阈值下的精度和召回率。然后,以召回率为横坐标,精度为纵坐标,将各个阈值下对应的点连接起来,形成 PR 曲线。
PR曲线越靠近右上角(1, 1)点,说明分类器的性能越好。
'''
# 绘制PR曲线
# 绘制不同K值的kNN分类器在测试集上的PR曲线,并计算对应的AUC值。
# precision_recall_curve 返回结果依次是precision(精确度)、recall(召回率)和 thresholds(阈值)
precision, recall, thresholds = precision_recall_curve(Ytest,y_prob_1)
pr_auc = auc(recall, precision)
# 绘制PR曲线
plt.subplot(1, 2, 1)
plt.plot(recall, precision, lw=2, label=f'k={k}, AUC = {pr_auc:.2f}')
plt.legend(loc="lower left")
plt.grid(True)
plt.title('PR')
plt.xlabel('recall')
plt.ylabel('precision')
plt.subplots_adjust(hspace=0.3, wspace=0.3) # 调整间距
plt.show()
'''
4. 确定K值,确定最终模型
'''
clf = KNeighborsClassifier(n_neighbors=7) # 实例化模型
clf = clf.fit(Xtrain, Ytrain, ) # 使用训练集训练模型
score = clf.score(Xtest, Ytest) # 看模型在新数据集(测试集)上的预测效果
print(score) # 准确率
print("训练集上的预测准确率为:", clf.score(Xtrain, Ytrain))
print("测试集上的预测准确率为:", clf.score(Xtest, Ytest))
print('混淆矩阵:',confusion_matrix(Ytest, clf.predict(Xtest)))
'''
混淆矩阵是机器学习中用于评估分类模型性能的重要工具。它通过表格形式直观展示模型预测结果与真实标签的对比,帮助分析分类错误的具体类型。
以下是一个二分类混淆矩阵的表格:
实际为正例 实际为反例
预测为正例 真正例(TP) 假正例(FP)
预测为反例 真反例(TN) 真反例(FN)
真正例(True Positive,TP):正确预测为正类的样本数。
假正例(False Positive,FP):实际为负类但被错误预测为正类的样本数。
假反例(False Negative,FN):实际为正类但被错误预测为负类的样本数。
真反例(True Negative,TN):正确预测为负类的样本数。
通过混淆矩阵,我们可以求得准确率、精确率、召回率等性能指标
'''
print('AUC值:',roc_auc_score(Ytest, clf.predict(Xtest)))
print('整体情况:',classification_report(Ytest, clf.predict(Xtest)))