Scikit-learn Python机器学习 - 模型保存及加载

锋哥原创的Scikit-learn Python机器学习视频教程:

https://www.bilibili.com/video/BV11reUzEEPH

课程介绍

本课程主要讲解基于Scikit-learn的Python机器学习知识,包括机器学习概述,特征工程(数据集,特征抽取,特征预处理,特征降维等),分类算法(K-临近算法,朴素贝叶斯算法,决策树等),回归与聚类算法(线性回归,欠拟合,逻辑回归与二分类,K-means算法)等。

Scikit-learn Python机器学习 - 模型保存及加载

在使用 scikit-learn 进行机器学习时,我们通常会需要将训练好的模型保存起来,便于后续的使用或部署。保存模型可以避免每次都重新训练模型,尤其是在模型训练时间较长或训练数据非常大的情况下。scikit-learn 提供了多种保存和加载模型的方法,最常见的是使用 joblibpickle 库。

我们通过joblib来实现模型保存和加载,核心是通过joblib的dump()方法来保存模型,以及通过load()方法来加载模型。

我们找一个前面的随机森林分类算法RandomForestClassifierTest.py来演示下模型保存及加载。

使用模型实例

复制代码
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
​
# 1,加载数据
iris = load_iris()
X = iris.data  # 特征矩阵 (150个样本,4个特征:萼长、萼宽、瓣长、瓣宽)
y = iris.target  # 特征值 目标向量 (3类鸢尾花:0, 1, 2)
​
# 2,数据预处理
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)  # 划分训练集和测试集
​
# 3,创建和训练模型
rfc_model = RandomForestClassifier(n_estimators=100,  # 100棵树
                                   oob_score=True,  # 启用OOB评估
                                   max_depth=3  # 控制树的深度,防止过拟合
                                   )
rfc_model.fit(X_train, y_train)  # 训练模型
​
# 4,进行预测并评估模型
y_pred = rfc_model.predict(X_test)  # 在测试集上进行预测
print('随机森林预测值:', y_pred)
print('正确值       :', y_test)
​
accuracy = accuracy_score(y_test, y_pred)  # 计算准确率
print(f'测试集准确率:{accuracy:.2f}')
print('分类报告:\n', classification_report(y_test, y_pred, target_names=iris.target_names))

模型保存dump

我们新建一个模型保存测试类saveModelTest.py

复制代码
import joblib
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
​
# 1,加载数据
iris = load_iris()
X = iris.data  # 特征矩阵 (150个样本,4个特征:萼长、萼宽、瓣长、瓣宽)
y = iris.target  # 特征值 目标向量 (3类鸢尾花:0, 1, 2)
​
# 2,数据预处理
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)  # 划分训练集和测试集
​
# 3,创建和训练模型
rfc_model = RandomForestClassifier(n_estimators=100,  # 100棵树
                                   oob_score=True,  # 启用OOB评估
                                   max_depth=3  # 控制树的深度,防止过拟合
                                   )
rfc_model.fit(X_train, y_train)  # 训练模型
​
# 保存模型
joblib.dump(rfc_model, 'rfc_model.pkl')
​
# 4,进行预测并评估模型
y_pred = rfc_model.predict(X_test)  # 在测试集上进行预测
print('随机森林预测值:', y_pred)
print('正确值       :', y_test)
​
accuracy = accuracy_score(y_test, y_pred)  # 计算准确率
print(f'测试集准确率:{accuracy:.2f}')
print('分类报告:\n', classification_report(y_test, y_pred, target_names=iris.target_names))

执行完,同级目录会生成一个pkl文件。

模型加载load

我们在新建一个模型加载测试类loadModelTest.py

复制代码
import joblib
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
​
# 1,加载数据
iris = load_iris()
X = iris.data  # 特征矩阵 (150个样本,4个特征:萼长、萼宽、瓣长、瓣宽)
y = iris.target  # 特征值 目标向量 (3类鸢尾花:0, 1, 2)
​
# 2,数据预处理
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)  # 划分训练集和测试集
​
# 加载模型
rfc_model = joblib.load('rfc_model.pkl')
​
# 4,进行预测并评估模型
y_pred = rfc_model.predict(X_test)  # 在测试集上进行预测
print('随机森林预测值:', y_pred)
print('正确值       :', y_test)
​
accuracy = accuracy_score(y_test, y_pred)  # 计算准确率
print(f'测试集准确率:{accuracy:.2f}')
print('分类报告:\n', classification_report(y_test, y_pred, target_names=iris.target_names))

这样就就能重复利用训练好模型。

我们运行测试下:

复制代码
随机森林预测值: [0 2 0 1 2 2 1 2 1 2 2 0 1 0 0 2 0 2 1 0 0 0 2 2 1 0 0 2 2 2]
正确值       : [0 2 0 2 2 2 1 2 1 2 2 0 1 0 0 2 0 2 1 0 0 0 2 1 1 0 0 2 2 2]
测试集准确率:0.93
分类报告:
               precision    recall  f1-score   support
​
      setosa       1.00      1.00      1.00        11
  versicolor       0.83      0.83      0.83         6
   virginica       0.92      0.92      0.92        13
​
    accuracy                           0.93        30
   macro avg       0.92      0.92      0.92        30
weighted avg       0.93      0.93      0.93        30
相关推荐
睿思达DBA_WGX2 小时前
使用 python-docx 库操作 word 文档(1):文件操作
开发语言·python·word
jackylzh4 小时前
深度学习中, WIN32为 Windows API 标识,匹配 Windows 系统,含 32/64 位
人工智能·python·深度学习
LateFrames5 小时前
用 【C# + Winform + MediaPipe】 实现人脸468点识别
python·c#·.net·mediapipe
人工干智能8 小时前
科普:Python 中,字典的“动态创建键”特性
开发语言·python
开心-开心急了11 小时前
主窗口(QMainWindow)如何放入文本编辑器(QPlainTextEdit)等继承自QWidget的对象--(重构版)
python·ui·pyqt
moshumu112 小时前
局域网访问Win11下的WSL中的jupyter notebook
ide·python·深度学习·神经网络·机器学习·jupyter
大饼酥12 小时前
吴恩达机器学习笔记(10)—支持向量机
机器学习·支持向量机·吴恩达·高斯核函数
计算机毕设残哥13 小时前
基于Hadoop+Spark的人体体能数据分析与可视化系统开源实现
大数据·hadoop·python·scrapy·数据分析·spark·dash
芒果量化13 小时前
ML4T - 第7章第8节 利用LR预测股票价格走势Predicting stock price moves with Logistic Regression
算法·机器学习·线性回归