Scikit-learn Python机器学习 - 模型保存及加载

锋哥原创的Scikit-learn Python机器学习视频教程:

https://www.bilibili.com/video/BV11reUzEEPH

课程介绍

本课程主要讲解基于Scikit-learn的Python机器学习知识,包括机器学习概述,特征工程(数据集,特征抽取,特征预处理,特征降维等),分类算法(K-临近算法,朴素贝叶斯算法,决策树等),回归与聚类算法(线性回归,欠拟合,逻辑回归与二分类,K-means算法)等。

Scikit-learn Python机器学习 - 模型保存及加载

在使用 scikit-learn 进行机器学习时,我们通常会需要将训练好的模型保存起来,便于后续的使用或部署。保存模型可以避免每次都重新训练模型,尤其是在模型训练时间较长或训练数据非常大的情况下。scikit-learn 提供了多种保存和加载模型的方法,最常见的是使用 joblibpickle 库。

我们通过joblib来实现模型保存和加载,核心是通过joblib的dump()方法来保存模型,以及通过load()方法来加载模型。

我们找一个前面的随机森林分类算法RandomForestClassifierTest.py来演示下模型保存及加载。

使用模型实例

复制代码
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
​
# 1,加载数据
iris = load_iris()
X = iris.data  # 特征矩阵 (150个样本,4个特征:萼长、萼宽、瓣长、瓣宽)
y = iris.target  # 特征值 目标向量 (3类鸢尾花:0, 1, 2)
​
# 2,数据预处理
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)  # 划分训练集和测试集
​
# 3,创建和训练模型
rfc_model = RandomForestClassifier(n_estimators=100,  # 100棵树
                                   oob_score=True,  # 启用OOB评估
                                   max_depth=3  # 控制树的深度,防止过拟合
                                   )
rfc_model.fit(X_train, y_train)  # 训练模型
​
# 4,进行预测并评估模型
y_pred = rfc_model.predict(X_test)  # 在测试集上进行预测
print('随机森林预测值:', y_pred)
print('正确值       :', y_test)
​
accuracy = accuracy_score(y_test, y_pred)  # 计算准确率
print(f'测试集准确率:{accuracy:.2f}')
print('分类报告:\n', classification_report(y_test, y_pred, target_names=iris.target_names))

模型保存dump

我们新建一个模型保存测试类saveModelTest.py

复制代码
import joblib
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
​
# 1,加载数据
iris = load_iris()
X = iris.data  # 特征矩阵 (150个样本,4个特征:萼长、萼宽、瓣长、瓣宽)
y = iris.target  # 特征值 目标向量 (3类鸢尾花:0, 1, 2)
​
# 2,数据预处理
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)  # 划分训练集和测试集
​
# 3,创建和训练模型
rfc_model = RandomForestClassifier(n_estimators=100,  # 100棵树
                                   oob_score=True,  # 启用OOB评估
                                   max_depth=3  # 控制树的深度,防止过拟合
                                   )
rfc_model.fit(X_train, y_train)  # 训练模型
​
# 保存模型
joblib.dump(rfc_model, 'rfc_model.pkl')
​
# 4,进行预测并评估模型
y_pred = rfc_model.predict(X_test)  # 在测试集上进行预测
print('随机森林预测值:', y_pred)
print('正确值       :', y_test)
​
accuracy = accuracy_score(y_test, y_pred)  # 计算准确率
print(f'测试集准确率:{accuracy:.2f}')
print('分类报告:\n', classification_report(y_test, y_pred, target_names=iris.target_names))

执行完,同级目录会生成一个pkl文件。

模型加载load

我们在新建一个模型加载测试类loadModelTest.py

复制代码
import joblib
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
​
# 1,加载数据
iris = load_iris()
X = iris.data  # 特征矩阵 (150个样本,4个特征:萼长、萼宽、瓣长、瓣宽)
y = iris.target  # 特征值 目标向量 (3类鸢尾花:0, 1, 2)
​
# 2,数据预处理
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)  # 划分训练集和测试集
​
# 加载模型
rfc_model = joblib.load('rfc_model.pkl')
​
# 4,进行预测并评估模型
y_pred = rfc_model.predict(X_test)  # 在测试集上进行预测
print('随机森林预测值:', y_pred)
print('正确值       :', y_test)
​
accuracy = accuracy_score(y_test, y_pred)  # 计算准确率
print(f'测试集准确率:{accuracy:.2f}')
print('分类报告:\n', classification_report(y_test, y_pred, target_names=iris.target_names))

这样就就能重复利用训练好模型。

我们运行测试下:

复制代码
随机森林预测值: [0 2 0 1 2 2 1 2 1 2 2 0 1 0 0 2 0 2 1 0 0 0 2 2 1 0 0 2 2 2]
正确值       : [0 2 0 2 2 2 1 2 1 2 2 0 1 0 0 2 0 2 1 0 0 0 2 1 1 0 0 2 2 2]
测试集准确率:0.93
分类报告:
               precision    recall  f1-score   support
​
      setosa       1.00      1.00      1.00        11
  versicolor       0.83      0.83      0.83         6
   virginica       0.92      0.92      0.92        13
​
    accuracy                           0.93        30
   macro avg       0.92      0.92      0.92        30
weighted avg       0.93      0.93      0.93        30
相关推荐
SteveRocket2 小时前
Python机器学习与数据分析教程之pandas
python·机器学习·数据分析
koo3643 小时前
李宏毅机器学习笔记32
人工智能·笔记·机器学习
材料科学研究3 小时前
机器学习催化剂设计!
深度学习·机器学习·orr·催化剂·催化剂设计·oer
材料科学研究3 小时前
机器学习锂离子电池!预估电池!
深度学习·机器学习·锂离子电池·电池·电池健康·电池管理·电池寿命
长桥夜波3 小时前
机器学习日报04
人工智能·机器学习
bulucc5 小时前
一个简答的意图识别Agent
python·大模型·agent
Lizhihao_5 小时前
Python如何写Selenium全攻略
开发语言·python
m0_738120726 小时前
网络安全编程——TCP客户端以及服务端Python实现
python·tcp/ip·安全·web安全·网络安全
AntBlack6 小时前
不当韭菜 : 好像真有点效果 ,想藏起来自己用了
前端·后端·python