Catboost 不能做多变量回归?

问题:

代码

python 复制代码
# CatBoost Regression Model
from catboost import CatBoostRegressor

model = CatBoostRegressor(
        iterations = 2000, 
        learning_rate = 0.1,
        l2_leaf_reg = 3,  
        depth = 9,
        rsm = 1,
        border_count=128,
        bagging_temperature= 10,
        verbose=False,
        loss_function='MultiRMSE',
    
                            )
    

model.fit(X_train, y_train)
y_pred = model.predict(X_test)

报错

bash 复制代码
Currently only multi-regression, multilabel and survival objectives work with multidimensional target 

原因

Catboost 默认用的回归的损失是RMSE,所以针对多变量的回归任务/分类任务需要更改损失函数

解决方法

python 复制代码
model = CatBoostRegressor(
        iterations = 2000, 
        learning_rate = 0.1,
        l2_leaf_reg = 3,  
        depth = 9,
        rsm = 1,
        border_count=128,
        bagging_temperature= 10,
        verbose=False,
        loss_function='MultiRMSE', ## 针对多变量任务更改损失函数改成MultiRMSE 或者 MultiLogit
    
                            )
相关推荐
dundunmm23 分钟前
【每天一个知识点】训推一体机
人工智能·大模型·硬件·软件·训练·推理
johnny2331 小时前
OCR、文档解析工具合集(下)
人工智能
Moshow郑锴3 小时前
实践题:智能客服机器人设计
人工智能·机器人·智能客服
2501_924889554 小时前
商超高峰客流统计误差↓75%!陌讯多模态融合算法在智慧零售的实战解析
大数据·人工智能·算法·计算机视觉·零售
jingfeng5144 小时前
C++模板进阶
java·c++·算法
维基框架5 小时前
维基框架 (Wiki Framework) 1.1.0 版本发布 提供多模型AI辅助开发
人工智能
西猫雷婶5 小时前
神经网络|(十二)概率论基础知识-先验/后验/似然概率基本概念
人工智能·神经网络·机器学习·回归·概率论
地平线开发者5 小时前
征程 6X | 常用工具介绍
算法·自动驾驶
地平线开发者5 小时前
理想汽车智驾方案介绍 2|MindVLA 方案详解
算法·自动驾驶
计算机编程小咖6 小时前
《基于大数据的农产品交易数据分析与可视化系统》选题不当,毕业答辩可能直接挂科
java·大数据·hadoop·python·数据挖掘·数据分析·spark