基于 sklearn 工具和鸢尾花数据集,进行逻辑回归实战

目录

一、前言

二、什么是鸢尾花数据集

三、项目整体流程

四、安装依赖

五、加载鸢尾花数据集

六、划分训练集与测试集

七、数据标准化

八、创建逻辑回归模型

九、模型预测

十、预测概率分析

十一、模型准确率评估

十二、混淆矩阵分析

十三、分类报告分析

Precision

Recall

[F1 Score](#F1 Score)

十四、查看模型参数

十五、完整项目代码

十六、项目执行流程总结

十七、面试高频问题

为什么逻辑回归叫回归却是分类算法?

为什么要标准化数据?

predict和predict_proba区别?

LogisticRegression中的max_iter是什么?

Iris数据集有几个类别?

十八、总结


一、前言

在机器学习入门阶段,逻辑回归(Logistic Regression)几乎是每个开发者都会接触到的第一个分类算法。

虽然名字中带有:

复制代码
Regression(回归)

但实际上:

复制代码
逻辑回归是分类算法

它广泛应用于:

复制代码
垃圾邮件识别

用户流失预测

疾病诊断

广告点击预测

金融风控

等场景。

本文将通过 Python 中最常用的机器学习框架:

复制代码
Scikit-Learn(sklearn)

结合经典的:

复制代码
Iris 鸢尾花数据集

完成一个完整的逻辑回归分类实战。

通过本文你将掌握:

复制代码
sklearn数据集加载

数据集分析

训练集测试集划分

逻辑回归模型训练

模型预测

模型评估

混淆矩阵分析

分类报告分析

二、什么是鸢尾花数据集

Iris 数据集是机器学习领域最经典的数据集之一。

数据来源:

复制代码
1936年

英国统计学家 Ronald Fisher

收集整理。

数据集包含:

复制代码
150条数据

共分为三类鸢尾花:

类别 编号
Setosa 0
Versicolor 1
Virginica 2

每种花:

复制代码
50条样本

包含4个特征:

特征 含义
sepal length 萼片长度
sepal width 萼片宽度
petal length 花瓣长度
petal width 花瓣宽度

目标:

复制代码
根据花朵特征

预测花朵种类

三、项目整体流程

本次实战流程如下:

这是机器学习项目最经典的开发流程。


四、安装依赖

安装 sklearn:

复制代码
pip install scikit-learn

安装完成后测试:

python 复制代码
import sklearn

print(sklearn.__version__)

五、加载鸢尾花数据集

代码如下:

python 复制代码
from sklearn.datasets import load_iris

iris = load_iris()

print(type(iris))

输出:

复制代码
<class 'sklearn.utils._bunch.Bunch'>

查看数据集信息:

python 复制代码
print(iris.keys())

输出:

python 复制代码
dict_keys([
    'data',
    'target',
    'target_names',
    'feature_names',
    'DESCR'
])

查看特征名称:

复制代码
print(iris.feature_names)

输出:

复制代码
[
 'sepal length (cm)',
 'sepal width (cm)',
 'petal length (cm)',
 'petal width (cm)'
]

查看前5条数据:

复制代码
print(iris.data[:5])

查看标签:

复制代码
print(iris.target[:5])

六、划分训练集与测试集

机器学习训练时不能全部用于训练。

需要划分:

复制代码
训练集

测试集

流程:


代码:

python 复制代码
from sklearn.model_selection import train_test_split

X = iris.data
y = iris.target

X_train, X_test, y_train, y_test = train_test_split(
    X,
    y,
    test_size=0.2,
    random_state=42
)

参数说明:

复制代码
test_size=0.2

20%作为测试集

random_state=42

保证每次划分结果一致

查看数据规模:

python 复制代码
print(X_train.shape)
print(X_test.shape)

输出:

复制代码
(120,4)

(30,4)

七、数据标准化

逻辑回归属于距离敏感算法。

通常建议先标准化。

代码:

python 复制代码
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()

X_train = scaler.fit_transform(X_train)

X_test = scaler.transform(X_test)

标准化后:

复制代码
均值≈0

标准差≈1

有助于模型更快收敛。


八、创建逻辑回归模型

导入模型:

python 复制代码
from sklearn.linear_model import LogisticRegression

创建对象:

python 复制代码
model = LogisticRegression(
    max_iter=200
)

参数说明:

复制代码
max_iter

最大迭代次数

开始训练:

复制代码
model.fit(
    X_train,
    y_train
)

训练完成:

复制代码
逻辑回归模型构建成功

九、模型预测

预测测试集:

python 复制代码
y_pred = model.predict(X_test)

print(y_pred)

输出:

复制代码
[
 1 0 2 1 1
 ...
]

查看真实值:

复制代码
print(y_test)

对比即可看到预测情况。


十、预测概率分析

逻辑回归本质上输出的是概率。

代码:

python 复制代码
prob = model.predict_proba(X_test)

print(prob[:5])

输出:

复制代码
[
 [0.98,0.01,0.01],
 [0.01,0.95,0.04],
 [0.01,0.05,0.94]
]

含义:

复制代码
属于各个类别的概率

例如:

复制代码
[0.01,0.05,0.94]

表示:

复制代码
类别0概率1%

类别1概率5%

类别2概率94%

因此预测结果为:

复制代码
类别2

十一、模型准确率评估

计算 Accuracy:

python 复制代码
from sklearn.metrics import accuracy_score

acc = accuracy_score(
    y_test,
    y_pred
)

print(acc)

输出:

复制代码
1.0

表示:

复制代码
100%分类正确

当然不同随机种子下结果可能略有差异。


十二、混淆矩阵分析

混淆矩阵能够查看:

复制代码
哪些类别预测正确

哪些类别预测错误

代码:

python 复制代码
from sklearn.metrics import confusion_matrix

cm = confusion_matrix(
    y_test,
    y_pred
)

print(cm)

输出:

复制代码
[[10 0 0]
 [0 9 0]
 [0 0 11]]

解释:

复制代码
类别0

预测正确10次

类别1

预测正确9次

类别2

预测正确11次

本次测试:

复制代码
没有预测错误

十三、分类报告分析

实际项目中更常用:

复制代码
Classification Report

代码:

python 复制代码
from sklearn.metrics import classification_report

print(
    classification_report(
        y_test,
        y_pred
    )
)

输出:

复制代码
precision

recall

f1-score

support

指标解释:

Precision

精确率

复制代码
预测为正的样本

有多少是真的

Recall

召回率

复制代码
真实正样本

找回来多少

F1 Score

综合指标

复制代码
Precision

+

Recall

十四、查看模型参数

逻辑回归训练后会产生权重参数。

查看系数:

复制代码
print(model.coef_)

输出类似:

复制代码
[
 [ -1.2  1.1 -2.3 -2.0]
 [ 0.5 -0.4  0.3 -0.2]
 [ 0.7 -0.7  2.0  2.2]
]

查看截距:

复制代码
print(model.intercept_)

这些参数决定:

复制代码
模型最终分类结果

十五、完整项目代码

python 复制代码
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix
from sklearn.metrics import classification_report

# 加载数据
iris = load_iris()

X = iris.data
y = iris.target

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(
    X,
    y,
    test_size=0.2,
    random_state=42
)

# 标准化
scaler = StandardScaler()

X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

# 创建模型
model = LogisticRegression(
    max_iter=200
)

# 训练模型
model.fit(
    X_train,
    y_train
)

# 预测
y_pred = model.predict(X_test)

# 准确率
acc = accuracy_score(
    y_test,
    y_pred
)

print("准确率:", acc)

# 混淆矩阵
print(
    confusion_matrix(
        y_test,
        y_pred
    )
)

# 分类报告
print(
    classification_report(
        y_test,
        y_pred
    )
)

十六、项目执行流程总结

整个逻辑回归实战流程如下:


十七、面试高频问题

为什么逻辑回归叫回归却是分类算法?

复制代码
因为底层使用回归思想

最终通过Sigmoid函数输出分类概率

为什么要标准化数据?

复制代码
避免不同特征量纲影响训练结果

predict和predict_proba区别?

复制代码
predict

返回最终类别

predict_proba

返回类别概率

LogisticRegression中的max_iter是什么?

复制代码
最大迭代次数

防止模型无法收敛

Iris数据集有几个类别?

复制代码
3个类别

每类50条数据

总计150条样本

十八、总结

本文基于 sklearn 和经典的 Iris 鸢尾花数据集,实现了一个完整的逻辑回归分类项目。

整个过程包括:

复制代码
数据加载
    ↓
数据划分
    ↓
数据标准化
    ↓
逻辑回归训练
    ↓
结果预测
    ↓
模型评估

通过这个案例,我们不仅掌握了:

复制代码
LogisticRegression

accuracy_score

confusion_matrix

classification_report

等核心工具的使用方法,也理解了一个标准机器学习项目的开发流程。

对于机器学习初学者来说:

Iris + Logistic Regression 是进入机器学习世界的第一块敲门砖。当你能够独立完成这个案例后,就已经具备了继续学习决策树、随机森林、XGBoost、神经网络等高级算法的基础。

相关推荐
财经资讯数据_灵砚智能1 小时前
基于全球经济类多源新闻的NLP情感分析与数据可视化(夜间-次晨)2026年6月5日
大数据·人工智能·python·ai·信息可视化·自然语言处理·灵砚智能
garmin Chen1 小时前
Prompt工程入门:让AI按你的要求工作(2)--Prompt 高阶优化与结构化设计
java·人工智能·python·ai·prompt
AC赳赳老秦1 小时前
用 OpenClaw 整理团队技术分享:自动提取 PPT 内容、生成文字稿、同步到知识库
开发语言·python·自动化·powerpoint·wpf·deepseek·openclaw
澹锦汐1 小时前
AI 重构开发工作流:从 Prompt 工程到智能化研发效能革命
人工智能
编程大师哥1 小时前
推导式和生成器表达式有什么区别?
python
稳如磐石.1 小时前
北京工业计算机
大数据·人工智能·python·物联网
牛栓柱1 小时前
【后端实战】用 Supabase + React/TS 零成本构建高并发 Multi-Agent 服务
前端·数据库·人工智能·后端·react.js·前端框架
暗夜猎手-大魔王1 小时前
转载--Hermes Agent 16 | 扩展机制:General Plugin、Memory Provider、Context Engine 三条扩展线
人工智能
微软技术栈1 小时前
技术速递|面向初学者的 GitHub Copilot CLI:交互模式与非交互模式
人工智能·github·copilot