SHAP中使用shap.summary_plot对多分类任务模型中特征重要性绘图

在文心一言中输入:

使用shap.summary_plot展示各个特征对模型输出类别的重要性

其输出的代码为(不正确):

python 复制代码
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn import svm
import shap

# 加载数据集(这里使用iris数据集作为例子)
iris = datasets.load_iris()
X = iris.data
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 初始化SVM模型并进行训练
clf = svm.SVC(kernel='linear', probability=True, random_state=42)
clf.fit(X_train, y_train)

# 初始化SHAP解释器
explainer = shap.Explainer(clf.predict_proba, X_train)

# 计算测试集上每个预测的SHAP值
# 注意:这里我们使用predict_proba方法,因为它返回了每个类别的概率
shap_values = explainer(X_test)

# 使用summary_plot可视化特征重要性
shap.summary_plot(shap_values, X_test, feature_names=iris.feature_names)

以上代码没有静态错误,但是运行报错:

Traceback (most recent call last):

File "D:\Ethan\Projects\fattyLiver\test_shap_iris.py", line 27, in <module>

shap.summary_plot(shap_values, X_test, feature_names=iris.feature_names)

File "D:\Ethan\Projects\fattyLiver\venv\lib\site-packages\shap\plots\_beeswarm.py", line 605, in summary_legacy

feature_names=feature_names[sort_inds],

TypeError: only integer scalar arrays can be converted to a scalar index

修改为如下代码(正确):

python 复制代码
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn import svm
import shap

# 加载数据集(这里使用iris数据集作为例子)
iris = datasets.load_iris()
X = iris.data
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 初始化SVM模型并进行训练
clf = svm.SVC(kernel='linear', probability=True, random_state=42)
clf.fit(X_train, y_train)

# 初始化SHAP解释器
explainer = shap.Explainer(clf.predict_proba, X_train)

# 计算测试集上每个预测的SHAP值
# 注意:这里我们使用predict_proba方法,因为它返回了每个类别的概率
shap_values = explainer(X_test)

# 使用summary_plot可视化特征重要性
# shap.summary_plot(shap_values, X_test, feature_names=iris.feature_names)
list_of_2d_arrays = [shap_values.values[:, :, i] for i in range(3)]
shap.summary_plot(list_of_2d_arrays, X_test, feature_names=iris.feature_names, class_names=iris.target_names)

输出图片:

相关推荐
极新2 分钟前
生数科技商业化总监陈鹤天:视频生成破瓶颈,AI赋能漫剧产业|2025极新AIGC峰会演讲实录
人工智能·科技·aigc
u1301306 分钟前
开源版 NotebookLM:Open Notebook 深度体验与部署指南
人工智能·开源
说私域9 分钟前
基于开源AI大模型AI智能名片S2B2C商城小程序的内容价值生成与多点选择传播策略研究
人工智能·微信·小程序·开源
水龙吟啸11 分钟前
项目设计与开发:智慧校园食堂系统
python·机器学习·前端框架·c#·团队开发·visual studio·数据库系统
说私域24 分钟前
数据分析能力在开源AI智能名片链动2+1模式多商户商城小程序中的价值与应用研究
人工智能·数据分析·开源
Coder_Boy_25 分钟前
基于SpringAI企业级智能教学考试平台试卷管理模块全业务闭环方案
java·大数据·人工智能·spring boot·springboot
拾荒的小海螺27 分钟前
开源项目:Z-Image 轻量高效的开源 AI 图像生成模型
人工智能·开源
dagouaofei32 分钟前
实测!6款AI自动生成PPT工具体验分享
人工智能·python·powerpoint
newrank_kk33 分钟前
下一代品牌战略:把智汇GEO作为核心品牌AI形象管理工具
大数据·人工智能
传感器与混合集成电路33 分钟前
面向航天、深地与核工业场景的高可靠电源方案设计要点
人工智能·物联网