【漫话机器学习系列】125.普拉托变换(Platt Scaling)

普拉托变换(Platt Scaling)详解

1. 什么是普拉托变换(Platt Scaling)?

普拉托变换(Platt Scaling)是一种用于二分类支持向量机(SVM)的后处理方法,它的作用是将 SVM 的输出分数转换为概率值。由于 SVM 本身并不直接输出概率,而是输出一个决策值(即 f(x)),因此我们需要使用某种方法将这个决策值映射到 [0,1] 范围的概率值。普拉托变换正是通过拟合一个 Sigmoid 函数 来完成这个任务。


2. 线性支持向量机的输出问题

在 SVM 训练完成后,它会为每个输入样本 x 计算一个 决策函数值

这个值表示样本相对于决策边界的距离。一般来说:

  • 如果 f(x) > 0,则分类为正类( y = 1 )。
  • 如果 f(x) < 0,则分类为负类( y = 0 )。

但这个值只是一个距离,而不是概率。例如,f(x) = 2 和 f(x) = 100 都表示正类,但哪个更有可能是正类呢?这个问题就是 SVM 无法直接提供概率输出的原因。因此,我们需要一个方法将 SVM 的输出转换为概率。


3. 普拉托变换的基本思想

普拉托变换的核心思想是 使用 Sigmoid 函数拟合 SVM 的输出,使其变成概率

其中:

  • f(x) 是 SVM 计算得到的决策值(即得分)。
  • A 和 B 是两个需要拟合的参数。

换句话说,我们希望通过训练找到 A 和 B,使得 SVM 的输出 f(x) 能够很好地拟合样本的真实概率分布。


4. 如何训练 A 和 B?

训练 A 和 B 的过程如下:

  1. 在数据集上训练 SVM ,获取所有样本的 得分 f(x) 及其真实类别 y。
  2. 使用交叉验证 构建一个 逻辑回归模型 ,以 f(x) 为输入,y 为输出,拟合 Sigmoid 函数 的参数A 和 B。
  3. 计算损失函数
    • 普拉托变换使用 最大似然估计(MLE) 进行参数拟合。

    • 具体来说,最小化以下 交叉熵损失

    • 这个公式与 逻辑回归的损失函数 一样,因此可以使用标准的优化方法(如梯度下降)来求解 A 和 B。


5. 为什么要使用普拉托变换?

普拉托变换是 SVM 最常用 的概率校准方法,原因如下:

  • 简单高效:只需要额外训练一个小的逻辑回归模型(拟合 Sigmoid),计算开销较小。
  • 适用于 SVM 及其他分类器:虽然最初是为 SVM 提出的,但 Platt Scaling 也可用于其他分类器(如神经网络、随机森林等)。
  • 提高模型可解释性:许多应用(如医学、金融等)需要输出概率,而不仅仅是分类结果。

6. Platt Scaling 的局限性

尽管普拉托变换在很多场景下都能有效校准概率输出,但它也有一些局限:

  1. 对样本不均衡较敏感:如果数据集的类别不均衡,拟合的 Sigmoid 可能会偏向多数类。
  2. 假设数据符合 Sigmoid 分布:Platt Scaling 假设 SVM 输出的决策值可以用 Sigmoid 函数拟合,但在某些复杂分布下,可能效果不好。
  3. 需要额外的数据:Platt Scaling 需要使用交叉验证数据来训练 A 和 B,这可能会浪费一部分训练数据。

如果数据集较大、类别不均衡或者 Sigmoid 拟合效果不好,通常可以考虑 Isotonic Regression(等温回归) 作为替代方法。


7. 代码实现(Python 示例)

在实际应用中,Scikit-learn 提供了 Platt Scaling 的实现,示例如下:

python 复制代码
from sklearn.svm import SVC
from sklearn.calibration import CalibratedClassifierCV
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

# 生成数据
X, y = make_classification(n_samples=1000, n_features=20, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 训练 SVM(不带概率)
svm = SVC(kernel='linear', probability=False)
svm.fit(X_train, y_train)

# 使用 Platt Scaling 校准概率
svm_calibrated = CalibratedClassifierCV(svm, method='sigmoid', cv=5)
svm_calibrated.fit(X_train, y_train)

# 预测概率
probabilities = svm_calibrated.predict_proba(X_test)

print("前 5 个样本的预测概率:")
print(probabilities[:5])

运行结果

python 复制代码
前 5 个样本的预测概率:
[[0.39216467 0.60783533]
 [0.19236017 0.80763983]
 [0.51329367 0.48670633]
 [0.24076157 0.75923843]
 [0.0697871  0.9302129 ]]

8. 结论

普拉托变换(Platt Scaling)是一种 简单有效的概率校准方法 ,特别适用于 SVM 这种不直接输出概率的分类器。它通过 拟合 Sigmoid 函数,将分类得分映射到 [0,1] 范围,使得输出更具可解释性。

在实际应用中,我们可以:

  • 需要概率输出的任务(如医疗诊断、金融风控)中使用 Platt Scaling。
  • 如果数据 类别不均衡 ,可以尝试 Isotonic Regression 作为替代方案。

希望这篇文章能够帮助你更好地理解 Platt Scaling 的原理和应用!

相关推荐
白-胖-子3 小时前
深入剖析大模型在文本生成式 AI 产品架构中的核心地位
人工智能·架构
想要成为计算机高手4 小时前
11. isaacsim4.2教程-Transform 树与Odometry
人工智能·机器人·自动驾驶·ros·rviz·isaac sim·仿真环境
NeoFii5 小时前
Day 22: 复习
机器学习
静心问道5 小时前
InstructBLIP:通过指令微调迈向通用视觉-语言模型
人工智能·多模态·ai技术应用
宇称不守恒4.06 小时前
2025暑期—06神经网络-常见网络2
网络·人工智能·神经网络
小楓12016 小时前
醫護行業在未來會被AI淘汰嗎?
人工智能·醫療·護理·職業
数据与人工智能律师6 小时前
数字迷雾中的安全锚点:解码匿名化与假名化的法律边界与商业价值
大数据·网络·人工智能·云计算·区块链
chenchihwen6 小时前
大模型应用班-第2课 DeepSeek使用与提示词工程课程重点 学习ollama 安装 用deepseek-r1:1.5b 分析PDF 内容
人工智能·学习
说私域6 小时前
公域流量向私域流量转化策略研究——基于开源AI智能客服、AI智能名片与S2B2C商城小程序的融合应用
人工智能·小程序
Java樱木6 小时前
AI 编程工具 Trae 重要的升级。。。
人工智能