逻辑回归的多分类实战:以鸢尾花数据集为例

文章目录

引言:从二分类到多分类

  • 逻辑回归是机器学习中最基础也最重要的算法之一,但初学者常常困惑:逻辑回归明明是二分类算法,如何能处理多分类问题呢?本文将带你深入了解逻辑回归的多分类策略,并通过完整的鸢尾花分类代码实现。

一、多分类问题无处不在

  • 在我们的日常生活和工作中,多分类问题比比皆是:
  • 邮件分类:工作(y=1)、朋友(y=2)、家庭(y=3)、爱好(y=4)
  • 天气预测:晴天(y=1)、多云(y=2)、雨天(y=3)、雪天(y=4)
  • 医疗诊断:健康(y=1)、感冒(y=2)、流感(y=3)

这些场景都需要算法能够区分多个类别,而逻辑回归通过巧妙的扩展就能胜任这些任务。

二、One-vs-All策略揭秘

1. 核心思想

  • One-vs-All(一对多,也称为One-vs-Rest)策略将多分类问题转化为多个二分类问题:
  1. 对于N个类别,训练N个独立的二分类器
  2. 第i个分类器将第i类作为正类,其余所有类别作为负类
  3. 预测时,选择所有分类器中预测概率最高的类别

2. 数学表达

对于第i类,我们的假设函数为:
h θ ( i ) ( x ) = P ( y = i ∣ x ; θ ) h_\theta^{(i)}(x) = P(y = i|x;\theta) hθ(i)(x)=P(y=i∣x;θ)

预测时选择:
max ⁡ i h θ ( i ) ( x ) \max_i h_\theta^{(i)}(x) imaxhθ(i)(x)

三、鸢尾花分类完整实现

  • 使用Python和scikit-learn库完整实现鸢尾花的多分类任务。

1. 环境准备

python 复制代码
# 导入必要的库
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import (accuracy_score,
                            confusion_matrix,
                            classification_report)
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.multiclass import OneVsRestClassifier
from sklearn.preprocessing import StandardScaler

2. 数据加载与探索

python 复制代码
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']  # 或 'Microsoft YaHei'
plt.rcParams['axes.unicode_minus'] = False

# 加载鸢尾花数据集
iris = load_iris()
# 特征名称: ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
# sepal length:花萼长度 sepal width:花萼宽度 petal length: 花瓣长度 petal width: 花瓣宽度
X = iris.data  # 特征矩阵 (150, 4)
# 目标类别: ['setosa' 'versicolor' 'virginica']
# setosa 山鸢尾 versicolor 变色鸢尾 virginica 维吉尼亚鸢尾
y = iris.target  # 标签 (150,)

# 查看特征名称和目标类别
print("特征名称:", iris.feature_names)
print("目标类别:", iris.target_names)

# 将数据转换为DataFrame便于可视化
iris_df = pd.DataFrame(X, columns=iris.feature_names)
iris_df['species'] = y
iris_df['species'] = iris_df['species'].map({0: 'setosa', 1: 'versicolor', 2: 'virginica'})

# 绘制特征分布图
sns.pairplot(iris_df, hue='species', palette='husl')
plt.suptitle("鸢尾花特征分布", y=1.02)
plt.show()
bash 复制代码
特征名称: ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
目标类别: ['setosa' 'versicolor' 'virginica']

3. 数据预处理

python 复制代码
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

print(f"训练集样本数: {len(X_train)}")
print(f"测试集样本数: {len(X_test)}")

# 特征标准化(逻辑回归通常需要)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
bash 复制代码
训练集样本数: 120
测试集样本数: 30

4. 模型训练与评估

python 复制代码
# 构建逻辑回归模型
log_reg = LogisticRegression(
    C=1000,                # 正则化强度的倒数
    solver='sag',          # 随机平均梯度下降
    max_iter=1000,         # 最大迭代次数
    random_state=42
)
# 使用 OneVsRestClassifier 包装
ovr_classifier = OneVsRestClassifier(log_reg)

# 训练模型
ovr_classifier.fit(X_train, y_train)

# 在训练集和测试集上评估
train_acc = ovr_classifier.score(X_train, y_train)
test_acc = ovr_classifier.score(X_test, y_test)

print(f"训练集准确率: {train_acc:.2%}")
print(f"测试集准确率: {test_acc:.2%}")

# 更详细的评估报告
y_pred = ovr_classifier.predict(X_test)
print("\n分类报告:")
print(classification_report(y_test, y_pred, target_names=iris.target_names))

# 绘制混淆矩阵
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(6,6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=iris.target_names,
            yticklabels=iris.target_names)
plt.xlabel('预测标签')
plt.ylabel('真实标签')
plt.title('混淆矩阵')
plt.show()
bash 复制代码
训练集准确率: 96.67%
测试集准确率: 96.67%

分类报告:
              precision    recall  f1-score   support

      setosa       1.00      1.00      1.00        10
  versicolor       1.00      0.90      0.95        10
   virginica       0.91      1.00      0.95        10

    accuracy                           0.97        30
   macro avg       0.97      0.97      0.97        30
weighted avg       0.97      0.97      0.97        30

5. 决策边界可视化

python 复制代码
# 为可视化,只使用两个主要特征
X_train_2d = X_train[:, :2]
X_test_2d = X_test[:, :2]

# 重新训练一个2D模型
log_reg_2d = LogisticRegression(
    C=1000,
    solver='sag',
    max_iter=2000,
    random_state=42
)

ovr_classifier_2d = OneVsRestClassifier(log_reg_2d)
ovr_classifier_2d.fit(X_train_2d, y_train)  # 必须先训练模型

# 创建网格点
x_min, x_max = X_train_2d[:, 0].min() - 1, X_train_2d[:, 0].max() + 1
y_min, y_max = X_train_2d[:, 1].min() - 1, X_train_2d[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02), np.arange(y_min, y_max, 0.02))

# 预测每个网格点的类别
Z = ovr_classifier_2d.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

# 预测每个网格点的类别
Z = ovr_classifier_2d.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

# 绘制决策边界
plt.figure(figsize=(10, 6))
plt.contourf(xx, yy, Z, alpha=0.4, cmap='Pastel2')
scatter = plt.scatter(X_train_2d[:, 0], X_train_2d[:, 1], c=y_train,
                      cmap='Dark2', edgecolor='black')

# 添加图例和标签
legend_elements = scatter.legend_elements()[0]
plt.legend(legend_elements,iris.target_names,title="鸢尾花种类")
plt.xlabel(iris.feature_names[0])
plt.ylabel(iris.feature_names[1])
plt.title("逻辑回归多分类决策边界")
plt.show()

四、关键参数解析

在构建逻辑回归模型时的重要参数:

  1. C=1000:正则化强度的倒数,较小的值表示更强的正则化。这里设为较大的值,相当于减少正则化。
  2. multi_class='ovr':指定使用One-vs-Rest策略处理多分类问题。scikit-learn还支持'multinomial'选项,使用softmax函数直接进行多分类。
  3. solver='sag':优化算法选择随机平均梯度下降(Stochastic Average Gradient),适合大数据集。其他可选算法包括:'liblinear':适合小数据集;'newton-cg':牛顿法;'lbfgs':拟牛顿法。
  4. max_iter=1000:最大迭代次数,确保模型能够收敛。

五、总结

  1. 逻辑回归如何通过One-vs-All策略处理多分类问题
  2. 完整的鸢尾花分类实现流程
  3. 模型评估与可视化方法
相关推荐
侯小啾8 小时前
【03】C语言 强制类型转换 与 进制转换
c语言·数据结构·算法
Xの哲學8 小时前
Linux NAPI 架构详解
linux·网络·算法·架构·边缘计算
京东零售技术11 小时前
扛起技术大梁的零售校招生们 | 1024技术人特别篇
算法
爱coding的橙子12 小时前
每日算法刷题Day78:10.23:leetcode 一般树7道题,用时1h30min
算法·leetcode·深度优先
Swift社区12 小时前
LeetCode 403 - 青蛙过河
算法·leetcode·职场和发展
地平线开发者12 小时前
三种 Badcase 精度验证方案详解与 hbm_infer 部署实录
算法·自动驾驶
papership12 小时前
【入门级-算法-5、数值处理算法:高精度的减法】
算法·1024程序员节
lingran__12 小时前
算法沉淀第十天(牛客2025秋季算法编程训练联赛2-基础组 和 奇怪的电梯)
c++·算法
DuHz12 小时前
基于MIMO FMCW雷达的二维角度分析多径抑制技术——论文阅读
论文阅读·物联网·算法·信息与通信·毫米波雷达
Dragon_D.13 小时前
排序算法大全——插入排序
算法·排序算法·c·学习方法