逻辑回归(Logistic Regression)进行多分类的实战

一、多分类策略

逻辑回归处理多分类主要有三种策略:

  1. OvR(One-vs-Rest)
    为每个类别训练一个二分类器

预测时选择概率最高的类别

Scikit-learn默认使用此方法

  1. OvO(One-vs-One)
    为每对类别训练一个分类器

适合类别较少但样本均衡的情况

  1. Softmax回归(Multinomial)
    直接输出多个类别的概率分布

使用交叉熵损失函数

二、完整实战代码示例

python 复制代码
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import datasets
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    classification_report, 
    confusion_matrix, 
    accuracy_score,
    roc_curve,
    auc,
    roc_auc_score
)
from sklearn.multiclass import OneVsRestClassifier
import warnings
warnings.filterwarnings('ignore')

# 设置中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

步骤1:加载和准备数据

python 复制代码
# 加载鸢尾花数据集(3个类别)
iris = datasets.load_iris()
X = iris.data
y = iris.target
feature_names = iris.feature_names
target_names = iris.target_names

print(f"特征形状: {X.shape}")
print(f"标签形状: {y.shape}")
print(f"类别: {target_names}")
print(f"特征名: {feature_names}")

# 查看数据分布
print("\n类别分布:")
for i, name in enumerate(target_names):
    print(f"{name}: {np.sum(y == i)} 个样本")

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y
)

# 标准化特征
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

步骤2:模型训练与评估

方法1:使用默认的OvR策略

python 复制代码
# 创建逻辑回归模型(默认使用OvR)
model_ovr = LogisticRegression(
    multi_class='ovr',  # One-vs-Rest
    solver='lbfgs',     # 适用于小数据集
    max_iter=1000,
    random_state=42,
    C=1.0               # 正则化强度,越小正则化越强
)

# 训练模型
model_ovr.fit(X_train_scaled, y_train)

# 预测
y_pred_ovr = model_ovr.predict(X_test_scaled)
y_pred_proba_ovr = model_ovr.predict_proba(X_test_scaled)

# 评估
print("=== OvR策略评估 ===")
print(f"准确率: {accuracy_score(y_test, y_pred_ovr):.4f}")
print("\n分类报告:")
print(classification_report(y_test, y_pred_ovr, target_names=target_names))

方法2:使用Softmax回归

python 复制代码
# 创建Softmax回归模型
model_softmax = LogisticRegression(
    multi_class='multinomial',  # Softmax回归
    solver='lbfgs',
    max_iter=1000,
    random_state=42,
    C=1.0
)

# 训练模型
model_softmax.fit(X_train_scaled, y_train)

# 预测
y_pred_softmax = model_softmax.predict(X_test_scaled)

# 评估
print("\n=== Softmax回归评估 ===")
print(f"准确率: {accuracy_score(y_test, y_pred_softmax):.4f}")
print("\n分类报告:")
print(classification_report(y_test, y_pred_softmax, target_names=target_names))

方法3:使用OneVsRestClassifier包装器

python 复制代码
# 使用包装器实现OvR
model_ovr_wrapper = OneVsRestClassifier(
    LogisticRegression(solver='lbfgs', max_iter=1000, random_state=42)
)
model_ovr_wrapper.fit(X_train_scaled, y_train)
y_pred_ovr_wrapper = model_ovr_wrapper.predict(X_test_scaled)

print("\n=== OvR包装器评估 ===")
print(f"准确率: {accuracy_score(y_test, y_pred_ovr_wrapper):.4f}")

步骤3:可视化分析

python 复制代码
def plot_confusion_matrix(y_true, y_pred, class_names, title):
    """绘制混淆矩阵"""
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title(f'混淆矩阵 - {title}', fontsize=14)
    plt.ylabel('真实标签')
    plt.xlabel('预测标签')
    plt.tight_layout()
    plt.show()

# 绘制混淆矩阵
plot_confusion_matrix(y_test, y_pred_ovr, target_names, "OvR策略")

# 绘制特征重要性
def plot_feature_importance(model, feature_names, target_names):
    """绘制特征重要性(权重)"""
    if hasattr(model, 'coef_'):
        weights = model.coef_
        
        fig, axes = plt.subplots(1, len(target_names), figsize=(15, 5))
        
        for i, (ax, class_name) in enumerate(zip(axes, target_names)):
            ax.barh(feature_names, weights[i])
            ax.set_title(f'类别: {class_name}')
            ax.set_xlabel('权重')
        
        plt.suptitle('逻辑回归特征权重(每个类别的决策边界)', fontsize=14)
        plt.tight_layout()
        plt.show()

plot_feature_importance(model_ovr, feature_names, target_names)

步骤4:概率可视化

python 复制代码
# 绘制预测概率分布
def plot_probability_distribution(y_pred_proba, y_true, target_names):
    """绘制预测概率分布"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    for i, (ax, class_name) in enumerate(zip(axes, target_names)):
        # 获取属于当前类别的样本的概率
        true_class_mask = (y_true == i)
        prob_for_class = y_pred_proba[true_class_mask, i]
        
        ax.hist(prob_for_class, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
        ax.set_title(f'{class_name} 的预测概率分布')
        ax.set_xlabel('预测概率')
        ax.set_ylabel('样本数')
        ax.grid(True, alpha=0.3)
    
    plt.suptitle('各类别预测概率分布', fontsize=14)
    plt.tight_layout()
    plt.show()

plot_probability_distribution(y_pred_proba_ovr, y_test, target_names)

步骤5:模型调优

python 复制代码
# 使用网格搜索寻找最佳参数
param_grid = {
    'C': [0.001, 0.01, 0.1, 1, 10, 100],  # 正则化强度
    'solver': ['lbfgs', 'liblinear', 'saga'],
    'max_iter': [100, 500, 1000]
}

# 创建网格搜索
grid_search = GridSearchCV(
    LogisticRegression(multi_class='ovr', random_state=42),
    param_grid,
    cv=5,
    scoring='accuracy',
    n_jobs=-1,
    verbose=1
)

# 执行网格搜索
grid_search.fit(X_train_scaled, y_train)

print("\n=== 网格搜索结果 ===")
print(f"最佳参数: {grid_search.best_params_}")
print(f"最佳交叉验证准确率: {grid_search.best_score_:.4f}")
print(f"测试集准确率: {grid_search.score(X_test_scaled, y_test):.4f}")

# 使用最佳模型
best_model = grid_search.best_estimator_
y_pred_best = best_model.predict(X_test_scaled)

print("\n=== 最佳模型评估 ===")
print(classification_report(y_test, y_pred_best, target_names=target_names))

步骤6:交叉验证评估

python 复制代码
# 交叉验证评估模型稳定性
cv_scores = cross_val_score(
    best_model, 
    X_train_scaled, 
    y_train, 
    cv=5, 
    scoring='accuracy'
)

print("\n=== 交叉验证结果 ===")
print(f"交叉验证准确率: {cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})")
print(f"各折准确率: {cv_scores}")

# 绘制交叉验证结果
plt.figure(figsize=(10, 6))
plt.plot(range(1, 6), cv_scores, marker='o', linewidth=2, markersize=8)
plt.axhline(y=cv_scores.mean(), color='r', linestyle='--', label=f'均值: {cv_scores.mean():.4f}')
plt.fill_between(range(1, 6), 
                 cv_scores.mean() - cv_scores.std(), 
                 cv_scores.mean() + cv_scores.std(), 
                 alpha=0.2, color='gray')
plt.title('5折交叉验证准确率', fontsize=14)
plt.xlabel('折数')
plt.ylabel('准确率')
plt.legend()
plt.grid(True, alpha=0.3)
plt.ylim([0.8, 1.0])
plt.show()

三、关键要点总结

1.策略选择:

类别较少且均衡:考虑OvO

类别较多:使用OvR或Softmax

Softmax通常更直接,但需要计算所有类别的概率

2.特征工程:

逻辑回归对特征缩放敏感,务必标准化

特征间的多重共线性会影响结果

3.正则化:

参数C控制正则化强度(C越小,正则化越强)

防止过拟合的重要工具

4.模型评估:

多分类使用准确率、混淆矩阵、分类报告

考虑使用宏平均和微平均

5.注意事项:

逻辑回归假设特征与log odds线性相关

对于非线性问题,需要特征工程或使用核方法

类别不平衡时需要调整class_weight参数

相关推荐
元亓亓亓2 小时前
LeetCode热题100--215. 数组中的第K个最大元素--中等
算法·leetcode·职场和发展
CoderYanger2 小时前
C.滑动窗口-求子数组个数-越长越合法——2962. 统计最大元素出现至少 K 次的子数组
java·数据结构·算法·leetcode·职场和发展
Eiceblue2 小时前
通过 C# 将 RTF 文档转换为图片
开发语言·算法·c#
alphaTao2 小时前
LeetCode 每日一题 2025/12/8-2025/12/14
算法·leetcode
玖日大大2 小时前
ModelEngine 可视化编排实战:从智能会议助手到企业级 AI 应用构建全指南
大数据·人工智能·算法
月明长歌2 小时前
【码道初阶】Leetcode面试题02.04:分割链表[中等难度]
java·数据结构·算法·leetcode·链表
如竟没有火炬2 小时前
快乐数——哈希表
数据结构·python·算法·leetcode·散列表
TL滕2 小时前
从0开始学算法——第十四天(数组与搜索练习)
笔记·学习·算法
SoleMotive.2 小时前
bio、nio、aio的区别以及使用场景
python·算法·nio