逻辑回归(Logistic Regression)进行多分类的实战

一、多分类策略

逻辑回归处理多分类主要有三种策略:

  1. OvR(One-vs-Rest)
    为每个类别训练一个二分类器

预测时选择概率最高的类别

Scikit-learn默认使用此方法

  1. OvO(One-vs-One)
    为每对类别训练一个分类器

适合类别较少但样本均衡的情况

  1. Softmax回归(Multinomial)
    直接输出多个类别的概率分布

使用交叉熵损失函数

二、完整实战代码示例

python 复制代码
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import datasets
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    classification_report, 
    confusion_matrix, 
    accuracy_score,
    roc_curve,
    auc,
    roc_auc_score
)
from sklearn.multiclass import OneVsRestClassifier
import warnings
warnings.filterwarnings('ignore')

# 设置中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

步骤1:加载和准备数据

python 复制代码
# 加载鸢尾花数据集(3个类别)
iris = datasets.load_iris()
X = iris.data
y = iris.target
feature_names = iris.feature_names
target_names = iris.target_names

print(f"特征形状: {X.shape}")
print(f"标签形状: {y.shape}")
print(f"类别: {target_names}")
print(f"特征名: {feature_names}")

# 查看数据分布
print("\n类别分布:")
for i, name in enumerate(target_names):
    print(f"{name}: {np.sum(y == i)} 个样本")

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y
)

# 标准化特征
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

步骤2:模型训练与评估

方法1:使用默认的OvR策略

python 复制代码
# 创建逻辑回归模型(默认使用OvR)
model_ovr = LogisticRegression(
    multi_class='ovr',  # One-vs-Rest
    solver='lbfgs',     # 适用于小数据集
    max_iter=1000,
    random_state=42,
    C=1.0               # 正则化强度,越小正则化越强
)

# 训练模型
model_ovr.fit(X_train_scaled, y_train)

# 预测
y_pred_ovr = model_ovr.predict(X_test_scaled)
y_pred_proba_ovr = model_ovr.predict_proba(X_test_scaled)

# 评估
print("=== OvR策略评估 ===")
print(f"准确率: {accuracy_score(y_test, y_pred_ovr):.4f}")
print("\n分类报告:")
print(classification_report(y_test, y_pred_ovr, target_names=target_names))

方法2:使用Softmax回归

python 复制代码
# 创建Softmax回归模型
model_softmax = LogisticRegression(
    multi_class='multinomial',  # Softmax回归
    solver='lbfgs',
    max_iter=1000,
    random_state=42,
    C=1.0
)

# 训练模型
model_softmax.fit(X_train_scaled, y_train)

# 预测
y_pred_softmax = model_softmax.predict(X_test_scaled)

# 评估
print("\n=== Softmax回归评估 ===")
print(f"准确率: {accuracy_score(y_test, y_pred_softmax):.4f}")
print("\n分类报告:")
print(classification_report(y_test, y_pred_softmax, target_names=target_names))

方法3:使用OneVsRestClassifier包装器

python 复制代码
# 使用包装器实现OvR
model_ovr_wrapper = OneVsRestClassifier(
    LogisticRegression(solver='lbfgs', max_iter=1000, random_state=42)
)
model_ovr_wrapper.fit(X_train_scaled, y_train)
y_pred_ovr_wrapper = model_ovr_wrapper.predict(X_test_scaled)

print("\n=== OvR包装器评估 ===")
print(f"准确率: {accuracy_score(y_test, y_pred_ovr_wrapper):.4f}")

步骤3:可视化分析

python 复制代码
def plot_confusion_matrix(y_true, y_pred, class_names, title):
    """绘制混淆矩阵"""
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title(f'混淆矩阵 - {title}', fontsize=14)
    plt.ylabel('真实标签')
    plt.xlabel('预测标签')
    plt.tight_layout()
    plt.show()

# 绘制混淆矩阵
plot_confusion_matrix(y_test, y_pred_ovr, target_names, "OvR策略")

# 绘制特征重要性
def plot_feature_importance(model, feature_names, target_names):
    """绘制特征重要性(权重)"""
    if hasattr(model, 'coef_'):
        weights = model.coef_
        
        fig, axes = plt.subplots(1, len(target_names), figsize=(15, 5))
        
        for i, (ax, class_name) in enumerate(zip(axes, target_names)):
            ax.barh(feature_names, weights[i])
            ax.set_title(f'类别: {class_name}')
            ax.set_xlabel('权重')
        
        plt.suptitle('逻辑回归特征权重(每个类别的决策边界)', fontsize=14)
        plt.tight_layout()
        plt.show()

plot_feature_importance(model_ovr, feature_names, target_names)

步骤4:概率可视化

python 复制代码
# 绘制预测概率分布
def plot_probability_distribution(y_pred_proba, y_true, target_names):
    """绘制预测概率分布"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    for i, (ax, class_name) in enumerate(zip(axes, target_names)):
        # 获取属于当前类别的样本的概率
        true_class_mask = (y_true == i)
        prob_for_class = y_pred_proba[true_class_mask, i]
        
        ax.hist(prob_for_class, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
        ax.set_title(f'{class_name} 的预测概率分布')
        ax.set_xlabel('预测概率')
        ax.set_ylabel('样本数')
        ax.grid(True, alpha=0.3)
    
    plt.suptitle('各类别预测概率分布', fontsize=14)
    plt.tight_layout()
    plt.show()

plot_probability_distribution(y_pred_proba_ovr, y_test, target_names)

步骤5:模型调优

python 复制代码
# 使用网格搜索寻找最佳参数
param_grid = {
    'C': [0.001, 0.01, 0.1, 1, 10, 100],  # 正则化强度
    'solver': ['lbfgs', 'liblinear', 'saga'],
    'max_iter': [100, 500, 1000]
}

# 创建网格搜索
grid_search = GridSearchCV(
    LogisticRegression(multi_class='ovr', random_state=42),
    param_grid,
    cv=5,
    scoring='accuracy',
    n_jobs=-1,
    verbose=1
)

# 执行网格搜索
grid_search.fit(X_train_scaled, y_train)

print("\n=== 网格搜索结果 ===")
print(f"最佳参数: {grid_search.best_params_}")
print(f"最佳交叉验证准确率: {grid_search.best_score_:.4f}")
print(f"测试集准确率: {grid_search.score(X_test_scaled, y_test):.4f}")

# 使用最佳模型
best_model = grid_search.best_estimator_
y_pred_best = best_model.predict(X_test_scaled)

print("\n=== 最佳模型评估 ===")
print(classification_report(y_test, y_pred_best, target_names=target_names))

步骤6:交叉验证评估

python 复制代码
# 交叉验证评估模型稳定性
cv_scores = cross_val_score(
    best_model, 
    X_train_scaled, 
    y_train, 
    cv=5, 
    scoring='accuracy'
)

print("\n=== 交叉验证结果 ===")
print(f"交叉验证准确率: {cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})")
print(f"各折准确率: {cv_scores}")

# 绘制交叉验证结果
plt.figure(figsize=(10, 6))
plt.plot(range(1, 6), cv_scores, marker='o', linewidth=2, markersize=8)
plt.axhline(y=cv_scores.mean(), color='r', linestyle='--', label=f'均值: {cv_scores.mean():.4f}')
plt.fill_between(range(1, 6), 
                 cv_scores.mean() - cv_scores.std(), 
                 cv_scores.mean() + cv_scores.std(), 
                 alpha=0.2, color='gray')
plt.title('5折交叉验证准确率', fontsize=14)
plt.xlabel('折数')
plt.ylabel('准确率')
plt.legend()
plt.grid(True, alpha=0.3)
plt.ylim([0.8, 1.0])
plt.show()

三、关键要点总结

1.策略选择:

类别较少且均衡:考虑OvO

类别较多:使用OvR或Softmax

Softmax通常更直接,但需要计算所有类别的概率

2.特征工程:

逻辑回归对特征缩放敏感,务必标准化

特征间的多重共线性会影响结果

3.正则化:

参数C控制正则化强度(C越小,正则化越强)

防止过拟合的重要工具

4.模型评估:

多分类使用准确率、混淆矩阵、分类报告

考虑使用宏平均和微平均

5.注意事项:

逻辑回归假设特征与log odds线性相关

对于非线性问题,需要特征工程或使用核方法

类别不平衡时需要调整class_weight参数

相关推荐
超级码力6668 小时前
【Latex文件架构】Latex文件架构模板
算法·数学建模·信息可视化
穿条秋裤到处跑8 小时前
每日一道leetcode(2026.04.29):二维网格图中探测环
算法·leetcode·职场和发展
Merlos_wind8 小时前
HashMap详解
算法·哈希算法·散列表
汉克老师9 小时前
GESP2025年3月认证C++五级( 第三部分编程题(1、平均分配))
c++·算法·贪心算法·排序·gesp5级·gesp五级
Yzzz-F11 小时前
Problem - 2205D - Codeforces
算法
智者知已应修善业12 小时前
【51单片机2个按键控制流水灯运行与暂停】2023-9-6
c++·经验分享·笔记·算法·51单片机
Halo_tjn12 小时前
Java Set集合相关知识点
java·开发语言·算法
生成论实验室13 小时前
《事件关系阴阳博弈动力学:识势应势之道》第四篇:降U动力学——认知确定度的自驱演化
人工智能·科技·神经网络·算法·架构
AI科技星13 小时前
全域数学·72分册:场计算机卷【乖乖数学】
算法·机器学习·数学建模·数据挖掘·量子计算
科研前沿14 小时前
镜像孪生VS视频孪生核心技术产品核心优势
大数据·人工智能·算法·重构·空间计算