如何在Sklearn Pipeline中运行CatBoost

介绍

CatBoost的一大特点是可以很好的处理类别特征(Categorical Features)。当我们将其结合到Sklearn的Pipeline中时,会发生如下报错:

shell 复制代码
_catboost.CatBoostError: 'data' is numpy array of floating point numerical type, it means no categorical features, but 'cat_features' parameter specifies nonzero number of categorical features

因为CatBoost需要检查输入训练数据pandas.DataFrame中对应的cat_features。如果我们使用Pipeline后,输入给.fit()的数据是被修改过的,DataFrame中的columns的名字变为了数字。

解决方案

我们提前在数据上使用Pipeline,然后将原始数据转换为Pipeline处理后的数据,然后检索出其中包含的类别特征,将其传输给Catboost。

python 复制代码
# define your pipeline
pipeline = Pipeline(steps=[
    ('preprocessor', preprocessor),
    ('classifier', model),
])

preprocessor.fit(X_train)
transformed_X_train = pd.DataFrame(preprocessor.transform(X_train)).convert_dtypes()

new_cat_feature_idx = [transformed_X_train.columns.get_loc(col) for col in transformed_X_train.select_dtypes(include=['int64', 'bool']).columns]

pipeline.fit(X_train, y_train, classifier__cat_features=new_cat_feature_idx)
相关推荐
IT_陈寒26 分钟前
5个Java 21新特性实战技巧,让你的代码性能飙升200%!
前端·人工智能·后端
dlraba80239 分钟前
YOLOv3:目标检测领域的经典之作
人工智能·yolo·目标检测
科新数智1 小时前
破解商家客服困局:真人工AI回复如何成为转型核心
人工智能·#agent #智能体
你才是向阳花1 小时前
如何用python来做小游戏
开发语言·python·pygame
'需尽欢'2 小时前
基于 Flask+Vue+MySQL的研学网站
python·mysql·flask
szxinmai主板定制专家3 小时前
【NI测试方案】基于ARM+FPGA的整车仿真与电池标定
arm开发·人工智能·yolo·fpga开发
新子y3 小时前
【小白笔记】最大交换 (Maximum Swap)问题
笔记·python
ygyqinghuan4 小时前
读懂目标检测
人工智能·目标检测·目标跟踪
华东数交4 小时前
企业与国有数据资产:入表全流程管理及资产化闭环理论解析
大数据·人工智能
程序员爱钓鱼4 小时前
Python编程实战 · 基础入门篇 | Python的缩进与代码块
后端·python