记一次使用catboost训练不平衡数据

CatBoost 是一种基于梯度提升决策树的机器学习算法,它在处理类别特征方面有独特的优势,并且通常能够提供比其他梯度提升框架更好的性能。下面是一个我最近使用 SMOTE 和 CatBoost 库进行分类任务的基本示例。

需要安装的包

bash 复制代码
pip install catboost
pip install imblearn

示例代码

这里是一个简单的例子,展示如何使用 CatBoostClassifier 进行分类任务:

  • 导入必要的库。
  • 准备数据集。
  • 划分训练集和测试集。
  • 创建并训练模型。
  • 评估模型性能。

步骤 1: 导入库

python 复制代码
import pandas as pd
from sklearn.model_selection import train_test_split
from catboost import CatBoostClassifier, Pool
from sklearn.metrics import accuracy_score, classification_report, ConfusionMatrixDisplay
from imblearn.over_sampling import SMOTE
import numpy as np # 截至2024年8月15日,catboost不支持NumPy 2.0,建议使用NumPy 1.26

步骤 2: 准备数据集

假设我们有一个 XLSX 文件 data.xlsx 包含特征和标签。

python 复制代码
# 加载数据
data = pd.read_excel('data.xlsx')

# 分离特征和标签
X = data.drop('target', axis=1)
y = data['target']

# SMOTE采样
X_resampled, y_resampled = SMOTE().fit_resample(X, y)

步骤 3: 划分训练集和测试集

python 复制代码
# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X_resampled, y_resampled, test_size=0.2, random_state=42)

步骤 4: 创建并训练模型

python 复制代码
# 定义分类器
model = CatBoostClassifier(iterations=1000, learning_rate=0.1, depth=8)

# 指定类别特征的索引(如果有的话)
categorical_features_indices = np.where(X.dtypes != np.float64)[0]

# 训练模型
model.fit(
    X_train, y_train,
    cat_features=categorical_features_indices,
    eval_set=(X_test, y_test),
    verbose=False
)

步骤 5: 评估模型性能

python 复制代码
# 预测
predictions = model.predict(X_test)

# 打印结果
print("Accuracy:", accuracy_score(y_test, predictions))
print(classification_report(y_test, predictions, digits=8))
_ = ConfusionMatrixDisplay.from_estimator(model, X_test, y_test)
bash 复制代码
Accuracy: 0.9818376068376068
              precision    recall  f1-score   support

           0  0.99433798 0.96942675 0.98172436      2355
           1  0.96979866 0.99440860 0.98194946      2325

    accuracy                      0.98183761      4680
   macro avg  0.98206832 0.98191768 0.98183691      4680
weighted avg  0.98214697 0.98183761 0.98183619      4680

效果不错,但这种使用SMOTE制造数据的方式,可能存在过拟合的问题。

相关推荐
苏打水com13 分钟前
0基础学前端:100天拿offer实战课(第3天)—— CSS基础美化:给网页“精装修”的5大核心技巧
人工智能·python·tensorflow
顾安r41 分钟前
11.5 脚本 本地网站收藏(解封归来)
linux·服务器·c语言·python·bash
Blossom.1181 小时前
把AI“贴”进路灯柱:1KB决策树让老旧路灯自己报「灯头松动」
java·人工智能·python·深度学习·算法·决策树·机器学习
❀͜͡傀儡师1 小时前
快速定位并解决Java应用CPU占用过高问题
java·开发语言·python
极客BIM工作室1 小时前
ControlNet:Adding Conditional Control to Text-to-Image Diffusion Models
人工智能·深度学习·机器学习
linuxxx1101 小时前
django中request.GET.urlencode的使用
后端·python·django
冬天vs不冷1 小时前
Java基础(十五):注解(Annotation)详解
android·java·python
汤姆yu1 小时前
基于大数据的全国降水可视化分析预测系统
大数据·开发语言·python
星空的资源小屋3 小时前
Text Grab,一款OCR 截图文字识别工具
python·django·ocr·scikit-learn
寒秋丶3 小时前
Milvus:Json字段详解(十)
数据库·人工智能·python·ai·milvus·向量数据库·rag