机器学习:阈值与混淆矩阵

机器学习:阈值与混淆矩阵


一、混淆矩阵(Confusion Matrix)

1.1 什么是混淆矩阵

混淆矩阵是评估分类模型性能的核心工具,它将模型的预测结果与真实标签进行对比,以矩阵形式呈现。

以二分类问题为例(正类 = 1,负类 = 0):

复制代码
                    预测值
                  正类(1)   负类(0)
真实值  正类(1)  |   TP   |   FN   |
        负类(0)  |   FP   |   TN   |
术语 全称 含义
TP True Positive 真正例:预测为正,实际也为正 ✅
TN True Negative 真负例:预测为负,实际也为负 ✅
FP False Positive 假正例:预测为正,实际为负(误报)❌
FN False Negative 假负例:预测为负,实际为正(漏报)❌

1.2 具体示例

场景:用模型预测100封邮件是否为垃圾邮件

复制代码
实际垃圾邮件:60封
实际正常邮件:40封

预测结果:
  正确识别垃圾邮件(TP)= 50
  正常邮件误判为垃圾(FP)= 5
  垃圾邮件漏判为正常(FN)= 10
  正确识别正常邮件(TN)= 35

混淆矩阵:

复制代码
              预测垃圾   预测正常
实际垃圾  |    50    |    10    |
实际正常  |     5    |    35    |

二、核心评估指标

2.1 指标公式汇总

准确率 (Accuracy)=TP+TNTP+TN+FP+FN\text{准确率 (Accuracy)} = \frac{TP + TN}{TP + TN + FP + FN}准确率 (Accuracy)=TP+TN+FP+FNTP+TN

精确率 (Precision)=TPTP+FP\text{精确率 (Precision)} = \frac{TP}{TP + FP}精确率 (Precision)=TP+FPTP

召回率 (Recall)=TPTP+FN\text{召回率 (Recall)} = \frac{TP}{TP + FN}召回率 (Recall)=TP+FNTP

F1 Score=2×Precision×RecallPrecision+Recall\text{F1 Score} = 2 \times \frac{Precision \times Recall}{Precision + Recall}F1 Score=2×Precision+RecallPrecision×Recall

假正率 (FPR)=FPFP+TN\text{假正率 (FPR)} = \frac{FP}{FP + TN}假正率 (FPR)=FP+TNFP

2.2 各指标解读

指标 计算结果(邮件示例) 意义
准确率 (50+35)/100 = 85% 整体预测正确的比例
精确率 50/(50+5) = 90.9% 预测为正的样本中真正是正的比例
召回率 50/(50+10) = 83.3% 所有真实正样本中被正确识别的比例
F1 Score 2×(0.909×0.833)/(0.909+0.833) = 87% 精确率和召回率的调和平均

2.3 精确率 vs 召回率的权衡

复制代码
精确率高 → 少误报(宁可漏掉,不乱报)
召回率高 → 少漏报(宁可误报,不能漏)

医疗诊断(癌症检测):召回率优先 → 不能漏掉一个病人
垃圾邮件过滤:精确率优先 → 不能误删重要邮件

三、分类阈值(Classification Threshold)

3.1 什么是阈值

大多数分类模型输出的是概率值 (0~1之间),而非直接的类别标签。

阈值(Threshold)决定了:概率超过多少才判定为正类?

复制代码
模型输出概率:0.73
默认阈值:0.5

0.73 > 0.5  →  预测为正类(1)
0.73 < 0.8  →  如果阈值调高到0.8,则预测为负类(0)

3.2 阈值对混淆矩阵的影响

以肿瘤检测为例,同一模型,不同阈值:

复制代码
阈值 = 0.3(较低)
  → 更容易判定为正类
  → TP ↑, FP ↑, FN ↓
  → 召回率 ↑, 精确率 ↓

阈值 = 0.7(较高)
  → 更难判定为正类
  → TP ↓, FP ↓, FN ↑
  → 召回率 ↓, 精确率 ↑
阈值 TP FP FN TN 精确率 召回率
0.3 58 18 2 22 76.3% 96.7%
0.5 50 5 10 35 90.9% 83.3%
0.7 40 2 20 38 95.2% 66.7%

3.3 如何选择最优阈值

方法一:根据业务需求手动指定

python 复制代码
# 医疗场景优先保召回率
threshold = 0.3

# 推荐系统优先保精确率
threshold = 0.7

方法二:Youden 指数(最大化 TPR - FPR)

复制代码
最优阈值 = argmax(Sensitivity + Specificity - 1)
         = argmax(TPR - FPR)

方法三:F1 Score 最大化

python 复制代码
from sklearn.metrics import f1_score
import numpy as np

thresholds = np.arange(0.1, 1.0, 0.01)
f1_scores = [f1_score(y_true, probs >= t) for t in thresholds]
best_threshold = thresholds[np.argmax(f1_scores)]

四、ROC 曲线与 AUC

4.1 ROC 曲线

ROC(Receiver Operating Characteristic)曲线通过遍历所有阈值,绘制:

  • X 轴:假正率 FPR = FP / (FP + TN)

  • Y 轴:真正率 TPR(召回率) = TP / (TP + FN)

    TPR
    1.0 | .........
    | ..
    | ..
    | ..
    0.5 | .
    |.
    0.0 |___________
    0.0 0.5 1.0 FPR

    理想模型:曲线越靠近左上角越好
    随机猜测:对角线(AUC = 0.5)

4.2 AUC(Area Under Curve)

AUC 是 ROC 曲线下的面积,值域 [0, 1]:

AUC 值 模型质量
1.0 完美模型
0.9~1.0 优秀
0.8~0.9 良好
0.7~0.8 一般
0.5~0.7 较差
0.5 随机猜测

4.3 PR 曲线(Precision-Recall Curve)

当数据严重不平衡时(如欺诈检测,正样本极少),ROC 曲线可能过于乐观,此时应使用 PR 曲线:

  • X 轴:召回率(Recall)

  • Y 轴:精确率(Precision)

    Precision
    1.0 |...
    | ...
    0.5 | ...
    | ...
    0.0 |________________
    0.0 0.5 1.0 Recall

    AP(Average Precision)= PR 曲线下面积,越大越好


五、多分类混淆矩阵

对于 K 个类别的分类问题,混淆矩阵扩展为 K×K:

复制代码
              预测猫   预测狗   预测鸟
实际猫   |    45   |    3    |    2   |
实际狗   |     2   |   48    |    0   |
实际鸟   |     1   |    1    |   48   |

多分类评估方式(对每个类分别计算后汇总):

python 复制代码
# macro: 各类指标算术平均(不考虑样本量)
# micro: 全局统计 TP/FP/FN 再计算(受大类主导)
# weighted: 按各类样本量加权平均

from sklearn.metrics import classification_report
print(classification_report(y_true, y_pred))

六、Python 实战代码

6.1 绘制混淆矩阵

python 复制代码
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

# 生成数据
X, y = make_classification(n_samples=1000, n_features=20, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# 训练模型
model = LogisticRegression()
model.fit(X_train, y_train)
y_pred = model.predict(X_test)

# 绘制混淆矩阵
cm = confusion_matrix(y_test, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['负类', '正类'])
disp.plot(cmap='Blues')
plt.title('混淆矩阵')
plt.show()

6.2 阈值调整与指标变化

python 复制代码
from sklearn.metrics import precision_score, recall_score, f1_score

y_prob = model.predict_proba(X_test)[:, 1]  # 正类概率

thresholds = np.arange(0.1, 1.0, 0.05)
results = []

for t in thresholds:
    y_pred_t = (y_prob >= t).astype(int)
    results.append({
        'threshold': t,
        'precision': precision_score(y_test, y_pred_t, zero_division=0),
        'recall':    recall_score(y_test, y_pred_t, zero_division=0),
        'f1':        f1_score(y_test, y_pred_t, zero_division=0),
    })

# 找到 F1 最优阈值
best = max(results, key=lambda x: x['f1'])
print(f"最优阈值: {best['threshold']:.2f}, F1: {best['f1']:.4f}")

6.3 绘制 ROC 曲线

python 复制代码
from sklearn.metrics import roc_curve, auc

fpr, tpr, thresholds = roc_curve(y_test, y_prob)
roc_auc = auc(fpr, tpr)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2,
         label=f'ROC 曲线 (AUC = {roc_auc:.3f})')
plt.plot([0, 1], [0, 1], color='navy', lw=1, linestyle='--', label='随机猜测')
plt.xlabel('假正率 (FPR)')
plt.ylabel('真正率 (TPR)')
plt.title('ROC 曲线')
plt.legend(loc='lower right')
plt.grid(alpha=0.3)
plt.show()

6.4 绘制 PR 曲线

python 复制代码
from sklearn.metrics import precision_recall_curve, average_precision_score

precision, recall, _ = precision_recall_curve(y_test, y_prob)
ap = average_precision_score(y_test, y_prob)

plt.figure(figsize=(8, 6))
plt.plot(recall, precision, color='steelblue', lw=2,
         label=f'PR 曲线 (AP = {ap:.3f})')
plt.xlabel('召回率 (Recall)')
plt.ylabel('精确率 (Precision)')
plt.title('Precision-Recall 曲线')
plt.legend(loc='upper right')
plt.grid(alpha=0.3)
plt.show()

七、常见误区与注意事项

7.1 准确率陷阱

复制代码
数据集:1000条,其中正类10条,负类990条
模型全部预测为负类:
  准确率 = 990/1000 = 99% ← 看起来很高!
  但召回率 = 0/10 = 0%   ← 完全没用!

解决:使用 F1/AUC 等更全面的指标,或对数据做平衡处理

7.2 阈值不是固定的 0.5

默认阈值 0.5 只在类别平衡误报/漏报代价相等 时合适。

实际业务中应根据:

  • 正负样本比例
  • 误报与漏报的业务代价
  • 实际数据分布

来调整阈值。

7.3 过度优化单一指标

不要只优化一个指标,要结合业务场景综合评估。


八、指标选择速查表

场景 推荐指标 原因
类别均衡、代价相等 Accuracy 简单直观
误报代价高(垃圾邮件) Precision 减少误报
漏报代价高(疾病检测) Recall 减少漏报
综合平衡 F1 Score 精确率/召回率的折中
类别严重不平衡 PR-AUC / F1 ROC 会过于乐观
模型排序能力 ROC-AUC 与阈值无关,全局衡量
多类别分类 Macro/Weighted F1 综合各类别表现

九、总结

复制代码
┌────────────────────────────────────────────────────────┐
│                    分类评估全景图                        │
│                                                        │
│  原始输出:概率值 (0~1)                                  │
│      ↓  阈值选择                                        │
│  二值预测:0 或 1                                        │
│      ↓  与真实标签对比                                   │
│  混淆矩阵:TP / TN / FP / FN                            │
│      ↓  衍生指标                                        │
│  Accuracy / Precision / Recall / F1                    │
│      ↓  遍历所有阈值                                     │
│  ROC 曲线 → AUC                                         │
│  PR 曲线  → AP                                          │
└────────────────────────────────────────────────────────┘

理解阈值与混淆矩阵的核心在于:根据业务场景的成本权衡,选择合适的阈值和评估指标,而不是盲目追求某一个数字。

相关推荐
鱼骨不是鱼翅2 小时前
机器学习(1)-----基础概念
人工智能·机器学习
xiao5kou4chang6kai42 小时前
蒸散发与光合作用阻抗理论 → ArcGIS自动化 → 区域ET/GPP产品融合
人工智能·蒸散发·植被生产力估算·penman-monteith
cd_949217212 小时前
骁龙与F1的故事:一场连接与速度的深度对话
人工智能
新加坡内哥谈技术2 小时前
大语言模型的上下文工程指南
人工智能
Gofarlic_OMS2 小时前
装备制造企业Fluent许可证成本分点典型案例
java·大数据·开发语言·人工智能·自动化·制造
2501_948114242 小时前
DeepSeek V4 全面实测:万亿参数开源模型的工程落地与成本推演
人工智能·ai·开源
程序员雷欧2 小时前
大模型应用开发学习第八天
大数据·人工智能·学习
liukuang1102 小时前
伊利、蒙牛、飞鹤与光明乳业:存量时代的攻守之道与价值分化
大数据·人工智能·物联网
前进的李工2 小时前
LangChain使用AI工具赋能:解锁大语言模型无限潜力
开发语言·人工智能·语言模型·langchain·大模型