Day 14 训练

Day 14 训练

  • SHAP(SHapley Additive exPlanations)
  • 1.创建解释器
  • 2.将特征贡献可视化
      • 第一部分:绘制SHAP特征重要性条形图
      • 第二部分:绘制SHAP特征重要性蜂巢图

SHAP(SHapley Additive exPlanations)

旨在解释复杂机器学习模型(如随机森林、梯度提升树、神经网络等 "黑箱" 模型)对特定输入的预测原因。其核心基于合作博弈论中的 Shapley 值。

  • 将模型的特征比作玩家,预测样本输出值是游戏目标,不同特征子集合作进行预测,特征子集预测得到的值是奖励 / 价值。通过计算每个特征的 Shapley 值来确定其对预测的贡献,具体是考虑所有可能的特征组合,计算特征在每种组合下的边际贡献,再求加权平均。
  • SHAP 具有加性解释特性,模型预测值等于基准值(模型在训练或背景数据集上的平均预测输出)加上所有特征的 SHAP 值之和。
  • SHAP 需要为每个样本的每个特征计算贡献值即 SHAP 值,形成 shap_values 数组。对于回归问题,shap_values 是形状为(n_samples,n_features)的数组;对于分类问题,通常返回一个列表,列表长度等于类别数,每个元素是(n_samples,n_features)数组,表示各特征对预测各类别的贡献。总之,SHAP 通过计算特征边际贡献,将模型预测分解到每个特征上,生成 shap_values 数组来解释预测。

1.创建解释器

bash 复制代码
import shap
import matplotlib.pyplot as plt

explainer = shap.TreeExplainer(rf_model)
shap_values = explainer.shap_values(X_test)
print(shap_values)
print(shap_values.shape) # 第一维是是样本数,第二维度是特征数量,第三维度是类别数量。
print("shap_values shape:", shap_values.shape)
print("shap_values[0] shape:", shap_values[0].shape)
print("shap_values[:, :, 0] shape:", shap_values[:, :, 0].shape)
print("X_test shape:", X_test.shape)

创建解释器对象

  • explainer = shap.TreeExplainer(rf_model):创建一个 SHAP 解释器对象。shap.TreeExplainer 是 SHAP 库中用于解释基于树的模型(如随机森林、梯度提升树等)的解释器类。rf_model 是一个已经训练好的随机森林模型对象,将其传递给 TreeExplainer,解释器就会根据该模型的结构和参数来计算特征的 SHAP 值。通过这个解释器对象,我们可以进一步获取模型预测的解释信息。

计算 SHAP 值

  • shap_values = explainer.shap_values(X_test):计算测试数据集 X_test 的 SHAP 值。X_test 是一个二维数组,包含了测试样本的特征值。explainer.shap_values 方法会根据之前创建的解释器对象和测试数据集,计算出每个特征对于每个测试样本的 SHAP 值,并将结果存储在 shap_values 变量中。SHAP 值是一个与特征数量相同的数组,其中每个元素表示一个特征对模型预测的贡献度。正的 SHAP 值表示该特征对预测结果有正向影响,负的 SHAP 值表示有负向影响。通过这些 SHAP 值,我们可以分析出哪些特征对模型的预测结果起到了关键作用,以及它们是如何影响预测结果的。

2.将特征贡献可视化

bash 复制代码
print("1.shap 特征重要性条形图")
shap.summary_plot(shap_values[:,:,0], X_test, plot_type="bar",show=False)
plt.title("SHAP Feature Importance (Class 0)")
plt.show()

print("--- 2. SHAP 特征重要性蜂巢图 ---")
shap.summary_plot(shap_values[:, :, 0], X_test,plot_type="violin",show=False,max_display=10) # 这里的show=False表示不直接显示图形,这样可以继续用plt来修改元素,不然就直接输出了
plt.title("SHAP Feature Importance (Violin Plot)")
plt.show()

这段代码是用于绘制SHAP(SHapley Additive exPlanations)特征重要性图的Python代码,主要使用了shap库和matplotlib库。下面对这段代码进行逐行解释:

第一部分:绘制SHAP特征重要性条形图

python 复制代码
print("1.shap 特征重要性条形图")
  • 这行代码会在控制台打印一条消息,提示接下来将绘制SHAP特征重要性条形图。
python 复制代码
shap.summary_plot(shap_values[:,:,0], X_test, plot_type="bar", show=False)
  • shap.summary_plotshap 库中的一个函数,用于绘制特征重要性图。
  • shap_values[:,:,0]shap_values 是一个三维数组,这里取其第一个维度的所有值,通常对应于模型的某个类别(这里是类别0)。
  • X_test:这是测试数据集,用于提供特征名称和数据范围等信息。
  • plot_type="bar":指定绘制条形图。
  • show=False:表示不直接显示图形,这样可以在绘制完图形后继续使用 matplotlib 修改图形元素。

第二部分:绘制SHAP特征重要性蜂巢图

python 复制代码
print("--- 2. SHAP 特征重要性蜂巢图 ---")
  • 这行代码会在控制台打印一条消息,提示接下来将绘制SHAP特征重要性蜂巢图。
python 复制代码
shap.summary_plot(shap_values[:, :, 0], X_test, plot_type="violin", show=False, max_display=10)
  • 同样使用 shap.summary_plot 函数绘制特征重要性图。
  • plot_type="violin":指定绘制蜂巢图(小提琴图)。
  • show=False:不直接显示图形,以便后续使用 matplotlib 修改图形元素。
  • max_display=10:限制最多显示10个特征。

@浙大疏锦行

相关推荐
hello_ejb32 小时前
聊聊Spring AI Alibaba的SentenceSplitter
人工智能·python·spring
新辞旧梦3 小时前
企业微信自建消息推送应用
服务器·python·企业微信
虎头金猫3 小时前
如何解决 403 错误:请求被拒绝,无法连接到服务器
运维·服务器·python·ubuntu·chatgpt·centos·bug
dqsh065 小时前
树莓派5+Ubuntu24.04 LTS串口通信 保姆级教程
人工智能·python·物联网·ubuntu·机器人
sunshineine7 小时前
jupyter notebook运行简单程序
linux·windows·python
方博士AI机器人7 小时前
Python 3.x 内置装饰器 (4) - @dataclass
开发语言·python
万能程序员-传康Kk7 小时前
中国邮政物流管理系统(Django+mysql)
python·mysql·django
Logintern097 小时前
【每天学习一点点】使用Python的pathlib模块分割文件路径
开发语言·python·学习
开开心心_Every8 小时前
手机隐私数据彻底删除工具:回收或弃用手机前防数据恢复
android·windows·python·搜索引擎·智能手机·pdf·音视频