SHAP中使用shap.summary_plot对多分类任务模型中特征重要性绘图

在文心一言中输入:

使用shap.summary_plot展示各个特征对模型输出类别的重要性

其输出的代码为(不正确):

python 复制代码
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn import svm
import shap

# 加载数据集(这里使用iris数据集作为例子)
iris = datasets.load_iris()
X = iris.data
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 初始化SVM模型并进行训练
clf = svm.SVC(kernel='linear', probability=True, random_state=42)
clf.fit(X_train, y_train)

# 初始化SHAP解释器
explainer = shap.Explainer(clf.predict_proba, X_train)

# 计算测试集上每个预测的SHAP值
# 注意:这里我们使用predict_proba方法,因为它返回了每个类别的概率
shap_values = explainer(X_test)

# 使用summary_plot可视化特征重要性
shap.summary_plot(shap_values, X_test, feature_names=iris.feature_names)

以上代码没有静态错误,但是运行报错:

Traceback (most recent call last):

File "D:\Ethan\Projects\fattyLiver\test_shap_iris.py", line 27, in <module>

shap.summary_plot(shap_values, X_test, feature_names=iris.feature_names)

File "D:\Ethan\Projects\fattyLiver\venv\lib\site-packages\shap\plots\_beeswarm.py", line 605, in summary_legacy

feature_names=feature_names[sort_inds],

TypeError: only integer scalar arrays can be converted to a scalar index

修改为如下代码(正确):

python 复制代码
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn import svm
import shap

# 加载数据集(这里使用iris数据集作为例子)
iris = datasets.load_iris()
X = iris.data
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 初始化SVM模型并进行训练
clf = svm.SVC(kernel='linear', probability=True, random_state=42)
clf.fit(X_train, y_train)

# 初始化SHAP解释器
explainer = shap.Explainer(clf.predict_proba, X_train)

# 计算测试集上每个预测的SHAP值
# 注意:这里我们使用predict_proba方法,因为它返回了每个类别的概率
shap_values = explainer(X_test)

# 使用summary_plot可视化特征重要性
# shap.summary_plot(shap_values, X_test, feature_names=iris.feature_names)
list_of_2d_arrays = [shap_values.values[:, :, i] for i in range(3)]
shap.summary_plot(list_of_2d_arrays, X_test, feature_names=iris.feature_names, class_names=iris.target_names)

输出图片:

相关推荐
2501_930799242 分钟前
访答知识库,企业知识库,访答浏览器,Al编辑器,RAG,Pdf转word。个人知识库,访答RAG,云知识库,私有知识库……
人工智能
机器之心9 分钟前
OpenAI官宣自研造芯,联手博通开发10吉瓦规模的AI加速器
人工智能·openai
机器之心13 分钟前
100美元、8000行代码手搓ChatGPT,Karpathy最新开源项目爆火,一夜近5k star
人工智能·openai
RTC老炮15 分钟前
webrtc弱网-BitrateEstimator类源码分析与算法原理
网络·人工智能·算法·机器学习·webrtc
星期天要睡觉19 分钟前
计算机视觉(opencv)——基于 MediaPipe 的手势识别系统
人工智能·opencv·计算机视觉
三年呀23 分钟前
指纹技术深度剖析:从原理到实践的全方位探索
图像处理·人工智能·计算机视觉·指纹识别·生物识别技术·安全算法
学习的周周啊1 小时前
一人AI自动化开发体系(Cursor 驱动):从需求到上线的全流程闭环与实战清单
运维·人工智能·自动化·ai编程·全栈·devops·cursor
后端小肥肠1 小时前
明星漫画总画不像?用 Coze +即梦 4 工作流,素描风漫画3分钟搞定,小白也能上手
人工智能·aigc·coze
flay2 小时前
5个Claude实战项目从0到1:自动化、客服机器人、代码审查
人工智能
flay2 小时前
Claude API完全指南:从入门到实战
人工智能