如何在Sklearn Pipeline中运行CatBoost

介绍

CatBoost的一大特点是可以很好的处理类别特征(Categorical Features)。当我们将其结合到Sklearn的Pipeline中时,会发生如下报错:

shell 复制代码
_catboost.CatBoostError: 'data' is numpy array of floating point numerical type, it means no categorical features, but 'cat_features' parameter specifies nonzero number of categorical features

因为CatBoost需要检查输入训练数据pandas.DataFrame中对应的cat_features。如果我们使用Pipeline后,输入给.fit()的数据是被修改过的,DataFrame中的columns的名字变为了数字。

解决方案

我们提前在数据上使用Pipeline,然后将原始数据转换为Pipeline处理后的数据,然后检索出其中包含的类别特征,将其传输给Catboost。

python 复制代码
# define your pipeline
pipeline = Pipeline(steps=[
    ('preprocessor', preprocessor),
    ('classifier', model),
])

preprocessor.fit(X_train)
transformed_X_train = pd.DataFrame(preprocessor.transform(X_train)).convert_dtypes()

new_cat_feature_idx = [transformed_X_train.columns.get_loc(col) for col in transformed_X_train.select_dtypes(include=['int64', 'bool']).columns]

pipeline.fit(X_train, y_train, classifier__cat_features=new_cat_feature_idx)
相关推荐
新知图书几秒前
工作分解结构辅助生成(使用千问)
人工智能·千问·高效办公
love530love5 分钟前
ComfyUI MediaPipe 终极填坑:解决 incompatible function arguments 报错,基于代理模式的猴子补丁升级版
人工智能·windows·comfyui·mediapipe·猴子补丁·monkey patch·python 3.12
dingzd958 分钟前
Facebook强化原创内容分发后跨境品牌如何重做素材策略
大数据·人工智能·新媒体运营·内容营销·跨境
卢子墨10 分钟前
Hermes Agent + 钉钉适配文档(重点解决图片引用识别问题)
人工智能·aigc·harness
2301_8159019714 分钟前
SQL如何将多行记录聚合成逗号分隔字符串_GROUP_CONCAT技巧
jvm·数据库·python
西索斯15 分钟前
Claude API 报 529 Overloaded 怎么办?3 种方案实测,最后一种最省心
python·claude
小民AI实战笔记15 分钟前
GitHub Actions + 钉钉,半小时搭个免费的热榜推送机器人
人工智能·aigc·ai编程
Flittly15 分钟前
【LangGraph新手村系列】(3)PostgreSQL 持久化检查点:让状态跨越进程与重启
人工智能·python·langchain
麦芽糖021917 分钟前
大模型二 Agent入门实战(AI私厨)
人工智能
.柒宇.19 分钟前
FastAPI 基础指南:从入门到实战
开发语言·python·fastapi