机器学习:P-R曲线(Python)

python 复制代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, label_binarize
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import precision_recall_curve, average_precision_score
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis


def data_preproce():
    # 加载数据,数据预处理
    digits = datasets.load_digits()
    X, y = digits.data, digits.target
    random_state = np.random.RandomState(0)
    n_example, n_features = X.shape
    X = np.c_[X, random_state.randn(n_example, 10 * n_features)]  # 添加噪声特征
    X = StandardScaler().fit_transform(X)  # 标准化
    y = label_binarize(y, classes=np.unique(y))  # one-hot
    # 划分数据集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0,
                                                        shuffle=True, stratify=y)
    return X_train, X_test, y_train, y_test


def model_trian(model):
    # 模型训练
    classifier = OneVsRestClassifier(model)
    classifier.fit(X_train, y_train)
    y_score = classifier.decision_function(X_test)
    return y_score


def micor_PR(y_test, y_score):
    # For each class
    precision = dict()
    recall = dict()
    average_precision = dict()
    n_classes = y_score.shape[1]  # 类别数
    for i in range(n_classes):
        # Compute precision-recall pairs for different probability thresholds.
        precision[i], recall[i],_ = precision_recall_curve(y_test[:, i], y_score[:, i])
        # Compute average precision (AP) from prediction scores.
        # \\text{AP} = \\sum_n (R_n - R_{n-1}) P_n
        average_precision[i] = average_precision_score(y_test[:, i], y_score[:, i])
    # A "micro-average": quantifying score on all classes jointly
    precision["micro"], recall["micro"],_ = precision_recall_curve(y_test.ravel(), y_score.ravel())
    average_precision["micro"] = average_precision_score(y_test, y_score, average="micro")
    return precision, recall, average_precision

def plt_PR_curve(precision, recall, average_precision, label):
    # 绘制P-R曲线
    label = label + ': AP={0:0.2f}'.format(average_precision["micro"])
    plt.step(recall['micro'], precision['micro'], where='post', lw=2, label=label)  # 绘制接替图

X_train, X_test, y_train, y_test = data_preproce()
y_score = model_trian(LogisticRegression())
precision, recall, average_precision = micor_PR(y_test, y_score)
plt.figure(figsize=(8, 6))
plt_PR_curve(precision, recall, average_precision, "LogisticRegression")

y_score = model_trian(SVC())
precision, recall, average_precision = micor_PR(y_test, y_score)
plt_PR_curve(precision, recall, average_precision, "svm.SVC")

y_score = model_trian(LinearDiscriminantAnalysis())
precision, recall, average_precision = micor_PR(y_test, y_score)
plt_PR_curve(precision, recall, average_precision, "LinearDiscriminantAnalysis")

plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
plt.xlabel('Recall', fontsize=12)
plt.ylabel('Precision', fontsize=12)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.grid()
plt.title('Average precision score, micro_averaged over all classes', fontsize=14)
plt.legend(fontsize=12)
plt.show()


from itertools import cycle
# setup plot details

y_score = model_trian(LogisticRegression())
precision, recall, average_precision = micor_PR(y_test, y_score)
plt.figure(figsize=(9, 8))
f_scores = np.linspace(0.2, 0.8, num=4)
lines, labels = [], []  # 图例线条和标签值

# F-score 等高线绘制
for f_score in f_scores:
    x = np.linspace(0.01, 1)
    y = f_score * x / (2 * x - f_score)
    line, = plt.plot(x[y >= 0], y[y >= 0], color='gray', ls='--', alpha=0.5)
    # alpha参数用于设置透明度,0~1之间,alpha越小,线条越透明
    plt.annotate('f1={0:0.1f}'.format(f_score), xy=(0.9, y[45] + 0.02))  # annotate注释

lines.append(line)
labels.append('iso-f1 curves')
line, = plt.plot(recall["micro"], precision["micro"], ls="-.", lw=2)
lines.append(line)
labels.append('micro-average(area = {0:0.2f})'.format(average_precision["micro"]))

for i in range(y_score.shape[1]):
    line, = plt.plot(recall[i], precision[i], lw=1.5)
    lines.append(line)
    labels.append('class {0} (area = {1:0.2f})'.format(i, average_precision[i]))

fig = plt.gcf()
fig.subplots_adjust(bottom=0.25)
plt.plot([0, 1], [0, 1], color='navy', linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall', fontsize=12)
plt.ylabel('Precision', fontsize=12)
plt.title('Extension of Precision-Recall curve to multi-class', fontsize=14)
plt.legend(lines, labels, loc=(1.02, 0), prop=dict(size=12))

plt.tight_layout()
plt.show()
相关推荐
江_小_白40 分钟前
自动驾驶之激光雷达
人工智能·机器学习·自动驾驶
湫ccc3 小时前
《Python基础》之字符串格式化输出
开发语言·python
mqiqe4 小时前
Python MySQL通过Binlog 获取变更记录 恢复数据
开发语言·python·mysql
AttackingLin4 小时前
2024强网杯--babyheap house of apple2解法
linux·开发语言·python
哭泣的眼泪4084 小时前
解析粗糙度仪在工业制造及材料科学和建筑工程领域的重要性
python·算法·django·virtualenv·pygame
湫ccc5 小时前
《Python基础》之基本数据类型
开发语言·python
IT古董5 小时前
【机器学习】机器学习中用到的高等数学知识-8. 图论 (Graph Theory)
人工智能·机器学习·图论
drebander6 小时前
使用 Java Stream 优雅实现List 转化为Map<key,Map<key,value>>
java·python·list
莫叫石榴姐6 小时前
数据科学与SQL:组距分组分析 | 区间分布问题
大数据·人工智能·sql·深度学习·算法·机器学习·数据挖掘
威威猫的栗子6 小时前
Python Turtle召唤童年:喜羊羊与灰太狼之懒羊羊绘画
开发语言·python