记一次使用catboost训练不平衡数据

CatBoost 是一种基于梯度提升决策树的机器学习算法,它在处理类别特征方面有独特的优势,并且通常能够提供比其他梯度提升框架更好的性能。下面是一个我最近使用 SMOTE 和 CatBoost 库进行分类任务的基本示例。

需要安装的包

bash 复制代码
pip install catboost
pip install imblearn

示例代码

这里是一个简单的例子,展示如何使用 CatBoostClassifier 进行分类任务:

  • 导入必要的库。
  • 准备数据集。
  • 划分训练集和测试集。
  • 创建并训练模型。
  • 评估模型性能。

步骤 1: 导入库

python 复制代码
import pandas as pd
from sklearn.model_selection import train_test_split
from catboost import CatBoostClassifier, Pool
from sklearn.metrics import accuracy_score, classification_report, ConfusionMatrixDisplay
from imblearn.over_sampling import SMOTE
import numpy as np # 截至2024年8月15日,catboost不支持NumPy 2.0,建议使用NumPy 1.26

步骤 2: 准备数据集

假设我们有一个 XLSX 文件 data.xlsx 包含特征和标签。

python 复制代码
# 加载数据
data = pd.read_excel('data.xlsx')

# 分离特征和标签
X = data.drop('target', axis=1)
y = data['target']

# SMOTE采样
X_resampled, y_resampled = SMOTE().fit_resample(X, y)

步骤 3: 划分训练集和测试集

python 复制代码
# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X_resampled, y_resampled, test_size=0.2, random_state=42)

步骤 4: 创建并训练模型

python 复制代码
# 定义分类器
model = CatBoostClassifier(iterations=1000, learning_rate=0.1, depth=8)

# 指定类别特征的索引(如果有的话)
categorical_features_indices = np.where(X.dtypes != np.float64)[0]

# 训练模型
model.fit(
    X_train, y_train,
    cat_features=categorical_features_indices,
    eval_set=(X_test, y_test),
    verbose=False
)

步骤 5: 评估模型性能

python 复制代码
# 预测
predictions = model.predict(X_test)

# 打印结果
print("Accuracy:", accuracy_score(y_test, predictions))
print(classification_report(y_test, predictions, digits=8))
_ = ConfusionMatrixDisplay.from_estimator(model, X_test, y_test)
bash 复制代码
Accuracy: 0.9818376068376068
              precision    recall  f1-score   support

           0  0.99433798 0.96942675 0.98172436      2355
           1  0.96979866 0.99440860 0.98194946      2325

    accuracy                      0.98183761      4680
   macro avg  0.98206832 0.98191768 0.98183691      4680
weighted avg  0.98214697 0.98183761 0.98183619      4680

效果不错,但这种使用SMOTE制造数据的方式,可能存在过拟合的问题。

相关推荐
R-G-B1 小时前
OpenCV Python——Numpy基本操作(Numpy 矩阵操作、Numpy 矩阵的检索与赋值、Numpy 操作ROI)
python·opencv·numpy·numpy基本操作·numpy 矩阵操作·numpy 矩阵的检索与赋值·numpy 操作roi
细节处有神明1 小时前
Jupyter 中实现交互式图表:ipywidgets 从入门到部署
ide·python·jupyter
小小码农一只1 小时前
Python 爬虫实战:玩转 Playwright 跨浏览器自动化(Chromium/Firefox/WebKit 全支持)
爬虫·python·自动化
拾零吖1 小时前
吴恩达 Machine Learning(Class 1)
人工智能·机器学习
深盾安全2 小时前
Python脚本安全防护策略全解析(上)
python
杜子不疼.2 小时前
《Python学习之使用标准库:从入门到实战》
开发语言·python·学习
胡耀超3 小时前
从哲学(业务)视角看待数据挖掘:从认知到实践的螺旋上升
人工智能·python·数据挖掘·大模型·特征工程·crisp-dm螺旋认知·批判性思维
tomelrg3 小时前
多台服务器批量发布arcgisserver服务并缓存切片
服务器·python·arcgis
wydaicls3 小时前
用函数实现方程函数解题
人工智能·算法·机器学习
A尘埃3 小时前
Java+Python混合微服务OCR系统设计
java·python·微服务·混合