SHAP中使用shap.summary_plot对多分类任务模型中特征重要性绘图

在文心一言中输入:

使用shap.summary_plot展示各个特征对模型输出类别的重要性

其输出的代码为(不正确):

python 复制代码
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn import svm
import shap

# 加载数据集(这里使用iris数据集作为例子)
iris = datasets.load_iris()
X = iris.data
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 初始化SVM模型并进行训练
clf = svm.SVC(kernel='linear', probability=True, random_state=42)
clf.fit(X_train, y_train)

# 初始化SHAP解释器
explainer = shap.Explainer(clf.predict_proba, X_train)

# 计算测试集上每个预测的SHAP值
# 注意:这里我们使用predict_proba方法,因为它返回了每个类别的概率
shap_values = explainer(X_test)

# 使用summary_plot可视化特征重要性
shap.summary_plot(shap_values, X_test, feature_names=iris.feature_names)

以上代码没有静态错误,但是运行报错:

Traceback (most recent call last):

File "D:\Ethan\Projects\fattyLiver\test_shap_iris.py", line 27, in <module>

shap.summary_plot(shap_values, X_test, feature_names=iris.feature_names)

File "D:\Ethan\Projects\fattyLiver\venv\lib\site-packages\shap\plots\_beeswarm.py", line 605, in summary_legacy

feature_names=feature_names[sort_inds],

TypeError: only integer scalar arrays can be converted to a scalar index

修改为如下代码(正确):

python 复制代码
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn import svm
import shap

# 加载数据集(这里使用iris数据集作为例子)
iris = datasets.load_iris()
X = iris.data
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 初始化SVM模型并进行训练
clf = svm.SVC(kernel='linear', probability=True, random_state=42)
clf.fit(X_train, y_train)

# 初始化SHAP解释器
explainer = shap.Explainer(clf.predict_proba, X_train)

# 计算测试集上每个预测的SHAP值
# 注意:这里我们使用predict_proba方法,因为它返回了每个类别的概率
shap_values = explainer(X_test)

# 使用summary_plot可视化特征重要性
# shap.summary_plot(shap_values, X_test, feature_names=iris.feature_names)
list_of_2d_arrays = [shap_values.values[:, :, i] for i in range(3)]
shap.summary_plot(list_of_2d_arrays, X_test, feature_names=iris.feature_names, class_names=iris.target_names)

输出图片:

相关推荐
泰恒3 分钟前
双阶段目标检测是什么?有什么用?
人工智能·深度学习·机器学习
weixin_669545205 分钟前
BC915E 5V/3.6A输入 两节升压充电IC,升压充电效率95%,输入最大支持18W,ESOP8 兼容IP2325
人工智能·单片机·嵌入式硬件·硬件工程·信息与通信
新缸中之脑6 分钟前
NOMAD:战时离线智能体
人工智能
章鱼丸-6 分钟前
DAY38 Dataset 类和DataLoader 类
人工智能
QQsuccess6 分钟前
人工智能(AI)全体系学习——系列三
人工智能·python·深度学习·学习
深藏功yu名13 分钟前
Day25(高阶篇):RAG检索与重排序算法精研|从原理到参数调优,彻底攻克检索瓶颈
人工智能·算法·ai·自然语言处理·排序算法·agent
司南-704915 分钟前
claude初探- 国内镜像安装linux版claude
linux·运维·服务器·人工智能·后端
cd_9492172117 分钟前
《观澜社张庆与中信证券联手,共探金融发展新路径》
人工智能·金融
一晌小贪欢19 分钟前
【计算机科普知识】:什么是AI智能体(AI Agent)
人工智能·ai·chatgpt·ai agent·智能体·ai智能体
森诺Alyson23 分钟前
前沿技术借鉴研讨-2026.3.26(解决虚假特征x2/混合专家对比学习框架)
论文阅读·人工智能·经验分享·深度学习·学习·论文笔记