【机器学习】鸢尾花分类-逻辑回归示例

复制代码
这段代码是一个完整的示例,展示了如何使用逻辑回归对鸢尾花数据集进行训练、保存模型,并允许用户输入数据进行预测。以下是对这段代码的总结:

功能: 这段代码演示了如何使用逻辑回归对鸢尾花数据集进行训练,并将训练好的模型保存到文件中。然后,它允许用户输入新的鸢尾花特征数据,使用保存的模型进行预测,并输出预测结果。

步骤概述:

  1. 加载数据和预处理: 使用 Scikit-Learn 中的 datasets 模块加载鸢尾花数据集,并提取前两个特征。然后,划分数据集为训练集和测试集,并对特征数据进行标准化处理。

  2. 训练和保存模型: 创建逻辑回归模型,并在训练集上训练模型。训练完成后,使用 joblib 库将训练好的模型保存到文件中。

  3. 预测: 使用保存的模型,接受用户输入的鸢尾花特征数据(花萼长度和花萼宽度),将其组织成特征向量,然后进行预测。

  4. 结果输出: 根据预测结果输出对应的分类标签,指示预测的鸢尾花属于 Setosa 类别还是非 Setosa 类别(Versicolor 或 Virginica)。

使用方法: 运行代码后,它会首先训练模型并将其保存。然后,你可以输入新的鸢尾花特征数据以进行预测,系统将输出预测结果。

注意事项: 这个示例使用了 joblib 库来保存和加载模型,你也可以使用其他库如 pickle。此外,这个示例演示了逻辑回归在一个简单数据集上的应用,实际应用中可能需要更多的数据处理、模型调优和评估步骤。

python 复制代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
import joblib  # 用于保存和加载模型
def train_logistic_regression():
    # 加载鸢尾花数据集
    iris = datasets.load_iris()
    # 只使用前两个特征以方便可视化
    X = iris.data[:, :2]  
    # 将标签转换为二分类问题
    y = (iris.target != 0).astype(int)  

    # 划分数据集为训练集和测试集
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    # 特征标准化
    scaler = StandardScaler()
    X_train = scaler.fit_transform(X_train)
    X_test = scaler.transform(X_test)

    # 创建逻辑回归模型
    model = LogisticRegression()

    # 在训练集上训练模型
    model.fit(X_train, y_train)

    # 保存训练好的模型
    joblib.dump(model, 'logistic_regression_model.pkl')

def predict_with_saved_model():
    # 加载保存的模型
    model = joblib.load('logistic_regression_model.pkl')

    # 用户输入特征数据
    sepal_length = float(input("Enter sepal length: "))
    sepal_width = float(input("Enter sepal width: "))
    input_data = np.array([[sepal_length, sepal_width]])

    # 进行预测
    prediction = model.predict(input_data)

    if prediction[0] == 0:
        print("Predicted class: Setosa")
    else:
        print("Predicted class: Non-Setosa (Versicolor or Virginica)")

# 训练模型并保存
train_logistic_regression()

# 使用保存的模型进行预测
predict_with_saved_model()
复制代码
输出结果:

Enter sepal length: 5

Enter sepal width: 7

Predicted class: Non-Setosa (Versicolor or Virginica)

复制代码
备注

在这个示例中,sepal length(花萼长度)和 sepal width(花萼宽度)是用于输入的特征。这些特征是鸢尾花数据集中的两个测量值。这些测量值的单位是厘米(cm)。

对于鸢尾花数据集中的这两个特征,以下是一些参考值范围:

sepal length: 大约为 4.3 至 7.9 厘米。

sepal width: 大约为 2.0 至 4.4 厘米。

请注意,这些参考值是基于鸢尾花数据集的统计信息,并且会根据具体数据而有所变化。当你输入新的花萼长度和花萼宽度值进行预测时,请确保输入的值在合理的范围内。

复制代码
相关推荐
夜松云1 小时前
从对数变换到深度框架:逻辑回归与交叉熵的数学原理及PyTorch实战
pytorch·算法·逻辑回归·梯度下降·交叉熵·对数变换·sigmoid函数
Blossom.1181 小时前
可解释人工智能(XAI):让机器决策透明化
人工智能·驱动开发·深度学习·目标检测·机器学习·aigc·硬件架构
Robot2511 小时前
「地平线」创始人余凯:自动驾驶尚未成熟,人形机器人更无从谈起
人工智能·科技·机器学习·机器人·自动驾驶
深蓝学院2 小时前
开源|上海AILab:自动驾驶仿真平台LimSim Series,兼容端到端/知识驱动/模块化技术路线
人工智能·机器学习·自动驾驶
一点.点2 小时前
LLM应用于自动驾驶方向相关论文整理(大模型在自动驾驶方向的相关研究)
人工智能·深度学习·机器学习·语言模型·自动驾驶·端到端大模型
云天徽上2 小时前
【数据可视化-41】15年NVDA, AAPL, MSFT, GOOGL & AMZ股票数据集可视化分析
人工智能·机器学习·信息可视化·数据挖掘·数据分析
云天徽上3 小时前
【数据可视化-42】杂货库存数据集可视化分析
人工智能·机器学习·信息可视化·数据挖掘·数据分析
自由随风飘3 小时前
机器学习第三篇 模型评估(交叉验证)
人工智能·机器学习
硅谷秋水4 小时前
MANIPTRANS:通过残差学习实现高效的灵巧双手操作迁移
人工智能·深度学习·机器学习·计算机视觉
盼小辉丶6 小时前
PyTorch生成式人工智能实战(3)——分类任务详解
人工智能·pytorch·分类