这段代码是一个完整的示例,展示了如何使用逻辑回归对鸢尾花数据集进行训练、保存模型,并允许用户输入数据进行预测。以下是对这段代码的总结:
功能: 这段代码演示了如何使用逻辑回归对鸢尾花数据集进行训练,并将训练好的模型保存到文件中。然后,它允许用户输入新的鸢尾花特征数据,使用保存的模型进行预测,并输出预测结果。
步骤概述:
-
加载数据和预处理: 使用 Scikit-Learn 中的
datasets
模块加载鸢尾花数据集,并提取前两个特征。然后,划分数据集为训练集和测试集,并对特征数据进行标准化处理。 -
训练和保存模型: 创建逻辑回归模型,并在训练集上训练模型。训练完成后,使用
joblib
库将训练好的模型保存到文件中。 -
预测: 使用保存的模型,接受用户输入的鸢尾花特征数据(花萼长度和花萼宽度),将其组织成特征向量,然后进行预测。
-
结果输出: 根据预测结果输出对应的分类标签,指示预测的鸢尾花属于 Setosa 类别还是非 Setosa 类别(Versicolor 或 Virginica)。
使用方法: 运行代码后,它会首先训练模型并将其保存。然后,你可以输入新的鸢尾花特征数据以进行预测,系统将输出预测结果。
注意事项: 这个示例使用了 joblib
库来保存和加载模型,你也可以使用其他库如 pickle
。此外,这个示例演示了逻辑回归在一个简单数据集上的应用,实际应用中可能需要更多的数据处理、模型调优和评估步骤。
python
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
import joblib # 用于保存和加载模型
def train_logistic_regression():
# 加载鸢尾花数据集
iris = datasets.load_iris()
# 只使用前两个特征以方便可视化
X = iris.data[:, :2]
# 将标签转换为二分类问题
y = (iris.target != 0).astype(int)
# 划分数据集为训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 特征标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
# 创建逻辑回归模型
model = LogisticRegression()
# 在训练集上训练模型
model.fit(X_train, y_train)
# 保存训练好的模型
joblib.dump(model, 'logistic_regression_model.pkl')
def predict_with_saved_model():
# 加载保存的模型
model = joblib.load('logistic_regression_model.pkl')
# 用户输入特征数据
sepal_length = float(input("Enter sepal length: "))
sepal_width = float(input("Enter sepal width: "))
input_data = np.array([[sepal_length, sepal_width]])
# 进行预测
prediction = model.predict(input_data)
if prediction[0] == 0:
print("Predicted class: Setosa")
else:
print("Predicted class: Non-Setosa (Versicolor or Virginica)")
# 训练模型并保存
train_logistic_regression()
# 使用保存的模型进行预测
predict_with_saved_model()
输出结果:
Enter sepal length: 5
Enter sepal width: 7
Predicted class: Non-Setosa (Versicolor or Virginica)
备注
在这个示例中,sepal length(花萼长度)和 sepal width(花萼宽度)是用于输入的特征。这些特征是鸢尾花数据集中的两个测量值。这些测量值的单位是厘米(cm)。
对于鸢尾花数据集中的这两个特征,以下是一些参考值范围:
sepal length: 大约为 4.3 至 7.9 厘米。
sepal width: 大约为 2.0 至 4.4 厘米。
请注意,这些参考值是基于鸢尾花数据集的统计信息,并且会根据具体数据而有所变化。当你输入新的花萼长度和花萼宽度值进行预测时,请确保输入的值在合理的范围内。