【漫话机器学习系列】125.普拉托变换(Platt Scaling)

普拉托变换(Platt Scaling)详解

1. 什么是普拉托变换(Platt Scaling)?

普拉托变换(Platt Scaling)是一种用于二分类支持向量机(SVM)的后处理方法,它的作用是将 SVM 的输出分数转换为概率值。由于 SVM 本身并不直接输出概率,而是输出一个决策值(即 f(x)),因此我们需要使用某种方法将这个决策值映射到 [0,1] 范围的概率值。普拉托变换正是通过拟合一个 Sigmoid 函数 来完成这个任务。


2. 线性支持向量机的输出问题

在 SVM 训练完成后,它会为每个输入样本 x 计算一个 决策函数值

这个值表示样本相对于决策边界的距离。一般来说:

  • 如果 f(x) > 0,则分类为正类( y = 1 )。
  • 如果 f(x) < 0,则分类为负类( y = 0 )。

但这个值只是一个距离,而不是概率。例如,f(x) = 2 和 f(x) = 100 都表示正类,但哪个更有可能是正类呢?这个问题就是 SVM 无法直接提供概率输出的原因。因此,我们需要一个方法将 SVM 的输出转换为概率。


3. 普拉托变换的基本思想

普拉托变换的核心思想是 使用 Sigmoid 函数拟合 SVM 的输出,使其变成概率

其中:

  • f(x) 是 SVM 计算得到的决策值(即得分)。
  • A 和 B 是两个需要拟合的参数。

换句话说,我们希望通过训练找到 A 和 B,使得 SVM 的输出 f(x) 能够很好地拟合样本的真实概率分布。


4. 如何训练 A 和 B?

训练 A 和 B 的过程如下:

  1. 在数据集上训练 SVM ,获取所有样本的 得分 f(x) 及其真实类别 y。
  2. 使用交叉验证 构建一个 逻辑回归模型 ,以 f(x) 为输入,y 为输出,拟合 Sigmoid 函数 的参数A 和 B。
  3. 计算损失函数
    • 普拉托变换使用 最大似然估计(MLE) 进行参数拟合。

    • 具体来说,最小化以下 交叉熵损失

    • 这个公式与 逻辑回归的损失函数 一样,因此可以使用标准的优化方法(如梯度下降)来求解 A 和 B。


5. 为什么要使用普拉托变换?

普拉托变换是 SVM 最常用 的概率校准方法,原因如下:

  • 简单高效:只需要额外训练一个小的逻辑回归模型(拟合 Sigmoid),计算开销较小。
  • 适用于 SVM 及其他分类器:虽然最初是为 SVM 提出的,但 Platt Scaling 也可用于其他分类器(如神经网络、随机森林等)。
  • 提高模型可解释性:许多应用(如医学、金融等)需要输出概率,而不仅仅是分类结果。

6. Platt Scaling 的局限性

尽管普拉托变换在很多场景下都能有效校准概率输出,但它也有一些局限:

  1. 对样本不均衡较敏感:如果数据集的类别不均衡,拟合的 Sigmoid 可能会偏向多数类。
  2. 假设数据符合 Sigmoid 分布:Platt Scaling 假设 SVM 输出的决策值可以用 Sigmoid 函数拟合,但在某些复杂分布下,可能效果不好。
  3. 需要额外的数据:Platt Scaling 需要使用交叉验证数据来训练 A 和 B,这可能会浪费一部分训练数据。

如果数据集较大、类别不均衡或者 Sigmoid 拟合效果不好,通常可以考虑 Isotonic Regression(等温回归) 作为替代方法。


7. 代码实现(Python 示例)

在实际应用中,Scikit-learn 提供了 Platt Scaling 的实现,示例如下:

python 复制代码
from sklearn.svm import SVC
from sklearn.calibration import CalibratedClassifierCV
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

# 生成数据
X, y = make_classification(n_samples=1000, n_features=20, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 训练 SVM(不带概率)
svm = SVC(kernel='linear', probability=False)
svm.fit(X_train, y_train)

# 使用 Platt Scaling 校准概率
svm_calibrated = CalibratedClassifierCV(svm, method='sigmoid', cv=5)
svm_calibrated.fit(X_train, y_train)

# 预测概率
probabilities = svm_calibrated.predict_proba(X_test)

print("前 5 个样本的预测概率:")
print(probabilities[:5])

运行结果

python 复制代码
前 5 个样本的预测概率:
[[0.39216467 0.60783533]
 [0.19236017 0.80763983]
 [0.51329367 0.48670633]
 [0.24076157 0.75923843]
 [0.0697871  0.9302129 ]]

8. 结论

普拉托变换(Platt Scaling)是一种 简单有效的概率校准方法 ,特别适用于 SVM 这种不直接输出概率的分类器。它通过 拟合 Sigmoid 函数,将分类得分映射到 [0,1] 范围,使得输出更具可解释性。

在实际应用中,我们可以:

  • 需要概率输出的任务(如医疗诊断、金融风控)中使用 Platt Scaling。
  • 如果数据 类别不均衡 ,可以尝试 Isotonic Regression 作为替代方案。

希望这篇文章能够帮助你更好地理解 Platt Scaling 的原理和应用!

相关推荐
GEO索引未来3 分钟前
大胆预测:国家会这样对GEO行业进行监管
大数据·人工智能·gpt·ai·chatgpt
闵孚龙5 分钟前
Prompt工程到底怎么做?从“会提问”到“能落地”的完整方法论
人工智能·prompt
AI人工智能+6 分钟前
文档抽取系统通过OCR与大语言模型融合技术,将非结构化文档(如合同、保单、表格)自动转换为结构化数据
人工智能·语言模型·ocr·文档抽取
深海鱼在掘金7 分钟前
深入浅出 LangChain —— 第十四章:可观测性与生产运维
人工智能·langchain·agent
生物信息与育种10 分钟前
实战总结:用 rMVP 做植物 GWAS 的标准工作流与避坑指南
人工智能·深度学习·职场和发展·数据分析·r语言
嵌入式小企鹅10 分钟前
大模型算法工程师面试宝典
人工智能·学习·算法·面试·职场和发展·大模型·面经
小仙女的小稀罕17 分钟前
会议转行动项处理,AI对比原生工具有何效率差异
人工智能
逻辑君23 分钟前
认知神经科学研究报告【20260030】
人工智能·神经网络·机器学习
java1234_小锋36 分钟前
能让你的 AI 编程 Token 降低 60% 以上的开源神器:目前 GitHub 狂揽约 4.2 万星标
人工智能·github·ai编程
sanshanjianke37 分钟前
AI辅助网文创作理论研究笔记(十二):L1.5——情节编排层
人工智能·ai写作