XGB-20:XGBoost中不同参数的预测函数

有许多在XGBoost中具有不同参数的预测函数。

预测选项

xgboost.Booster.predict() 方法有许多不同的预测选项,从 pred_contribspred_leaf 不等。输出形状取决于预测的类型。对于多类分类问题,XGBoost为每个类构建一棵树,每个类的树称为树的"组",因此输出维度可能会因所使用的模型而改变。

在1.4版本后,添加了 strict_shape 的新参数。可以将其设置为 True,以指示希望获得更受限制的输出。假设正在使用 xgboost.Booster,以下是可能的返回列表:

  • 使用 strict_shape 设置为 True 进行正常预测时:

    • 输出是一个2维数组,第一维是行数,第二维是组数 。对于回归/生存/排序/二分类,这相当于一个形状为shape[1] == 1的列向量。但对于多类别问题,使用 multi:softprob 时,列数等于类别数。如果 strict_shape 设置为 False,输出1维或2维数组
  • 使用 output_margin 避免转换且 strict_shape 设置为 True 时:

    • 输出是一个2维数组,除了 multi:softmax 由于去掉了转换而具有与 multi:softprob 相等的输出形状。如果 strict_shape 设置为 False,则输出可以具有1维或2维,具体取决于所使用的模型
  • 使用 pred_contribsstrict_shape 设置为 True 时:

    • 输出是一个3维数组,形状为(行数,组数,列数+1)。是否使用 approx_contribs 不会改变输出形状。如果未设置 strict_shape 参数,则它可以是2维或3维数组,具体取决于是否使用多类别模型
  • 使用 pred_interactionsstrict_shape 设置为 True 时:

    • 输出是一个4维数组,形状为(行数,组数,列数+1,列数+1)。是否使用 approx_contribs 不会改变输出形状。如果 strict_shape 设置为 False,则它可以具有3维或4维,具体取决于底层模型
  • 使用 pred_leafstrict_shape 设置为 True 时:

    • 输出是一个4维数组,形状为(n_samples, n_iterations, n_classes, n_trees_in_forest)。 n_trees_in_forest 在训练过程中由 num_parallel_tree 指定。当 strict_shape 设置为 False 时,输出是一个2维数组,最后3维连接成1维。如果最后一维等于1,则会删除最后一维。

对于 R 包,当指定 strict_shape 时,将返回一个数组,其值与 Python 相同, R 数组是列主序的,而 Python 的 numpy 数组是行主序的 ,因此所有维度都被颠倒。例如,对于在 strict_shape=True 的情况下通过 Python predict_leaf 获得的输出有4个维度:(n_samples, n_iterations, n_classes, n_trees_in_forest),而在 R 中 strict_shape=TRUE 的输出是 (n_trees_in_forest, n_classes, n_iterations, n_samples)。

除了这些预测类型之外,还有一个称为 iteration_range 的参数,类似于模型切片。但与实际将模型拆分为多个堆栈不同,它只是返回由范围内的树形成的预测。每次迭代创建的树的数量等于num_parallel_tree。因此,如果正在训练大小为4的增强随机森林,对于3类别分类数据集,并且想要使用前2次迭代的树进行预测,需要提供 iteration_range=(0, 2)。然后将在此预测中使用前

棵树。

提前停止Early Stopping

在使用提前停止进行训练时,原生 Python 接口和 sklearn/R 接口之间存在一种不一致的行为 。默认情况下,在 R 和 sklearn 接口上,会自动使用 best_iteration,因此预测将来自最佳模型。但是在原生 Python 接口中,xgboost.Booster.predict()xgboost.Booster.inplace_predict() 默认使用完整模型。用户可以使用 iteration_range 参数和 best_iteration 属性来实现相同的行为。此外,xgboost.callback.EarlyStoppingsave_best 参数可能会很有用。

基准分数Base Margin

XGBoost 中有一个名为 base_score 的训练参数,以及一个 DMatrix 的元数据称为 base_margin。它们指定了增强模型的全局偏差。如果提供了后者,则会忽略前者。base_margin 可用于基于其他模型训练 XGBoost 模型。

阶段性预测

使用 DMatrix 的原生接口,可以对预测进行阶段性(或缓存)。例如,可以首先对前4棵树进行预测,然后在8棵树上运行预测。在运行第一个预测后,前4棵树的结果被缓存,因此当您在8棵树上运行预测时,XGBoost 可以重复使用先前预测的结果。缓存会在下一次预测、训练或评估时自动过期,如果缓存的 DMatrix 对象已过期(例如,超出作用域并被语言环境中的垃圾回收器收集)。

阶段性预测

使用原生接口和 DMatrix,可以对预测进行阶段性(或缓存) 。例如,可以首先对前4棵树进行预测,然后在8棵树上运行预测。在运行第一个预测后,前4棵树的结果被缓存,因此当在8棵树上运行预测时,XGBoost 可以重复使用先前预测的结 果。如果缓存的 DMatrix 对象已过期(例如,超出作用域并被语言环境中的垃圾回收器收集),则缓存会在下一次预测、训练或评估时自动过期

In-place预测

传统上,XGBoost 只接受 DMatrix 进行预测,使用诸如 scikit-learn 接口之类的包装器时,构建过程会在内部发生。添加了对就地预测的支持,以绕过 DMatrix 的构建,这种构建方式速度较慢且占用内存 。新的预测函数具有有限的功能,但通常对于简单的推断任务已经足够。它接受 Python 中一些常见的数据类型,如 numpy.ndarrayscipy.sparse.csr_matrixcudf.DataFrame,而不是 xgboost.DMatrix。可以调用 xgboost.Booster.inplace_predict() 来使用它。请注意,就地预测的输出取决于输入数据类型,当输入在 GPU 数据上时,输出为 cupy.ndarray,否则返回 numpy.ndarray

线程安全

在 1.4 版本之后,所有的预测函数,包括具有各种参数的正常预测(如 shap 值计算和 inplace_predict),在底层 booster 为 gbtree 或 dart 时是线程安全的,这意味着只要使用树模型,预测本身就应该是线程安全的 。但是安全性仅在预测方面得到保证。如果尝试在一个线程中训练模型,并在另一个线程中使用相同的模型进行预测 ,则行为是未定义的。这比人们可能期望的更容易发生,例如可能会在预测函数内部意外地调用 clf.set_params()

python 复制代码
def predict_fn(clf: xgb.XGBClassifier, X):
    X = preprocess(X)
    clf.set_params(n_jobs=1)  # NOT safe!
    return clf.predict_proba(X, iteration_range=(0, 10))

with ThreadPoolExecutor(max_workers=10) as e:
    e.submit(predict_fn, ...)

隐私保护预测

Concrete ML 是由 Zama 开发的第三方开源库,提供了类似于梯度提升类,但直接在加密数据上进行预测的功能,这得益于全同态加密。一个简单的例子如下:

python 复制代码
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from concrete.ml.sklearn import XGBClassifier

x, y = make_classification(n_samples=100, class_sep=2, n_features=30, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(
    x, y, test_size=10, random_state=42
)

# Train in the clear and quantize the weights
model = XGBClassifier()
model.fit(X_train, y_train)

# Simulate the predictions in the clear
y_pred_clear = model.predict(X_test)

# Compile in FHE
model.compile(X_train)

# Generate keys
model.fhe_circuit.keygen()

# Run the inference on encrypted inputs!
y_pred_fhe = model.predict(X_test, fhe="execute")

print("In clear:", y_pred_clear)
print("In FHE:", y_pred_fhe)
print(f"Similarity: {int((y_pred_fhe == y_pred_clear).mean()*100)}%")

参考

相关推荐
南宫理的日知录18 分钟前
99、Python并发编程:多线程的问题、临界资源以及同步机制
开发语言·python·学习·编程学习
coberup27 分钟前
django Forbidden (403)错误解决方法
python·django·403错误
龙哥说跨境1 小时前
如何利用指纹浏览器爬虫绕过Cloudflare的防护?
服务器·网络·python·网络爬虫
小白学大数据1 小时前
正则表达式在Kotlin中的应用:提取图片链接
开发语言·python·selenium·正则表达式·kotlin
flashman9111 小时前
python在word中插入图片
python·microsoft·自动化·word
菜鸟的人工智能之路1 小时前
桑基图在医学数据分析中的更复杂应用示例
python·数据分析·健康医疗
懒大王爱吃狼2 小时前
Python教程:python枚举类定义和使用
开发语言·前端·javascript·python·python基础·python编程·python书籍
秃头佛爷4 小时前
Python学习大纲总结及注意事项
开发语言·python·学习
深度学习lover5 小时前
<项目代码>YOLOv8 苹果腐烂识别<目标检测>
人工智能·python·yolo·目标检测·计算机视觉·苹果腐烂识别
API快乐传递者6 小时前
淘宝反爬虫机制的主要手段有哪些?
爬虫·python