机器学习:P-R曲线(Python)

python 复制代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, label_binarize
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import precision_recall_curve, average_precision_score
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis


def data_preproce():
    # 加载数据,数据预处理
    digits = datasets.load_digits()
    X, y = digits.data, digits.target
    random_state = np.random.RandomState(0)
    n_example, n_features = X.shape
    X = np.c_[X, random_state.randn(n_example, 10 * n_features)]  # 添加噪声特征
    X = StandardScaler().fit_transform(X)  # 标准化
    y = label_binarize(y, classes=np.unique(y))  # one-hot
    # 划分数据集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0,
                                                        shuffle=True, stratify=y)
    return X_train, X_test, y_train, y_test


def model_trian(model):
    # 模型训练
    classifier = OneVsRestClassifier(model)
    classifier.fit(X_train, y_train)
    y_score = classifier.decision_function(X_test)
    return y_score


def micor_PR(y_test, y_score):
    # For each class
    precision = dict()
    recall = dict()
    average_precision = dict()
    n_classes = y_score.shape[1]  # 类别数
    for i in range(n_classes):
        # Compute precision-recall pairs for different probability thresholds.
        precision[i], recall[i],_ = precision_recall_curve(y_test[:, i], y_score[:, i])
        # Compute average precision (AP) from prediction scores.
        # \\text{AP} = \\sum_n (R_n - R_{n-1}) P_n
        average_precision[i] = average_precision_score(y_test[:, i], y_score[:, i])
    # A "micro-average": quantifying score on all classes jointly
    precision["micro"], recall["micro"],_ = precision_recall_curve(y_test.ravel(), y_score.ravel())
    average_precision["micro"] = average_precision_score(y_test, y_score, average="micro")
    return precision, recall, average_precision

def plt_PR_curve(precision, recall, average_precision, label):
    # 绘制P-R曲线
    label = label + ': AP={0:0.2f}'.format(average_precision["micro"])
    plt.step(recall['micro'], precision['micro'], where='post', lw=2, label=label)  # 绘制接替图

X_train, X_test, y_train, y_test = data_preproce()
y_score = model_trian(LogisticRegression())
precision, recall, average_precision = micor_PR(y_test, y_score)
plt.figure(figsize=(8, 6))
plt_PR_curve(precision, recall, average_precision, "LogisticRegression")

y_score = model_trian(SVC())
precision, recall, average_precision = micor_PR(y_test, y_score)
plt_PR_curve(precision, recall, average_precision, "svm.SVC")

y_score = model_trian(LinearDiscriminantAnalysis())
precision, recall, average_precision = micor_PR(y_test, y_score)
plt_PR_curve(precision, recall, average_precision, "LinearDiscriminantAnalysis")

plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
plt.xlabel('Recall', fontsize=12)
plt.ylabel('Precision', fontsize=12)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.grid()
plt.title('Average precision score, micro_averaged over all classes', fontsize=14)
plt.legend(fontsize=12)
plt.show()


from itertools import cycle
# setup plot details

y_score = model_trian(LogisticRegression())
precision, recall, average_precision = micor_PR(y_test, y_score)
plt.figure(figsize=(9, 8))
f_scores = np.linspace(0.2, 0.8, num=4)
lines, labels = [], []  # 图例线条和标签值

# F-score 等高线绘制
for f_score in f_scores:
    x = np.linspace(0.01, 1)
    y = f_score * x / (2 * x - f_score)
    line, = plt.plot(x[y >= 0], y[y >= 0], color='gray', ls='--', alpha=0.5)
    # alpha参数用于设置透明度,0~1之间,alpha越小,线条越透明
    plt.annotate('f1={0:0.1f}'.format(f_score), xy=(0.9, y[45] + 0.02))  # annotate注释

lines.append(line)
labels.append('iso-f1 curves')
line, = plt.plot(recall["micro"], precision["micro"], ls="-.", lw=2)
lines.append(line)
labels.append('micro-average(area = {0:0.2f})'.format(average_precision["micro"]))

for i in range(y_score.shape[1]):
    line, = plt.plot(recall[i], precision[i], lw=1.5)
    lines.append(line)
    labels.append('class {0} (area = {1:0.2f})'.format(i, average_precision[i]))

fig = plt.gcf()
fig.subplots_adjust(bottom=0.25)
plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall', fontsize=12)
plt.ylabel('Precision', fontsize=12)
plt.title('Extension of Precision-Recall curve to multi-class', fontsize=14)
plt.legend(lines, labels, loc=(1.02, 0), prop=dict(size=12))

plt.tight_layout()
plt.show()
相关推荐
无垠的广袤6 分钟前
【工业树莓派 CM0 NANO 单板计算机】本地部署 EMQX
linux·python·嵌入式硬件·物联网·树莓派·emqx·工业物联网
m0_372257021 小时前
ID3 算法为什么可以用来优化决策树
算法·决策树·机器学习
艾莉丝努力练剑1 小时前
【Python基础:语法第一课】Python 基础语法详解:变量、类型、动态特性与运算符实战,构建完整的编程基础认知体系
大数据·人工智能·爬虫·python·pycharm·编辑器
gCode Teacher 格码致知1 小时前
Python基础教学:如何拼接字符串?-由Deepseek产生
python
还债大湿兄1 小时前
阿里通义千问调用图像大模型生成轮动漫风格 python调用
开发语言·前端·python
blank@l1 小时前
python测开小工具--日志查询分析工具
python·python接口自动化测试基础·python测试开发·日志查询分析·日志分析统计查询·软件测试工具·argparse模块
Together_CZ1 小时前
Cambrian-S: Towards Spatial Supersensing in Video——迈向视频中的空间超感知
人工智能·机器学习·音视频·spatial·cambrian-s·迈向视频中的空间超感知·supersensing
hu_nil2 小时前
LLMOps-第十三周
python·vllm
空影星2 小时前
轻量日记神器RedNotebook,高效记录每一天
python·数据挖掘·数据分析·音视频
搬砖ing换来金砖2 小时前
Python入门-Task02
开发语言·python