sklearn学习(6)决策树

sklearn的决策树相关在DecisiontreeClassifiter这个函数中,一般决策树节点的选择有两种,其一是基尼指数gini,其二是信息增益entropy,而基尼指数要求越小越好,信息增益越大越好。

为了预防决策树过拟合,可以在决策树中加入剪枝,又分为预剪枝和后剪枝。下面通过一个数据演示决策树

导入包

复制代码
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_iris
import matplotlib.pyplot as plt
from sklearn.tree import plot_tree
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score

加载数据并划分,决策树数据一般不需要标准化或均一化

复制代码
# 加载数据集
iris = load_iris()
X = iris.data
y = iris.target
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

创建决策树,创建了三种,无剪枝、预剪枝和后剪枝

复制代码
# 1. 创建未剪枝的决策树
dt_full = DecisionTreeClassifier(random_state=42)
dt_full.fit(X_train, y_train)
y1_pre=dt_full.predict(X_test)

# 2. 创建预剪枝的决策树
dt_pruned = DecisionTreeClassifier(
    max_depth=3,              # 限制树的最大深度
    min_samples_split=10,    # 内部节点分裂所需的最小样本数
    min_samples_leaf=5,      # 叶节点所需的最小样本数
    max_leaf_nodes=10,       # 最大叶子节点数
    random_state=42
)
dt_pruned.fit(X_train, y_train)
y2_pre=dt_pruned.predict(X_test)

# 3. 使用后剪枝(代价复杂度剪枝)
path = dt_full.cost_complexity_pruning_path(X_train, y_train)
ccp_alphas = path.ccp_alphas
dt_ccp = DecisionTreeClassifier(random_state=42, ccp_alpha=0.01)  # 设置合适的ccp_alpha值
dt_ccp.fit(X_train, y_train)
y3_pre=dt_ccp.predict(X_test)

可视化

复制代码
# 可视化不同决策树
plt.figure(figsize=(40, 10),dpi=300)

# 未剪枝的树
plt.subplot(131)
plot_tree(dt_full, filled=True, feature_names=iris.feature_names,
         class_names=iris.target_names, rounded=True)
plt.title('no_jianzhi')

# 预剪枝的树
plt.subplot(132)
plot_tree(dt_pruned, filled=True, feature_names=iris.feature_names,
         class_names=iris.target_names, rounded=True)
plt.title('yu_jianzhi')

# 后剪枝的树
plt.subplot(133)
plot_tree(dt_ccp, filled=True, feature_names=iris.feature_names,
         class_names=iris.target_names, rounded=True)
plt.title('hou_jianzhi')

plt.tight_layout()
plt.show()

性能比较

复制代码
# 比较不同模型的性能
print("未剪枝模型的训练集准确率:",accuracy_score(y_test,y1_pre))

print("预剪枝模型的测试集准确率:", accuracy_score(y_test,y2_pre))

print("后剪枝模型的测试集准确率:", accuracy_score(y_test,y3_pre))

#很多时候节点处都是非连续量,怎么办呢,要把其转化为数字来做决策树,例子如下:

复制代码
#离散特征的决策树
import pandas as pd
from sklearn.tree import DecisionTreeClassifier
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# 示例数据:假设我们要根据一些特征预测一个人的职业
data = {
    '学历': ['本科', '硕士', '博士', '本科', '硕士', '本科'],
    '专业': ['计算机', '数学', '物理', '计算机', '数学', '物理'],
    '技能': ['编程', '分析', '研究', '编程', '分析', '研究'],
    '目标': ['工程师', '研究员', '教授', '工程师', '研究员', '教授']
}

# 创建DataFrame
df = pd.DataFrame(data)

# 将汉字特征转换为数值
le_dict = {}
for column in df.columns:
    le = LabelEncoder()#创建编码器
    df[column] = le.fit_transform(df[column])#编码变成数
    le_dict[column] = le  # 保存每个特征的编码器,以便后续转换


# 准备特征和目标变量
X = df.drop('目标', axis=1) # 特征数据(去掉答案列)
y = df['目标']#只要答案列


# 分割训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)


# 创建并训练决策树模型
dt = DecisionTreeClassifier(random_state=42)
dt.fit(X_train, y_train)

#绘制
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False  # 用来正常显示负号
plt.figure(figsize=(40, 10),dpi=300)
plot_tree(dt, filled=True, feature_names=df.columns[:-1], class_names=le_dict['目标'].classes_)
plt.show()

# 进行预测
y_pred = dt.predict(X_test)

# 计算准确率
accuracy = accuracy_score(y_test, y_pred)
print(f"模型准确率: {accuracy:.2f}")

# 示例:预测新的数据
new_data = {
    '学历': ['硕士'],
    '专业': ['计算机'],
    '技能': ['编程']
}
new_df = pd.DataFrame(new_data)

# 使用之前保存的编码器转换新数据
for column in new_df.columns:
    new_df[column] = le_dict[column].transform(new_df[column])

# 进行预测
prediction = dt.predict(new_df)
predicted_label = le_dict['目标'].inverse_transform(prediction)#反变换变为汉字
print(f"预测的职业: {predicted_label[0]}")
相关推荐
冬奇Lab1 小时前
Workflow 系列(06):安全——跨步骤注入传播与四层防御
人工智能·工作流引擎
冬奇Lab1 小时前
每日一个开源项目(第149篇):RAG-Anything - 把图片、表格、公式当成一等公民的多模态 RAG 框架
人工智能·开源
米小虾1 小时前
AI Agent 安全实战指南:当智能体开始"不听话",开发者该如何应对?
人工智能·安全·agent
IT_陈寒3 小时前
Vite的热更新突然不香了,排查三小时差点砸键盘
前端·人工智能·后端
用户8356290780514 小时前
Python 实现 PDF 文件加密与解密方法
后端·python
用户8356290780514 小时前
使用 Python 冻结与拆分 Excel 窗格教程
后端·python
阿里云大数据AI技术5 小时前
构建高转化海外电商搜索:阿里云OpenSearch行业算法版的全链路智能优化策略实战
人工智能·搜索引擎
Awu12275 小时前
⚡从零开发 Agent CLI(五)实现一个可治理、可扩展的工具系统
前端·人工智能·claude
字节跳动视频云技术团队5 小时前
让 Agent 成为音视频工作台:AI MediaKit CLI + Skill 发布
人工智能·音视频开发