一、多分类策略
逻辑回归处理多分类主要有三种策略:
- OvR(One-vs-Rest)
为每个类别训练一个二分类器
预测时选择概率最高的类别
Scikit-learn默认使用此方法
- OvO(One-vs-One)
为每对类别训练一个分类器
适合类别较少但样本均衡的情况
- Softmax回归(Multinomial)
直接输出多个类别的概率分布
使用交叉熵损失函数
二、完整实战代码示例
python
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import datasets
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
classification_report,
confusion_matrix,
accuracy_score,
roc_curve,
auc,
roc_auc_score
)
from sklearn.multiclass import OneVsRestClassifier
import warnings
warnings.filterwarnings('ignore')
# 设置中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
步骤1:加载和准备数据
python
# 加载鸢尾花数据集(3个类别)
iris = datasets.load_iris()
X = iris.data
y = iris.target
feature_names = iris.feature_names
target_names = iris.target_names
print(f"特征形状: {X.shape}")
print(f"标签形状: {y.shape}")
print(f"类别: {target_names}")
print(f"特征名: {feature_names}")
# 查看数据分布
print("\n类别分布:")
for i, name in enumerate(target_names):
print(f"{name}: {np.sum(y == i)} 个样本")
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.3, random_state=42, stratify=y
)
# 标准化特征
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
步骤2:模型训练与评估
方法1:使用默认的OvR策略
python
# 创建逻辑回归模型(默认使用OvR)
model_ovr = LogisticRegression(
multi_class='ovr', # One-vs-Rest
solver='lbfgs', # 适用于小数据集
max_iter=1000,
random_state=42,
C=1.0 # 正则化强度,越小正则化越强
)
# 训练模型
model_ovr.fit(X_train_scaled, y_train)
# 预测
y_pred_ovr = model_ovr.predict(X_test_scaled)
y_pred_proba_ovr = model_ovr.predict_proba(X_test_scaled)
# 评估
print("=== OvR策略评估 ===")
print(f"准确率: {accuracy_score(y_test, y_pred_ovr):.4f}")
print("\n分类报告:")
print(classification_report(y_test, y_pred_ovr, target_names=target_names))
方法2:使用Softmax回归
python
# 创建Softmax回归模型
model_softmax = LogisticRegression(
multi_class='multinomial', # Softmax回归
solver='lbfgs',
max_iter=1000,
random_state=42,
C=1.0
)
# 训练模型
model_softmax.fit(X_train_scaled, y_train)
# 预测
y_pred_softmax = model_softmax.predict(X_test_scaled)
# 评估
print("\n=== Softmax回归评估 ===")
print(f"准确率: {accuracy_score(y_test, y_pred_softmax):.4f}")
print("\n分类报告:")
print(classification_report(y_test, y_pred_softmax, target_names=target_names))
方法3:使用OneVsRestClassifier包装器
python
# 使用包装器实现OvR
model_ovr_wrapper = OneVsRestClassifier(
LogisticRegression(solver='lbfgs', max_iter=1000, random_state=42)
)
model_ovr_wrapper.fit(X_train_scaled, y_train)
y_pred_ovr_wrapper = model_ovr_wrapper.predict(X_test_scaled)
print("\n=== OvR包装器评估 ===")
print(f"准确率: {accuracy_score(y_test, y_pred_ovr_wrapper):.4f}")
步骤3:可视化分析
python
def plot_confusion_matrix(y_true, y_pred, class_names, title):
"""绘制混淆矩阵"""
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=class_names, yticklabels=class_names)
plt.title(f'混淆矩阵 - {title}', fontsize=14)
plt.ylabel('真实标签')
plt.xlabel('预测标签')
plt.tight_layout()
plt.show()
# 绘制混淆矩阵
plot_confusion_matrix(y_test, y_pred_ovr, target_names, "OvR策略")
# 绘制特征重要性
def plot_feature_importance(model, feature_names, target_names):
"""绘制特征重要性(权重)"""
if hasattr(model, 'coef_'):
weights = model.coef_
fig, axes = plt.subplots(1, len(target_names), figsize=(15, 5))
for i, (ax, class_name) in enumerate(zip(axes, target_names)):
ax.barh(feature_names, weights[i])
ax.set_title(f'类别: {class_name}')
ax.set_xlabel('权重')
plt.suptitle('逻辑回归特征权重(每个类别的决策边界)', fontsize=14)
plt.tight_layout()
plt.show()
plot_feature_importance(model_ovr, feature_names, target_names)
步骤4:概率可视化
python
# 绘制预测概率分布
def plot_probability_distribution(y_pred_proba, y_true, target_names):
"""绘制预测概率分布"""
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for i, (ax, class_name) in enumerate(zip(axes, target_names)):
# 获取属于当前类别的样本的概率
true_class_mask = (y_true == i)
prob_for_class = y_pred_proba[true_class_mask, i]
ax.hist(prob_for_class, bins=20, alpha=0.7, color='skyblue', edgecolor='black')
ax.set_title(f'{class_name} 的预测概率分布')
ax.set_xlabel('预测概率')
ax.set_ylabel('样本数')
ax.grid(True, alpha=0.3)
plt.suptitle('各类别预测概率分布', fontsize=14)
plt.tight_layout()
plt.show()
plot_probability_distribution(y_pred_proba_ovr, y_test, target_names)
步骤5:模型调优
python
# 使用网格搜索寻找最佳参数
param_grid = {
'C': [0.001, 0.01, 0.1, 1, 10, 100], # 正则化强度
'solver': ['lbfgs', 'liblinear', 'saga'],
'max_iter': [100, 500, 1000]
}
# 创建网格搜索
grid_search = GridSearchCV(
LogisticRegression(multi_class='ovr', random_state=42),
param_grid,
cv=5,
scoring='accuracy',
n_jobs=-1,
verbose=1
)
# 执行网格搜索
grid_search.fit(X_train_scaled, y_train)
print("\n=== 网格搜索结果 ===")
print(f"最佳参数: {grid_search.best_params_}")
print(f"最佳交叉验证准确率: {grid_search.best_score_:.4f}")
print(f"测试集准确率: {grid_search.score(X_test_scaled, y_test):.4f}")
# 使用最佳模型
best_model = grid_search.best_estimator_
y_pred_best = best_model.predict(X_test_scaled)
print("\n=== 最佳模型评估 ===")
print(classification_report(y_test, y_pred_best, target_names=target_names))
步骤6:交叉验证评估
python
# 交叉验证评估模型稳定性
cv_scores = cross_val_score(
best_model,
X_train_scaled,
y_train,
cv=5,
scoring='accuracy'
)
print("\n=== 交叉验证结果 ===")
print(f"交叉验证准确率: {cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})")
print(f"各折准确率: {cv_scores}")
# 绘制交叉验证结果
plt.figure(figsize=(10, 6))
plt.plot(range(1, 6), cv_scores, marker='o', linewidth=2, markersize=8)
plt.axhline(y=cv_scores.mean(), color='r', linestyle='--', label=f'均值: {cv_scores.mean():.4f}')
plt.fill_between(range(1, 6),
cv_scores.mean() - cv_scores.std(),
cv_scores.mean() + cv_scores.std(),
alpha=0.2, color='gray')
plt.title('5折交叉验证准确率', fontsize=14)
plt.xlabel('折数')
plt.ylabel('准确率')
plt.legend()
plt.grid(True, alpha=0.3)
plt.ylim([0.8, 1.0])
plt.show()
三、关键要点总结
1.策略选择:
类别较少且均衡:考虑OvO
类别较多:使用OvR或Softmax
Softmax通常更直接,但需要计算所有类别的概率
2.特征工程:
逻辑回归对特征缩放敏感,务必标准化
特征间的多重共线性会影响结果
3.正则化:
参数C控制正则化强度(C越小,正则化越强)
防止过拟合的重要工具
4.模型评估:
多分类使用准确率、混淆矩阵、分类报告
考虑使用宏平均和微平均
5.注意事项:
逻辑回归假设特征与log odds线性相关
对于非线性问题,需要特征工程或使用核方法
类别不平衡时需要调整class_weight参数