记一次使用catboost训练不平衡数据

CatBoost 是一种基于梯度提升决策树的机器学习算法,它在处理类别特征方面有独特的优势,并且通常能够提供比其他梯度提升框架更好的性能。下面是一个我最近使用 SMOTE 和 CatBoost 库进行分类任务的基本示例。

需要安装的包

bash 复制代码
pip install catboost
pip install imblearn

示例代码

这里是一个简单的例子,展示如何使用 CatBoostClassifier 进行分类任务:

  • 导入必要的库。
  • 准备数据集。
  • 划分训练集和测试集。
  • 创建并训练模型。
  • 评估模型性能。

步骤 1: 导入库

python 复制代码
import pandas as pd
from sklearn.model_selection import train_test_split
from catboost import CatBoostClassifier, Pool
from sklearn.metrics import accuracy_score, classification_report, ConfusionMatrixDisplay
from imblearn.over_sampling import SMOTE
import numpy as np # 截至2024年8月15日,catboost不支持NumPy 2.0,建议使用NumPy 1.26

步骤 2: 准备数据集

假设我们有一个 XLSX 文件 data.xlsx 包含特征和标签。

python 复制代码
# 加载数据
data = pd.read_excel('data.xlsx')

# 分离特征和标签
X = data.drop('target', axis=1)
y = data['target']

# SMOTE采样
X_resampled, y_resampled = SMOTE().fit_resample(X, y)

步骤 3: 划分训练集和测试集

python 复制代码
# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(X_resampled, y_resampled, test_size=0.2, random_state=42)

步骤 4: 创建并训练模型

python 复制代码
# 定义分类器
model = CatBoostClassifier(iterations=1000, learning_rate=0.1, depth=8)

# 指定类别特征的索引(如果有的话)
categorical_features_indices = np.where(X.dtypes != np.float64)[0]

# 训练模型
model.fit(
    X_train, y_train,
    cat_features=categorical_features_indices,
    eval_set=(X_test, y_test),
    verbose=False
)

步骤 5: 评估模型性能

python 复制代码
# 预测
predictions = model.predict(X_test)

# 打印结果
print("Accuracy:", accuracy_score(y_test, predictions))
print(classification_report(y_test, predictions, digits=8))
_ = ConfusionMatrixDisplay.from_estimator(model, X_test, y_test)
bash 复制代码
Accuracy: 0.9818376068376068
              precision    recall  f1-score   support

           0  0.99433798 0.96942675 0.98172436      2355
           1  0.96979866 0.99440860 0.98194946      2325

    accuracy                      0.98183761      4680
   macro avg  0.98206832 0.98191768 0.98183691      4680
weighted avg  0.98214697 0.98183761 0.98183619      4680

效果不错,但这种使用SMOTE制造数据的方式,可能存在过拟合的问题。

相关推荐
weixin_527550407 分钟前
初级程序员入门指南
javascript·python·算法
程序员的世界你不懂27 分钟前
Appium+python自动化(十)- 元素定位
python·appium·自动化
CryptoPP1 小时前
使用WebSocket实时获取印度股票数据源(无调用次数限制)实战
后端·python·websocket·网络协议·区块链
树叶@1 小时前
Python数据分析7
开发语言·python
老胖闲聊2 小时前
Python Rio 【图像处理】库简介
开发语言·图像处理·python
码界奇点3 小时前
Python Flask文件处理与异常处理实战指南
开发语言·python·自然语言处理·flask·python3.11
浠寒AI3 小时前
智能体模式篇(上)- 深入 ReAct:LangGraph构建能自主思考与行动的 AI
人工智能·python
行云流水剑4 小时前
【学习记录】如何使用 Python 提取 PDF 文件中的内容
python·学习·pdf
心扬5 小时前
python生成器
开发语言·python
mouseliu5 小时前
python之二:docker部署项目
前端·python