逻辑回归(Logistic Regression)进行多分类的实战

一、多分类策略

逻辑回归处理多分类主要有三种策略:

  1. OvR(One-vs-Rest)
    为每个类别训练一个二分类器

预测时选择概率最高的类别

Scikit-learn默认使用此方法

  1. OvO(One-vs-One)
    为每对类别训练一个分类器

适合类别较少但样本均衡的情况

  1. Softmax回归(Multinomial)
    直接输出多个类别的概率分布

使用交叉熵损失函数

二、完整实战代码示例

python 复制代码
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import datasets
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    classification_report, 
    confusion_matrix, 
    accuracy_score,
    roc_curve,
    auc,
    roc_auc_score
)
from sklearn.multiclass import OneVsRestClassifier
import warnings
warnings.filterwarnings('ignore')

# 设置中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

步骤1:加载和准备数据

python 复制代码
# 加载鸢尾花数据集(3个类别)
iris = datasets.load_iris()
X = iris.data
y = iris.target
feature_names = iris.feature_names
target_names = iris.target_names

print(f"特征形状: {X.shape}")
print(f"标签形状: {y.shape}")
print(f"类别: {target_names}")
print(f"特征名: {feature_names}")

# 查看数据分布
print("\n类别分布:")
for i, name in enumerate(target_names):
    print(f"{name}: {np.sum(y == i)} 个样本")

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y
)

# 标准化特征
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

步骤2:模型训练与评估

方法1:使用默认的OvR策略

python 复制代码
# 创建逻辑回归模型(默认使用OvR)
model_ovr = LogisticRegression(
    multi_class='ovr',  # One-vs-Rest
    solver='lbfgs',     # 适用于小数据集
    max_iter=1000,
    random_state=42,
    C=1.0               # 正则化强度,越小正则化越强
)

# 训练模型
model_ovr.fit(X_train_scaled, y_train)

# 预测
y_pred_ovr = model_ovr.predict(X_test_scaled)
y_pred_proba_ovr = model_ovr.predict_proba(X_test_scaled)

# 评估
print("=== OvR策略评估 ===")
print(f"准确率: {accuracy_score(y_test, y_pred_ovr):.4f}")
print("\n分类报告:")
print(classification_report(y_test, y_pred_ovr, target_names=target_names))

方法2:使用Softmax回归

python 复制代码
# 创建Softmax回归模型
model_softmax = LogisticRegression(
    multi_class='multinomial',  # Softmax回归
    solver='lbfgs',
    max_iter=1000,
    random_state=42,
    C=1.0
)

# 训练模型
model_softmax.fit(X_train_scaled, y_train)

# 预测
y_pred_softmax = model_softmax.predict(X_test_scaled)

# 评估
print("\n=== Softmax回归评估 ===")
print(f"准确率: {accuracy_score(y_test, y_pred_softmax):.4f}")
print("\n分类报告:")
print(classification_report(y_test, y_pred_softmax, target_names=target_names))

方法3:使用OneVsRestClassifier包装器

python 复制代码
# 使用包装器实现OvR
model_ovr_wrapper = OneVsRestClassifier(
    LogisticRegression(solver='lbfgs', max_iter=1000, random_state=42)
)
model_ovr_wrapper.fit(X_train_scaled, y_train)
y_pred_ovr_wrapper = model_ovr_wrapper.predict(X_test_scaled)

print("\n=== OvR包装器评估 ===")
print(f"准确率: {accuracy_score(y_test, y_pred_ovr_wrapper):.4f}")

步骤3:可视化分析

python 复制代码
def plot_confusion_matrix(y_true, y_pred, class_names, title):
    """绘制混淆矩阵"""
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names)
    plt.title(f'混淆矩阵 - {title}', fontsize=14)
    plt.ylabel('真实标签')
    plt.xlabel('预测标签')
    plt.tight_layout()
    plt.show()

# 绘制混淆矩阵
plot_confusion_matrix(y_test, y_pred_ovr, target_names, "OvR策略")

# 绘制特征重要性
def plot_feature_importance(model, feature_names, target_names):
    """绘制特征重要性(权重)"""
    if hasattr(model, 'coef_'):
        weights = model.coef_
        
        fig, axes = plt.subplots(1, len(target_names), figsize=(15, 5))
        
        for i, (ax, class_name) in enumerate(zip(axes, target_names)):
            ax.barh(feature_names, weights[i])
            ax.set_title(f'类别: {class_name}')
            ax.set_xlabel('权重')
        
        plt.suptitle('逻辑回归特征权重(每个类别的决策边界)', fontsize=14)
        plt.tight_layout()
        plt.show()

plot_feature_importance(model_ovr, feature_names, target_names)

步骤4:概率可视化

python 复制代码
# 绘制预测概率分布
def plot_probability_distribution(y_pred_proba, y_true, target_names):
    """绘制预测概率分布"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    for i, (ax, class_name) in enumerate(zip(axes, target_names)):
        # 获取属于当前类别的样本的概率
        true_class_mask = (y_true == i)
        prob_for_class = y_pred_proba[true_class_mask, i]
        
        ax.hist(prob_for_class, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
        ax.set_title(f'{class_name} 的预测概率分布')
        ax.set_xlabel('预测概率')
        ax.set_ylabel('样本数')
        ax.grid(True, alpha=0.3)
    
    plt.suptitle('各类别预测概率分布', fontsize=14)
    plt.tight_layout()
    plt.show()

plot_probability_distribution(y_pred_proba_ovr, y_test, target_names)

步骤5:模型调优

python 复制代码
# 使用网格搜索寻找最佳参数
param_grid = {
    'C': [0.001, 0.01, 0.1, 1, 10, 100],  # 正则化强度
    'solver': ['lbfgs', 'liblinear', 'saga'],
    'max_iter': [100, 500, 1000]
}

# 创建网格搜索
grid_search = GridSearchCV(
    LogisticRegression(multi_class='ovr', random_state=42),
    param_grid,
    cv=5,
    scoring='accuracy',
    n_jobs=-1,
    verbose=1
)

# 执行网格搜索
grid_search.fit(X_train_scaled, y_train)

print("\n=== 网格搜索结果 ===")
print(f"最佳参数: {grid_search.best_params_}")
print(f"最佳交叉验证准确率: {grid_search.best_score_:.4f}")
print(f"测试集准确率: {grid_search.score(X_test_scaled, y_test):.4f}")

# 使用最佳模型
best_model = grid_search.best_estimator_
y_pred_best = best_model.predict(X_test_scaled)

print("\n=== 最佳模型评估 ===")
print(classification_report(y_test, y_pred_best, target_names=target_names))

步骤6:交叉验证评估

python 复制代码
# 交叉验证评估模型稳定性
cv_scores = cross_val_score(
    best_model, 
    X_train_scaled, 
    y_train, 
    cv=5, 
    scoring='accuracy'
)

print("\n=== 交叉验证结果 ===")
print(f"交叉验证准确率: {cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})")
print(f"各折准确率: {cv_scores}")

# 绘制交叉验证结果
plt.figure(figsize=(10, 6))
plt.plot(range(1, 6), cv_scores, marker='o', linewidth=2, markersize=8)
plt.axhline(y=cv_scores.mean(), color='r', linestyle='--', label=f'均值: {cv_scores.mean():.4f}')
plt.fill_between(range(1, 6), 
                 cv_scores.mean() - cv_scores.std(), 
                 cv_scores.mean() + cv_scores.std(), 
                 alpha=0.2, color='gray')
plt.title('5折交叉验证准确率', fontsize=14)
plt.xlabel('折数')
plt.ylabel('准确率')
plt.legend()
plt.grid(True, alpha=0.3)
plt.ylim([0.8, 1.0])
plt.show()

三、关键要点总结

1.策略选择:

类别较少且均衡:考虑OvO

类别较多:使用OvR或Softmax

Softmax通常更直接,但需要计算所有类别的概率

2.特征工程:

逻辑回归对特征缩放敏感,务必标准化

特征间的多重共线性会影响结果

3.正则化:

参数C控制正则化强度(C越小,正则化越强)

防止过拟合的重要工具

4.模型评估:

多分类使用准确率、混淆矩阵、分类报告

考虑使用宏平均和微平均

5.注意事项:

逻辑回归假设特征与log odds线性相关

对于非线性问题,需要特征工程或使用核方法

类别不平衡时需要调整class_weight参数

相关推荐
纤纡.12 小时前
逻辑回归实战进阶:交叉验证与采样技术破解数据痛点(二)
算法·机器学习·逻辑回归
czhc114007566312 小时前
协议 25
java·开发语言·算法
范纹杉想快点毕业12 小时前
状态机设计与嵌入式系统开发完整指南从面向过程到面向对象,从理论到实践的全面解析
linux·服务器·数据库·c++·算法·mongodb·mfc
fish-man12 小时前
测试加粗效果
算法
晓131312 小时前
第二章 【C语言篇:入门】 C 语言基础入门
c语言·算法
小徐xxx12 小时前
Softmax回归(分类问题)学习记录
深度学习·分类·回归·softmax·学习记录
yong999012 小时前
MATLAB面波频散曲线反演程序
开发语言·算法·matlab
AAD5558889913 小时前
YOLOv8-MAN-Faster电容器缺陷检测:七类组件识别与分类系统
yolo·分类·数据挖掘
JicasdC123asd13 小时前
【工业检测】基于YOLO13-C3k2-EIEM的铸造缺陷检测与分类系统_1
人工智能·算法·分类
Not Dr.Wang42213 小时前
自动控制系统稳定性研究及判据分析
算法