逻辑回归的多分类实战:以鸢尾花数据集为例

文章目录

引言:从二分类到多分类

  • 逻辑回归是机器学习中最基础也最重要的算法之一,但初学者常常困惑:逻辑回归明明是二分类算法,如何能处理多分类问题呢?本文将带你深入了解逻辑回归的多分类策略,并通过完整的鸢尾花分类代码实现。

一、多分类问题无处不在

  • 在我们的日常生活和工作中,多分类问题比比皆是:
  • 邮件分类:工作(y=1)、朋友(y=2)、家庭(y=3)、爱好(y=4)
  • 天气预测:晴天(y=1)、多云(y=2)、雨天(y=3)、雪天(y=4)
  • 医疗诊断:健康(y=1)、感冒(y=2)、流感(y=3)

这些场景都需要算法能够区分多个类别,而逻辑回归通过巧妙的扩展就能胜任这些任务。

二、One-vs-All策略揭秘

1. 核心思想

  • One-vs-All(一对多,也称为One-vs-Rest)策略将多分类问题转化为多个二分类问题:
  1. 对于N个类别,训练N个独立的二分类器
  2. 第i个分类器将第i类作为正类,其余所有类别作为负类
  3. 预测时,选择所有分类器中预测概率最高的类别

2. 数学表达

对于第i类,我们的假设函数为:
h θ ( i ) ( x ) = P ( y = i ∣ x ; θ ) h_\theta^{(i)}(x) = P(y = i|x;\theta) hθ(i)(x)=P(y=i∣x;θ)

预测时选择:
max ⁡ i h θ ( i ) ( x ) \max_i h_\theta^{(i)}(x) imaxhθ(i)(x)

三、鸢尾花分类完整实现

  • 使用Python和scikit-learn库完整实现鸢尾花的多分类任务。

1. 环境准备

python 复制代码
# 导入必要的库
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import (accuracy_score,
                            confusion_matrix,
                            classification_report)
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.multiclass import OneVsRestClassifier
from sklearn.preprocessing import StandardScaler

2. 数据加载与探索

python 复制代码
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']  # 或 'Microsoft YaHei'
plt.rcParams['axes.unicode_minus'] = False

# 加载鸢尾花数据集
iris = load_iris()
# 特征名称: ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
# sepal length:花萼长度 sepal width:花萼宽度 petal length: 花瓣长度 petal width: 花瓣宽度
X = iris.data  # 特征矩阵 (150, 4)
# 目标类别: ['setosa' 'versicolor' 'virginica']
# setosa 山鸢尾 versicolor 变色鸢尾 virginica 维吉尼亚鸢尾
y = iris.target  # 标签 (150,)

# 查看特征名称和目标类别
print("特征名称:", iris.feature_names)
print("目标类别:", iris.target_names)

# 将数据转换为DataFrame便于可视化
iris_df = pd.DataFrame(X, columns=iris.feature_names)
iris_df['species'] = y
iris_df['species'] = iris_df['species'].map({0: 'setosa', 1: 'versicolor', 2: 'virginica'})

# 绘制特征分布图
sns.pairplot(iris_df, hue='species', palette='husl')
plt.suptitle("鸢尾花特征分布", y=1.02)
plt.show()
bash 复制代码
特征名称: ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']
目标类别: ['setosa' 'versicolor' 'virginica']

3. 数据预处理

python 复制代码
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42, stratify=y)

print(f"训练集样本数: {len(X_train)}")
print(f"测试集样本数: {len(X_test)}")

# 特征标准化(逻辑回归通常需要)
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
bash 复制代码
训练集样本数: 120
测试集样本数: 30

4. 模型训练与评估

python 复制代码
# 构建逻辑回归模型
log_reg = LogisticRegression(
    C=1000,                # 正则化强度的倒数
    solver='sag',          # 随机平均梯度下降
    max_iter=1000,         # 最大迭代次数
    random_state=42
)
# 使用 OneVsRestClassifier 包装
ovr_classifier = OneVsRestClassifier(log_reg)

# 训练模型
ovr_classifier.fit(X_train, y_train)

# 在训练集和测试集上评估
train_acc = ovr_classifier.score(X_train, y_train)
test_acc = ovr_classifier.score(X_test, y_test)

print(f"训练集准确率: {train_acc:.2%}")
print(f"测试集准确率: {test_acc:.2%}")

# 更详细的评估报告
y_pred = ovr_classifier.predict(X_test)
print("\n分类报告:")
print(classification_report(y_test, y_pred, target_names=iris.target_names))

# 绘制混淆矩阵
cm = confusion_matrix(y_test, y_pred)
plt.figure(figsize=(6,6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=iris.target_names,
            yticklabels=iris.target_names)
plt.xlabel('预测标签')
plt.ylabel('真实标签')
plt.title('混淆矩阵')
plt.show()
bash 复制代码
训练集准确率: 96.67%
测试集准确率: 96.67%

分类报告:
              precision    recall  f1-score   support

      setosa       1.00      1.00      1.00        10
  versicolor       1.00      0.90      0.95        10
   virginica       0.91      1.00      0.95        10

    accuracy                           0.97        30
   macro avg       0.97      0.97      0.97        30
weighted avg       0.97      0.97      0.97        30

5. 决策边界可视化

python 复制代码
# 为可视化,只使用两个主要特征
X_train_2d = X_train[:, :2]
X_test_2d = X_test[:, :2]

# 重新训练一个2D模型
log_reg_2d = LogisticRegression(
    C=1000,
    solver='sag',
    max_iter=2000,
    random_state=42
)

ovr_classifier_2d = OneVsRestClassifier(log_reg_2d)
ovr_classifier_2d.fit(X_train_2d, y_train)  # 必须先训练模型

# 创建网格点
x_min, x_max = X_train_2d[:, 0].min() - 1, X_train_2d[:, 0].max() + 1
y_min, y_max = X_train_2d[:, 1].min() - 1, X_train_2d[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02), np.arange(y_min, y_max, 0.02))

# 预测每个网格点的类别
Z = ovr_classifier_2d.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

# 预测每个网格点的类别
Z = ovr_classifier_2d.predict(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

# 绘制决策边界
plt.figure(figsize=(10, 6))
plt.contourf(xx, yy, Z, alpha=0.4, cmap='Pastel2')
scatter = plt.scatter(X_train_2d[:, 0], X_train_2d[:, 1], c=y_train,
                      cmap='Dark2', edgecolor='black')

# 添加图例和标签
legend_elements = scatter.legend_elements()[0]
plt.legend(legend_elements,iris.target_names,title="鸢尾花种类")
plt.xlabel(iris.feature_names[0])
plt.ylabel(iris.feature_names[1])
plt.title("逻辑回归多分类决策边界")
plt.show()

四、关键参数解析

在构建逻辑回归模型时的重要参数:

  1. C=1000:正则化强度的倒数,较小的值表示更强的正则化。这里设为较大的值,相当于减少正则化。
  2. multi_class='ovr':指定使用One-vs-Rest策略处理多分类问题。scikit-learn还支持'multinomial'选项,使用softmax函数直接进行多分类。
  3. solver='sag':优化算法选择随机平均梯度下降(Stochastic Average Gradient),适合大数据集。其他可选算法包括:'liblinear':适合小数据集;'newton-cg':牛顿法;'lbfgs':拟牛顿法。
  4. max_iter=1000:最大迭代次数,确保模型能够收敛。

五、总结

  1. 逻辑回归如何通过One-vs-All策略处理多分类问题
  2. 完整的鸢尾花分类实现流程
  3. 模型评估与可视化方法
相关推荐
超龄超能程序猿9 小时前
YOLOv8 五大核心模型:从检测到分类的介绍
yolo·分类·数据挖掘
fie888910 小时前
NSCT(非下采样轮廓波变换)的分解和重建程序
算法
晨晖211 小时前
单链表逆转,c语言
c语言·数据结构·算法
im_AMBER12 小时前
Leetcode 78 识别数组中的最大异常值 | 镜像对之间最小绝对距离
笔记·学习·算法·leetcode
鼾声鼾语13 小时前
matlab的ros2发布的消息,局域网内其他设备收不到情况吗?但是matlab可以订阅其他局域网的ros2发布的消息(问题总结)
开发语言·人工智能·深度学习·算法·matlab·isaaclab
LYFlied13 小时前
【每日算法】LeetCode 25. K 个一组翻转链表
算法·leetcode·链表
Swizard13 小时前
别再迷信“准确率”了!一文读懂 AI 图像分割的黄金标尺 —— Dice 系数
python·算法·训练
s090713613 小时前
紧凑型3D成像声纳实现路径
算法·3d·声呐·前视多波束
可爱的小小小狼13 小时前
算法:二叉树遍历
算法
廋到被风吹走13 小时前
【数据库】【MySQL】分库分表策略 分类、优势与短板
数据库·mysql·分类