用sklearn运行分类模型,选择AUC最高的模型保存模型权重并绘制AUCROC曲线(以逻辑回归、随机森林、梯度提升、MLP为例)

诸神缄默不语-个人CSDN博文目录

文章目录

  • [1. 导入包](#1. 导入包)
  • [2. 初始化分类模型](#2. 初始化分类模型)
  • [3. 训练、测试模型,绘图,保存指标](#3. 训练、测试模型,绘图,保存指标)

1. 导入包

python 复制代码
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.neural_network import MLPClassifier
from sklearn.metrics import roc_auc_score,accuracy_score,roc_curve,auc
import joblib
import matplotlib.pyplot as plt

2. 初始化分类模型

python 复制代码
classifiers = {
    "Logistic Regression": LogisticRegression(),
    "Random Forest": RandomForestClassifier(),
    "GBDT": GradientBoostingClassifier(),
    "MLP": MLPClassifier(max_iter=1000)
}

3. 训练、测试模型,绘图,保存指标

在这里省略了数据处理部分,总之X/Y都是np.ndarray对象。f反正你创建一个可写的文件流就行,如果连这个都不会的话参考我写的这篇博文:Python3对象序列化,即处理JSON、XML和文件(持续更新ing...)
f.close()没写,根据你的需要如果想加就加。

这个逻辑是每次得到AUC最高的模型就画图,其实感觉把模型权重储存下来然后再joblib.load()再画图会更合适......

如果想对每个模型画ROC曲线叠在一张图上的话,在最前面新建画布(plt.figure()),每个模型运行完后都运行一次plt.plot(),不close()就行。

python 复制代码
max_auc = 0
max_acc = 0
best_classifier = ""
# 训练模型
for lr_name, lr in classifiers.items():
    lr.fit(X_train, y_train)

    # 预测
    y_pred = lr.predict(X_test)
    y_pred_proba = lr.predict_proba(X_test)[:, 1]

    # 评估
    auc_score = roc_auc_score(y_test, y_pred_proba)
    acc = accuracy_score(y_test, y_pred)

    if auc_score > max_auc:
        max_auc = auc_score
        max_acc = acc
        best_classifier = lr_name
        joblib.dump(lr, f"model.pkl")

        fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba)
        roc_auc = auc(fpr, tpr)
        plt.figure()
        plt.plot(
            fpr,
            tpr,
            color="darkorange",
            lw=2,
            label=f"ROC curve (AUC = {roc_auc:.2f})",
        )
        plt.plot(
            [0, 1], [0, 1], color="navy", lw=2, linestyle="--"
        )  # 随机猜测基线
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.title("Receiver Operating Characteristic")
        plt.legend(loc="lower right")
        plt.grid()
        plt.savefig("roc.png")
        plt.close()

    f.write(
        f"{lr_name} AUC: {auc_score:.4f}, ACC: {acc:.4f}"
        + "\n"
    )
    f.flush()

f.write(
    f"best_classifier: {best_classifier} AUC: {max_auc:.4f}, ACC: {max_acc:.4f}"
    + "\n"
)
f.flush()
相关推荐
wuxuand10 小时前
2026时序分类综述A Comprehensive Review of Time Series Classification
人工智能·深度学习·分类·数据挖掘
Fleshy数模11 小时前
基于PyTorch实现食物图像分类:从数据加载到CNN训练全流程
pytorch·分类·cnn
bulingg13 小时前
LR逻辑回归详解
算法·机器学习·逻辑回归
哈哈很哈哈1 天前
逻辑回归Logistic Regression
算法·机器学习·逻辑回归
甄心爱学习1 天前
【极大似然估计/最大化后验】为什么逻辑回归要使用交叉熵损失函数
算法·机器学习·逻辑回归
简简单单做算法1 天前
基于WOA鲸鱼优化的LSTM长短记忆网络模型的文本分类算法matlab仿真
人工智能·分类·lstm·文本分类·woa鲸鱼优化·woa-lstm
放下华子我只抽RuiKe51 天前
机器学习全景指南-进阶篇——解决分类问题的逻辑回归
人工智能·机器学习·分类·逻辑回归·文心一言·ai编程·智能体
啊阿狸不会拉杆2 天前
《计算机视觉:模型、学习和推理》第 18 章-身份与方式模型
人工智能·python·学习·计算机视觉·分类·子空间身份模型·plda
咚咚王者2 天前
人工智能之语言领域 自然语言处理 第五章 文本分类
人工智能·自然语言处理·分类
FriendshipT2 天前
YOLOs-CPP:一个免费开源的YOLO全系列C++推理库(以YOLO26为例)
c++·人工智能·yolo·目标检测·分类·开源