Catboost 不能做多变量回归?

问题:

代码

python 复制代码
# CatBoost Regression Model
from catboost import CatBoostRegressor

model = CatBoostRegressor(
        iterations = 2000, 
        learning_rate = 0.1,
        l2_leaf_reg = 3,  
        depth = 9,
        rsm = 1,
        border_count=128,
        bagging_temperature= 10,
        verbose=False,
        loss_function='MultiRMSE',
    
                            )
    

model.fit(X_train, y_train)
y_pred = model.predict(X_test)

报错

bash 复制代码
Currently only multi-regression, multilabel and survival objectives work with multidimensional target 

原因

Catboost 默认用的回归的损失是RMSE,所以针对多变量的回归任务/分类任务需要更改损失函数

解决方法

python 复制代码
model = CatBoostRegressor(
        iterations = 2000, 
        learning_rate = 0.1,
        l2_leaf_reg = 3,  
        depth = 9,
        rsm = 1,
        border_count=128,
        bagging_temperature= 10,
        verbose=False,
        loss_function='MultiRMSE', ## 针对多变量任务更改损失函数改成MultiRMSE 或者 MultiLogit
    
                            )
相关推荐
SharkWeek.38 分钟前
【力扣Hot 100】普通数组2
数据结构·算法·leetcode
ZStack开发者社区2 小时前
AI应用、轻量云、虚拟化|云轴科技ZStack参编金融行标与报告
人工智能·科技·金融
真想骂*4 小时前
人工智能如何重塑音频、视觉及多模态领域的应用格局
人工智能·音视频
赛丽曼6 小时前
机器学习-K近邻算法
人工智能·机器学习·近邻算法
啊波次得饿佛哥7 小时前
7. 计算机视觉
人工智能·计算机视觉·视觉检测
XianxinMao8 小时前
RLHF技术应用探析:从安全任务到高阶能力提升
人工智能·python·算法
hefaxiang8 小时前
【C++】函数重载
开发语言·c++·算法
Swift社区8 小时前
【分布式日志篇】从工具选型到实战部署:全面解析日志采集与管理路径
人工智能·spring boot·分布式
Quz9 小时前
OpenCV:高通滤波之索贝尔、沙尔和拉普拉斯
图像处理·人工智能·opencv·计算机视觉·矩阵
去往火星9 小时前
OpenCV文字绘制支持中文显示
人工智能·opencv·计算机视觉