XGB-25:Callback函数

本文档提供了XGBoost Python包中使用的回调API的基本概述。在XGBoost 1.3中,为Python包设计了一个新的回调接口,它为设计各种扩展提供了灵活性,用于训练。此外,XGBoost还预定义了许多回调函数,用于支持提前停止early stopping、检查点checkpoints等。

使用内置回调函数

默认情况下,XGBoost 中的训练方法具有参数,如 early_stopping_roundsverbose/verbose_eval,当指定这些参数时,训练过程将在内部定义相应的回调函数。例如,当指定了 early_stopping_rounds 时,EarlyStopping 回调将在迭代循环内调用。也可以直接将此回调函数传递给 XGBoost:

python 复制代码
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

import xgboost as xgb
import numpy as np

X, y = load_breast_cancer(return_X_y=True)
X_train, X_valid, y_train, y_valid = train_test_split(X, y, stratify=y, random_state=94)

D_train = xgb.DMatrix(X_train, y_train)
D_valid = xgb.DMatrix(X_valid, y_valid)

# Define a custom evaluation metric used for early stopping.
def eval_error_metric(predt, dtrain: xgb.DMatrix):
    label = dtrain.get_label()
    r = np.zeros(predt.shape)
    gt = predt > 0.5
    r[gt] = 1 - label[gt]
    le = predt <= 0.5
    r[le] = label[le]
    return 'CustomErr', np.sum(r)

# Specify which dataset and which metric should be used for early stopping.
early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds,
                                        metric_name='CustomErr',
                                        data_name='Train')

booster = xgb.train(
    {'objective': 'binary:logistic',
     'eval_metric': ['error', 'rmse'],
     'tree_method': 'hist'}, D_train,
    evals=[(D_train, 'Train'), (D_valid, 'Valid')],
    feval=eval_error_metric,
    num_boost_round=1000,
    callbacks=[early_stop],
    verbose_eval=False)

dump = booster.get_dump(dump_format='json')
assert len(early_stop.stopping_history['Train']['CustomErr']) == len(dump)

定义自己的回调函数

XGBoost提供了一个回调接口类:TrainingCallback,用户定义的回调应该继承这个类并覆盖相应的方法。在示例中有使用和定义回调函数的工作示例。

python 复制代码
import argparse
import os
import tempfile
from typing import Dict

import numpy as np
from matplotlib import pyplot as plt
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

import xgboost as xgb


class Plotting(xgb.callback.TrainingCallback):
    """Plot evaluation result during training.  Only for demonstration purpose as it's
    quite slow to draw using matplotlib.

    """

    def __init__(self, rounds: int) -> None:
        self.fig = plt.figure()
        self.ax = self.fig.add_subplot(111)
        self.rounds = rounds
        self.lines: Dict[str, plt.Line2D] = {}
        self.fig.show()
        self.x = np.linspace(0, self.rounds, self.rounds)
        plt.ion()

    def _get_key(self, data: str, metric: str) -> str:
        return f"{data}-{metric}"

    def after_iteration(
        self, model: xgb.Booster, epoch: int, evals_log: Dict[str, dict]
    ) -> bool:
        """Update the plot."""
        if not self.lines:
            for data, metric in evals_log.items():
                for metric_name, log in metric.items():
                    key = self._get_key(data, metric_name)
                    expanded = log + [0] * (self.rounds - len(log))
                    (self.lines[key],) = self.ax.plot(self.x, expanded, label=key)
                    self.ax.legend()
        else:
            # https://pythonspot.com/matplotlib-update-plot/
            for data, metric in evals_log.items():
                for metric_name, log in metric.items():
                    key = self._get_key(data, metric_name)
                    expanded = log + [0] * (self.rounds - len(log))
                    self.lines[key].set_ydata(expanded)
            self.fig.canvas.draw()
        # False to indicate training should not stop.
        return False


def custom_callback() -> None:
    """Demo for defining a custom callback function that plots evaluation result during
    training."""
    X, y = load_breast_cancer(return_X_y=True)
    X_train, X_valid, y_train, y_valid = train_test_split(X, y, random_state=0)

    D_train = xgb.DMatrix(X_train, y_train)
    D_valid = xgb.DMatrix(X_valid, y_valid)

    num_boost_round = 100
    plotting = Plotting(num_boost_round)

    # Pass it to the `callbacks` parameter as a list.
    xgb.train(
        {
            "objective": "binary:logistic",
            "eval_metric": ["error", "rmse"],
            "tree_method": "hist",
            "device": "cuda",
        },
        D_train,
        evals=[(D_train, "Train"), (D_valid, "Valid")],
        num_boost_round=num_boost_round,
        callbacks=[plotting],
    )


def check_point_callback() -> None:
    """Demo for using the checkpoint callback. Custom logic for handling output is
    usually required and users are encouraged to define their own callback for
    checkpointing operations. The builtin one can be used as a starting point.

    """
    # Only for demo, set a larger value (like 100) in practice as checkpointing is quite
    # slow.
    rounds = 2

    def check(as_pickle: bool) -> None:
        for i in range(0, 10, rounds):
            if i == 0:
                continue
            if as_pickle:
                path = os.path.join(tmpdir, "model_" + str(i) + ".pkl")
            else:
                path = os.path.join(
                    tmpdir,
                    f"model_{i}.{xgb.callback.TrainingCheckPoint.default_format}",
                )
            assert os.path.exists(path)

    X, y = load_breast_cancer(return_X_y=True)
    m = xgb.DMatrix(X, y)
    # Check point to a temporary directory for demo
    with tempfile.TemporaryDirectory() as tmpdir:
        # Use callback class from xgboost.callback
        # Feel free to subclass/customize it to suit your need.
        check_point = xgb.callback.TrainingCheckPoint(
            directory=tmpdir, interval=rounds, name="model"
        )
        xgb.train(
            {"objective": "binary:logistic"},
            m,
            num_boost_round=10,
            verbose_eval=False,
            callbacks=[check_point],
        )
        check(False)

        # This version of checkpoint saves everything including parameters and
        # model.  See: doc/tutorials/saving_model.rst
        check_point = xgb.callback.TrainingCheckPoint(
            directory=tmpdir, interval=rounds, as_pickle=True, name="model"
        )
        xgb.train(
            {"objective": "binary:logistic"},
            m,
            num_boost_round=10,
            verbose_eval=False,
            callbacks=[check_point],
        )
        check(True)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--plot", default=1, type=int)
    args = parser.parse_args()

    check_point_callback()

    if args.plot:
        custom_callback()

参考

相关推荐
ZZHow10242 小时前
02OpenCV基本操作
python·opencv·计算机视觉
计算机学长felix2 小时前
基于Django的“酒店推荐系统”设计与开发(源码+数据库+文档+PPT)
数据库·python·mysql·django·vue
站大爷IP2 小时前
Python随机数函数全解析:5个核心工具的实战指南
python
悟乙己2 小时前
使用 Python 中的强化学习最大化简单 RAG 性能
开发语言·python·agent·rag·n8n
max5006002 小时前
图像处理:实现多图点重叠效果
开发语言·图像处理·人工智能·python·深度学习·音视频
AI原吾2 小时前
玩转物联网只需十行代码,可它为何悄悄停止维护
python·物联网·hbmqtt
云动雨颤2 小时前
Python单元测试入门:3个核心断言方法,帮你快速定位代码bug
python·单元测试
SunnyDays10113 小时前
Python 实现 HTML 转 Word 和 PDF
python·html转word·html转pdf·html转docx·html转doc
跟橙姐学代码3 小时前
Python异常处理:告别程序崩溃,让代码更优雅!
前端·python·ipython
蓝纹绿茶4 小时前
Python程序使用了Ffmpeg,结束程序后,文件夹中仍然生成音频、视频文件
python·ubuntu·ffmpeg·音视频