逻辑回归(Logistic Regression)进行多分类的实战

一、多分类策略

逻辑回归处理多分类主要有三种策略:

  1. OvR(One-vs-Rest)
    为每个类别训练一个二分类器

预测时选择概率最高的类别

Scikit-learn默认使用此方法

  1. OvO(One-vs-One)
    为每对类别训练一个分类器

适合类别较少但样本均衡的情况

  1. Softmax回归(Multinomial)
    直接输出多个类别的概率分布

使用交叉熵损失函数

二、完整实战代码示例

python 复制代码
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import datasets
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    classification_report, 
    confusion_matrix, 
    accuracy_score,
    roc_curve,
    auc,
    roc_auc_score
)
from sklearn.multiclass import OneVsRestClassifier
import warnings
warnings.filterwarnings('ignore')

# 设置中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

步骤1:加载和准备数据

python 复制代码
# 加载鸢尾花数据集(3个类别)
iris = datasets.load_iris()
X = iris.data
y = iris.target
feature_names = iris.feature_names
target_names = iris.target_names

print(f"特征形状: {X.shape}")
print(f"标签形状: {y.shape}")
print(f"类别: {target_names}")
print(f"特征名: {feature_names}")

# 查看数据分布
print("\n类别分布:")
for i, name in enumerate(target_names):
    print(f"{name}: {np.sum(y == i)} 个样本")

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y
)

# 标准化特征
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

步骤2:模型训练与评估

方法1:使用默认的OvR策略

python 复制代码
# 创建逻辑回归模型(默认使用OvR)
model_ovr = LogisticRegression(
    multi_class='ovr',  # One-vs-Rest
    solver='lbfgs',     # 适用于小数据集
    max_iter=1000,
    random_state=42,
    C=1.0               # 正则化强度,越小正则化越强
)

# 训练模型
model_ovr.fit(X_train_scaled, y_train)

# 预测
y_pred_ovr = model_ovr.predict(X_test_scaled)
y_pred_proba_ovr = model_ovr.predict_proba(X_test_scaled)

# 评估
print("=== OvR策略评估 ===")
print(f"准确率: {accuracy_score(y_test, y_pred_ovr):.4f}")
print("\n分类报告:")
print(classification_report(y_test, y_pred_ovr, target_names=target_names))

方法2:使用Softmax回归

python 复制代码
# 创建Softmax回归模型
model_softmax = LogisticRegression(
    multi_class='multinomial',  # Softmax回归
    solver='lbfgs',
    max_iter=1000,
    random_state=42,
    C=1.0
)

# 训练模型
model_softmax.fit(X_train_scaled, y_train)

# 预测
y_pred_softmax = model_softmax.predict(X_test_scaled)

# 评估
print("\n=== Softmax回归评估 ===")
print(f"准确率: {accuracy_score(y_test, y_pred_softmax):.4f}")
print("\n分类报告:")
print(classification_report(y_test, y_pred_softmax, target_names=target_names))

方法3:使用OneVsRestClassifier包装器

python 复制代码
# 使用包装器实现OvR
model_ovr_wrapper = OneVsRestClassifier(
    LogisticRegression(solver='lbfgs', max_iter=1000, random_state=42)
)
model_ovr_wrapper.fit(X_train_scaled, y_train)
y_pred_ovr_wrapper = model_ovr_wrapper.predict(X_test_scaled)

print("\n=== OvR包装器评估 ===")
print(f"准确率: {accuracy_score(y_test, y_pred_ovr_wrapper):.4f}")

步骤3:可视化分析

python 复制代码
def plot_confusion_matrix(y_true, y_pred, class_names, title):
    """绘制混淆矩阵"""
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title(f'混淆矩阵 - {title}', fontsize=14)
    plt.ylabel('真实标签')
    plt.xlabel('预测标签')
    plt.tight_layout()
    plt.show()

# 绘制混淆矩阵
plot_confusion_matrix(y_test, y_pred_ovr, target_names, "OvR策略")

# 绘制特征重要性
def plot_feature_importance(model, feature_names, target_names):
    """绘制特征重要性(权重)"""
    if hasattr(model, 'coef_'):
        weights = model.coef_
        
        fig, axes = plt.subplots(1, len(target_names), figsize=(15, 5))
        
        for i, (ax, class_name) in enumerate(zip(axes, target_names)):
            ax.barh(feature_names, weights[i])
            ax.set_title(f'类别: {class_name}')
            ax.set_xlabel('权重')
        
        plt.suptitle('逻辑回归特征权重(每个类别的决策边界)', fontsize=14)
        plt.tight_layout()
        plt.show()

plot_feature_importance(model_ovr, feature_names, target_names)

步骤4:概率可视化

python 复制代码
# 绘制预测概率分布
def plot_probability_distribution(y_pred_proba, y_true, target_names):
    """绘制预测概率分布"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    for i, (ax, class_name) in enumerate(zip(axes, target_names)):
        # 获取属于当前类别的样本的概率
        true_class_mask = (y_true == i)
        prob_for_class = y_pred_proba[true_class_mask, i]
        
        ax.hist(prob_for_class, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
        ax.set_title(f'{class_name} 的预测概率分布')
        ax.set_xlabel('预测概率')
        ax.set_ylabel('样本数')
        ax.grid(True, alpha=0.3)
    
    plt.suptitle('各类别预测概率分布', fontsize=14)
    plt.tight_layout()
    plt.show()

plot_probability_distribution(y_pred_proba_ovr, y_test, target_names)

步骤5:模型调优

python 复制代码
# 使用网格搜索寻找最佳参数
param_grid = {
    'C': [0.001, 0.01, 0.1, 1, 10, 100],  # 正则化强度
    'solver': ['lbfgs', 'liblinear', 'saga'],
    'max_iter': [100, 500, 1000]
}

# 创建网格搜索
grid_search = GridSearchCV(
    LogisticRegression(multi_class='ovr', random_state=42),
    param_grid,
    cv=5,
    scoring='accuracy',
    n_jobs=-1,
    verbose=1
)

# 执行网格搜索
grid_search.fit(X_train_scaled, y_train)

print("\n=== 网格搜索结果 ===")
print(f"最佳参数: {grid_search.best_params_}")
print(f"最佳交叉验证准确率: {grid_search.best_score_:.4f}")
print(f"测试集准确率: {grid_search.score(X_test_scaled, y_test):.4f}")

# 使用最佳模型
best_model = grid_search.best_estimator_
y_pred_best = best_model.predict(X_test_scaled)

print("\n=== 最佳模型评估 ===")
print(classification_report(y_test, y_pred_best, target_names=target_names))

步骤6:交叉验证评估

python 复制代码
# 交叉验证评估模型稳定性
cv_scores = cross_val_score(
    best_model, 
    X_train_scaled, 
    y_train, 
    cv=5, 
    scoring='accuracy'
)

print("\n=== 交叉验证结果 ===")
print(f"交叉验证准确率: {cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})")
print(f"各折准确率: {cv_scores}")

# 绘制交叉验证结果
plt.figure(figsize=(10, 6))
plt.plot(range(1, 6), cv_scores, marker='o', linewidth=2, markersize=8)
plt.axhline(y=cv_scores.mean(), color='r', linestyle='--', label=f'均值: {cv_scores.mean():.4f}')
plt.fill_between(range(1, 6), 
                 cv_scores.mean() - cv_scores.std(), 
                 cv_scores.mean() + cv_scores.std(), 
                 alpha=0.2, color='gray')
plt.title('5折交叉验证准确率', fontsize=14)
plt.xlabel('折数')
plt.ylabel('准确率')
plt.legend()
plt.grid(True, alpha=0.3)
plt.ylim([0.8, 1.0])
plt.show()

三、关键要点总结

1.策略选择:

类别较少且均衡:考虑OvO

类别较多:使用OvR或Softmax

Softmax通常更直接,但需要计算所有类别的概率

2.特征工程:

逻辑回归对特征缩放敏感,务必标准化

特征间的多重共线性会影响结果

3.正则化:

参数C控制正则化强度(C越小,正则化越强)

防止过拟合的重要工具

4.模型评估:

多分类使用准确率、混淆矩阵、分类报告

考虑使用宏平均和微平均

5.注意事项:

逻辑回归假设特征与log odds线性相关

对于非线性问题,需要特征工程或使用核方法

类别不平衡时需要调整class_weight参数

相关推荐
爱看科技1 小时前
量子计算赋能图像智能新突破,微美全息(NASDAQ:WIMI)PQCNN并行混合架构引领多类分类性能跃升
分类·数据挖掘·量子计算
郝学胜-神的一滴2 小时前
Leetcode 969 煎饼排序✨:翻转间的数组排序艺术
数据结构·c++·算法·leetcode·面试
I_LPL9 小时前
hot100贪心专题
数据结构·算法·leetcode·贪心
颜酱9 小时前
DFS 岛屿系列题全解析
javascript·后端·算法
WolfGang00732110 小时前
代码随想录算法训练营 Day16 | 二叉树 part06
算法
算法玩不起11 小时前
以乳腺癌诊断数据为例的医学AI分类建模方法入门
人工智能·分类·数据挖掘
2401_8318249611 小时前
代码性能剖析工具
开发语言·c++·算法
Sunshine for you12 小时前
C++中的职责链模式实战
开发语言·c++·算法
qq_4160187212 小时前
C++中的状态模式
开发语言·c++·算法
2401_8845632412 小时前
模板代码生成工具
开发语言·c++·算法