【漫话机器学习系列】125.普拉托变换(Platt Scaling)

普拉托变换(Platt Scaling)详解

1. 什么是普拉托变换(Platt Scaling)?

普拉托变换(Platt Scaling)是一种用于二分类支持向量机(SVM)的后处理方法,它的作用是将 SVM 的输出分数转换为概率值。由于 SVM 本身并不直接输出概率,而是输出一个决策值(即 f(x)),因此我们需要使用某种方法将这个决策值映射到 [0,1] 范围的概率值。普拉托变换正是通过拟合一个 Sigmoid 函数 来完成这个任务。


2. 线性支持向量机的输出问题

在 SVM 训练完成后,它会为每个输入样本 x 计算一个 决策函数值

这个值表示样本相对于决策边界的距离。一般来说:

  • 如果 f(x) > 0,则分类为正类( y = 1 )。
  • 如果 f(x) < 0,则分类为负类( y = 0 )。

但这个值只是一个距离,而不是概率。例如,f(x) = 2 和 f(x) = 100 都表示正类,但哪个更有可能是正类呢?这个问题就是 SVM 无法直接提供概率输出的原因。因此,我们需要一个方法将 SVM 的输出转换为概率。


3. 普拉托变换的基本思想

普拉托变换的核心思想是 使用 Sigmoid 函数拟合 SVM 的输出,使其变成概率

其中:

  • f(x) 是 SVM 计算得到的决策值(即得分)。
  • A 和 B 是两个需要拟合的参数。

换句话说,我们希望通过训练找到 A 和 B,使得 SVM 的输出 f(x) 能够很好地拟合样本的真实概率分布。


4. 如何训练 A 和 B?

训练 A 和 B 的过程如下:

  1. 在数据集上训练 SVM ,获取所有样本的 得分 f(x) 及其真实类别 y。
  2. 使用交叉验证 构建一个 逻辑回归模型 ,以 f(x) 为输入,y 为输出,拟合 Sigmoid 函数 的参数A 和 B。
  3. 计算损失函数
    • 普拉托变换使用 最大似然估计(MLE) 进行参数拟合。

    • 具体来说,最小化以下 交叉熵损失

    • 这个公式与 逻辑回归的损失函数 一样,因此可以使用标准的优化方法(如梯度下降)来求解 A 和 B。


5. 为什么要使用普拉托变换?

普拉托变换是 SVM 最常用 的概率校准方法,原因如下:

  • 简单高效:只需要额外训练一个小的逻辑回归模型(拟合 Sigmoid),计算开销较小。
  • 适用于 SVM 及其他分类器:虽然最初是为 SVM 提出的,但 Platt Scaling 也可用于其他分类器(如神经网络、随机森林等)。
  • 提高模型可解释性:许多应用(如医学、金融等)需要输出概率,而不仅仅是分类结果。

6. Platt Scaling 的局限性

尽管普拉托变换在很多场景下都能有效校准概率输出,但它也有一些局限:

  1. 对样本不均衡较敏感:如果数据集的类别不均衡,拟合的 Sigmoid 可能会偏向多数类。
  2. 假设数据符合 Sigmoid 分布:Platt Scaling 假设 SVM 输出的决策值可以用 Sigmoid 函数拟合,但在某些复杂分布下,可能效果不好。
  3. 需要额外的数据:Platt Scaling 需要使用交叉验证数据来训练 A 和 B,这可能会浪费一部分训练数据。

如果数据集较大、类别不均衡或者 Sigmoid 拟合效果不好,通常可以考虑 Isotonic Regression(等温回归) 作为替代方法。


7. 代码实现(Python 示例)

在实际应用中,Scikit-learn 提供了 Platt Scaling 的实现,示例如下:

python 复制代码
from sklearn.svm import SVC
from sklearn.calibration import CalibratedClassifierCV
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

# 生成数据
X, y = make_classification(n_samples=1000, n_features=20, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# 训练 SVM(不带概率)
svm = SVC(kernel='linear', probability=False)
svm.fit(X_train, y_train)

# 使用 Platt Scaling 校准概率
svm_calibrated = CalibratedClassifierCV(svm, method='sigmoid', cv=5)
svm_calibrated.fit(X_train, y_train)

# 预测概率
probabilities = svm_calibrated.predict_proba(X_test)

print("前 5 个样本的预测概率:")
print(probabilities[:5])

运行结果

python 复制代码
前 5 个样本的预测概率:
[[0.39216467 0.60783533]
 [0.19236017 0.80763983]
 [0.51329367 0.48670633]
 [0.24076157 0.75923843]
 [0.0697871  0.9302129 ]]

8. 结论

普拉托变换(Platt Scaling)是一种 简单有效的概率校准方法 ,特别适用于 SVM 这种不直接输出概率的分类器。它通过 拟合 Sigmoid 函数,将分类得分映射到 [0,1] 范围,使得输出更具可解释性。

在实际应用中,我们可以:

  • 需要概率输出的任务(如医疗诊断、金融风控)中使用 Platt Scaling。
  • 如果数据 类别不均衡 ,可以尝试 Isotonic Regression 作为替代方案。

希望这篇文章能够帮助你更好地理解 Platt Scaling 的原理和应用!

相关推荐
泡芙萝莉酱12 分钟前
各省份发电量数据(2005-2022年)-社科数据
大数据·人工智能·深度学习·数据挖掘·数据分析·毕业论文·数据统计
threelab13 分钟前
02.three官方示例+编辑器+AI快速学习webgl_animation_skinning_blending
人工智能·学习·编辑器
wei_shuo3 小时前
OB Cloud 云数据库V4.3:SQL +AI全新体验
数据库·人工智能·sql
努力的搬砖人.3 小时前
AI生成视频推荐
人工智能
想要成为计算机高手4 小时前
Helix:一种用于通用人形控制的视觉语言行动模型
人工智能·计算机视觉·自然语言处理·大模型·vla
Mory_Herbert4 小时前
5.1 神经网络: 层和块
人工智能·深度学习·神经网络
Evand J5 小时前
MATLAB程序演示与编程思路,相对导航,四个小车的形式,使用集中式扩展卡尔曼滤波(fullyCN-EKF)
人工智能·算法
知来者逆6 小时前
在与大语言模型交互中的礼貌现象:技术影响、社会行为与文化意义的多维度探讨
人工智能·深度学习·语言模型·自然语言处理·llm
xwz小王子8 小时前
Taccel:一个高性能的GPU加速视触觉机器人模拟平台
人工智能·机器人
深空数字孪生9 小时前
AI时代的数据可视化:未来已来
人工智能·信息可视化