机器学习:阈值与混淆矩阵
一、混淆矩阵(Confusion Matrix)
1.1 什么是混淆矩阵
混淆矩阵是评估分类模型性能的核心工具,它将模型的预测结果与真实标签进行对比,以矩阵形式呈现。
以二分类问题为例(正类 = 1,负类 = 0):
预测值
正类(1) 负类(0)
真实值 正类(1) | TP | FN |
负类(0) | FP | TN |
| 术语 | 全称 | 含义 |
|---|---|---|
| TP | True Positive | 真正例:预测为正,实际也为正 ✅ |
| TN | True Negative | 真负例:预测为负,实际也为负 ✅ |
| FP | False Positive | 假正例:预测为正,实际为负(误报)❌ |
| FN | False Negative | 假负例:预测为负,实际为正(漏报)❌ |
1.2 具体示例
场景:用模型预测100封邮件是否为垃圾邮件
实际垃圾邮件:60封
实际正常邮件:40封
预测结果:
正确识别垃圾邮件(TP)= 50
正常邮件误判为垃圾(FP)= 5
垃圾邮件漏判为正常(FN)= 10
正确识别正常邮件(TN)= 35
混淆矩阵:
预测垃圾 预测正常
实际垃圾 | 50 | 10 |
实际正常 | 5 | 35 |
二、核心评估指标
2.1 指标公式汇总
准确率 (Accuracy)=TP+TNTP+TN+FP+FN\text{准确率 (Accuracy)} = \frac{TP + TN}{TP + TN + FP + FN}准确率 (Accuracy)=TP+TN+FP+FNTP+TN
精确率 (Precision)=TPTP+FP\text{精确率 (Precision)} = \frac{TP}{TP + FP}精确率 (Precision)=TP+FPTP
召回率 (Recall)=TPTP+FN\text{召回率 (Recall)} = \frac{TP}{TP + FN}召回率 (Recall)=TP+FNTP
F1 Score=2×Precision×RecallPrecision+Recall\text{F1 Score} = 2 \times \frac{Precision \times Recall}{Precision + Recall}F1 Score=2×Precision+RecallPrecision×Recall
假正率 (FPR)=FPFP+TN\text{假正率 (FPR)} = \frac{FP}{FP + TN}假正率 (FPR)=FP+TNFP
2.2 各指标解读
| 指标 | 计算结果(邮件示例) | 意义 |
|---|---|---|
| 准确率 | (50+35)/100 = 85% | 整体预测正确的比例 |
| 精确率 | 50/(50+5) = 90.9% | 预测为正的样本中真正是正的比例 |
| 召回率 | 50/(50+10) = 83.3% | 所有真实正样本中被正确识别的比例 |
| F1 Score | 2×(0.909×0.833)/(0.909+0.833) = 87% | 精确率和召回率的调和平均 |
2.3 精确率 vs 召回率的权衡
精确率高 → 少误报(宁可漏掉,不乱报)
召回率高 → 少漏报(宁可误报,不能漏)
医疗诊断(癌症检测):召回率优先 → 不能漏掉一个病人
垃圾邮件过滤:精确率优先 → 不能误删重要邮件
三、分类阈值(Classification Threshold)
3.1 什么是阈值
大多数分类模型输出的是概率值 (0~1之间),而非直接的类别标签。
阈值(Threshold)决定了:概率超过多少才判定为正类?
模型输出概率:0.73
默认阈值:0.5
0.73 > 0.5 → 预测为正类(1)
0.73 < 0.8 → 如果阈值调高到0.8,则预测为负类(0)
3.2 阈值对混淆矩阵的影响
以肿瘤检测为例,同一模型,不同阈值:
阈值 = 0.3(较低)
→ 更容易判定为正类
→ TP ↑, FP ↑, FN ↓
→ 召回率 ↑, 精确率 ↓
阈值 = 0.7(较高)
→ 更难判定为正类
→ TP ↓, FP ↓, FN ↑
→ 召回率 ↓, 精确率 ↑
| 阈值 | TP | FP | FN | TN | 精确率 | 召回率 |
|---|---|---|---|---|---|---|
| 0.3 | 58 | 18 | 2 | 22 | 76.3% | 96.7% |
| 0.5 | 50 | 5 | 10 | 35 | 90.9% | 83.3% |
| 0.7 | 40 | 2 | 20 | 38 | 95.2% | 66.7% |
3.3 如何选择最优阈值
方法一:根据业务需求手动指定
python
# 医疗场景优先保召回率
threshold = 0.3
# 推荐系统优先保精确率
threshold = 0.7
方法二:Youden 指数(最大化 TPR - FPR)
最优阈值 = argmax(Sensitivity + Specificity - 1)
= argmax(TPR - FPR)
方法三:F1 Score 最大化
python
from sklearn.metrics import f1_score
import numpy as np
thresholds = np.arange(0.1, 1.0, 0.01)
f1_scores = [f1_score(y_true, probs >= t) for t in thresholds]
best_threshold = thresholds[np.argmax(f1_scores)]
四、ROC 曲线与 AUC
4.1 ROC 曲线
ROC(Receiver Operating Characteristic)曲线通过遍历所有阈值,绘制:
-
X 轴:假正率 FPR = FP / (FP + TN)
-
Y 轴:真正率 TPR(召回率) = TP / (TP + FN)
TPR
1.0 | .........
| ..
| ..
| ..
0.5 | .
|.
0.0 |___________
0.0 0.5 1.0 FPR理想模型:曲线越靠近左上角越好
随机猜测:对角线(AUC = 0.5)
4.2 AUC(Area Under Curve)
AUC 是 ROC 曲线下的面积,值域 [0, 1]:
| AUC 值 | 模型质量 |
|---|---|
| 1.0 | 完美模型 |
| 0.9~1.0 | 优秀 |
| 0.8~0.9 | 良好 |
| 0.7~0.8 | 一般 |
| 0.5~0.7 | 较差 |
| 0.5 | 随机猜测 |
4.3 PR 曲线(Precision-Recall Curve)
当数据严重不平衡时(如欺诈检测,正样本极少),ROC 曲线可能过于乐观,此时应使用 PR 曲线:
-
X 轴:召回率(Recall)
-
Y 轴:精确率(Precision)
Precision
1.0 |...
| ...
0.5 | ...
| ...
0.0 |________________
0.0 0.5 1.0 RecallAP(Average Precision)= PR 曲线下面积,越大越好
五、多分类混淆矩阵
对于 K 个类别的分类问题,混淆矩阵扩展为 K×K:
预测猫 预测狗 预测鸟
实际猫 | 45 | 3 | 2 |
实际狗 | 2 | 48 | 0 |
实际鸟 | 1 | 1 | 48 |
多分类评估方式(对每个类分别计算后汇总):
python
# macro: 各类指标算术平均(不考虑样本量)
# micro: 全局统计 TP/FP/FN 再计算(受大类主导)
# weighted: 按各类样本量加权平均
from sklearn.metrics import classification_report
print(classification_report(y_true, y_pred))
六、Python 实战代码
6.1 绘制混淆矩阵
python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
# 生成数据
X, y = make_classification(n_samples=1000, n_features=20, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
# 训练模型
model = LogisticRegression()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
# 绘制混淆矩阵
cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['负类', '正类'])
disp.plot(cmap='Blues')
plt.title('混淆矩阵')
plt.show()
6.2 阈值调整与指标变化
python
from sklearn.metrics import precision_score, recall_score, f1_score
y_prob = model.predict_proba(X_test)[:, 1] # 正类概率
thresholds = np.arange(0.1, 1.0, 0.05)
results = []
for t in thresholds:
y_pred_t = (y_prob >= t).astype(int)
results.append({
'threshold': t,
'precision': precision_score(y_test, y_pred_t, zero_division=0),
'recall': recall_score(y_test, y_pred_t, zero_division=0),
'f1': f1_score(y_test, y_pred_t, zero_division=0),
})
# 找到 F1 最优阈值
best = max(results, key=lambda x: x['f1'])
print(f"最优阈值: {best['threshold']:.2f}, F1: {best['f1']:.4f}")
6.3 绘制 ROC 曲线
python
from sklearn.metrics import roc_curve, auc
fpr, tpr, thresholds = roc_curve(y_test, y_prob)
roc_auc = auc(fpr, tpr)
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2,
label=f'ROC 曲线 (AUC = {roc_auc:.3f})')
plt.plot([0, 1], [0, 1], color='navy', lw=1, linestyle='--', label='随机猜测')
plt.xlabel('假正率 (FPR)')
plt.ylabel('真正率 (TPR)')
plt.title('ROC 曲线')
plt.legend(loc='lower right')
plt.grid(alpha=0.3)
plt.show()
6.4 绘制 PR 曲线
python
from sklearn.metrics import precision_recall_curve, average_precision_score
precision, recall, _ = precision_recall_curve(y_test, y_prob)
ap = average_precision_score(y_test, y_prob)
plt.figure(figsize=(8, 6))
plt.plot(recall, precision, color='steelblue', lw=2,
label=f'PR 曲线 (AP = {ap:.3f})')
plt.xlabel('召回率 (Recall)')
plt.ylabel('精确率 (Precision)')
plt.title('Precision-Recall 曲线')
plt.legend(loc='upper right')
plt.grid(alpha=0.3)
plt.show()
七、常见误区与注意事项
7.1 准确率陷阱
数据集:1000条,其中正类10条,负类990条
模型全部预测为负类:
准确率 = 990/1000 = 99% ← 看起来很高!
但召回率 = 0/10 = 0% ← 完全没用!
解决:使用 F1/AUC 等更全面的指标,或对数据做平衡处理
7.2 阈值不是固定的 0.5
默认阈值 0.5 只在类别平衡 且误报/漏报代价相等 时合适。
实际业务中应根据:
- 正负样本比例
- 误报与漏报的业务代价
- 实际数据分布
来调整阈值。
7.3 过度优化单一指标
不要只优化一个指标,要结合业务场景综合评估。
八、指标选择速查表
| 场景 | 推荐指标 | 原因 |
|---|---|---|
| 类别均衡、代价相等 | Accuracy | 简单直观 |
| 误报代价高(垃圾邮件) | Precision | 减少误报 |
| 漏报代价高(疾病检测) | Recall | 减少漏报 |
| 综合平衡 | F1 Score | 精确率/召回率的折中 |
| 类别严重不平衡 | PR-AUC / F1 | ROC 会过于乐观 |
| 模型排序能力 | ROC-AUC | 与阈值无关,全局衡量 |
| 多类别分类 | Macro/Weighted F1 | 综合各类别表现 |
九、总结
┌────────────────────────────────────────────────────────┐
│ 分类评估全景图 │
│ │
│ 原始输出:概率值 (0~1) │
│ ↓ 阈值选择 │
│ 二值预测:0 或 1 │
│ ↓ 与真实标签对比 │
│ 混淆矩阵:TP / TN / FP / FN │
│ ↓ 衍生指标 │
│ Accuracy / Precision / Recall / F1 │
│ ↓ 遍历所有阈值 │
│ ROC 曲线 → AUC │
│ PR 曲线 → AP │
└────────────────────────────────────────────────────────┘
理解阈值与混淆矩阵的核心在于:根据业务场景的成本权衡,选择合适的阈值和评估指标,而不是盲目追求某一个数字。