SHAP中使用shap.summary_plot对多分类任务模型中特征重要性绘图

在文心一言中输入:

使用shap.summary_plot展示各个特征对模型输出类别的重要性

其输出的代码为(不正确):

python 复制代码
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn import svm
import shap

# 加载数据集(这里使用iris数据集作为例子)
iris = datasets.load_iris()
X = iris.data
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 初始化SVM模型并进行训练
clf = svm.SVC(kernel='linear', probability=True, random_state=42)
clf.fit(X_train, y_train)

# 初始化SHAP解释器
explainer = shap.Explainer(clf.predict_proba, X_train)

# 计算测试集上每个预测的SHAP值
# 注意:这里我们使用predict_proba方法,因为它返回了每个类别的概率
shap_values = explainer(X_test)

# 使用summary_plot可视化特征重要性
shap.summary_plot(shap_values, X_test, feature_names=iris.feature_names)

以上代码没有静态错误,但是运行报错:

Traceback (most recent call last):

File "D:\Ethan\Projects\fattyLiver\test_shap_iris.py", line 27, in <module>

shap.summary_plot(shap_values, X_test, feature_names=iris.feature_names)

File "D:\Ethan\Projects\fattyLiver\venv\lib\site-packages\shap\plots\_beeswarm.py", line 605, in summary_legacy

feature_names=feature_names[sort_inds],

TypeError: only integer scalar arrays can be converted to a scalar index

修改为如下代码(正确):

python 复制代码
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn import svm
import shap

# 加载数据集(这里使用iris数据集作为例子)
iris = datasets.load_iris()
X = iris.data
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 初始化SVM模型并进行训练
clf = svm.SVC(kernel='linear', probability=True, random_state=42)
clf.fit(X_train, y_train)

# 初始化SHAP解释器
explainer = shap.Explainer(clf.predict_proba, X_train)

# 计算测试集上每个预测的SHAP值
# 注意:这里我们使用predict_proba方法,因为它返回了每个类别的概率
shap_values = explainer(X_test)

# 使用summary_plot可视化特征重要性
# shap.summary_plot(shap_values, X_test, feature_names=iris.feature_names)
list_of_2d_arrays = [shap_values.values[:, :, i] for i in range(3)]
shap.summary_plot(list_of_2d_arrays, X_test, feature_names=iris.feature_names, class_names=iris.target_names)

输出图片:

相关推荐
miao6497255340 分钟前
智能猫砂盆到底哪家好用?自费实测聚宠、糯雪、CEWEY真实反馈!
大数据·运维·人工智能·安全·自动化
spark-aixin42 分钟前
昇思学习打卡-5-基于Mindspore实现BERT对话情绪识别
人工智能·学习·bert
地瓜伯伯1 小时前
HandlerMethodArgumentResolver :深入spring mvc参数解析机制
大数据·人工智能·spring boot·spring·语言模型
WarghostWu1 小时前
游戏AI的创造思路-技术基础-计算机视觉
人工智能·游戏·计算机视觉
处女座_三月1 小时前
视觉图像面积计算
人工智能·算法·计算机视觉
Czi.2 小时前
Build a Large Language Model (From Scratch)附录B(gpt-4o翻译版)
人工智能·语言模型·自然语言处理
龙的爹23332 小时前
论文翻译 | (DSP)展示-搜索-预测:为知识密集型自然语言处理组合检索和语言模型
人工智能·学习·语言模型·自然语言处理·nlp
新加坡内哥谈技术2 小时前
顶级10大AI测试工具
人工智能·语言模型·自然语言处理
躺柒2 小时前
读人工智能全传03分治策略
人工智能·机器学习·ai·组件·分治策略
大霸王龙3 小时前
深度学习中,模型的构建和训练过程中会用到多种函数
人工智能·深度学习