SHAP中使用shap.summary_plot对多分类任务模型中特征重要性绘图

在文心一言中输入:

使用shap.summary_plot展示各个特征对模型输出类别的重要性

其输出的代码为(不正确):

python 复制代码
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn import svm
import shap

# 加载数据集(这里使用iris数据集作为例子)
iris = datasets.load_iris()
X = iris.data
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 初始化SVM模型并进行训练
clf = svm.SVC(kernel='linear', probability=True, random_state=42)
clf.fit(X_train, y_train)

# 初始化SHAP解释器
explainer = shap.Explainer(clf.predict_proba, X_train)

# 计算测试集上每个预测的SHAP值
# 注意:这里我们使用predict_proba方法,因为它返回了每个类别的概率
shap_values = explainer(X_test)

# 使用summary_plot可视化特征重要性
shap.summary_plot(shap_values, X_test, feature_names=iris.feature_names)

以上代码没有静态错误,但是运行报错:

Traceback (most recent call last):

File "D:\Ethan\Projects\fattyLiver\test_shap_iris.py", line 27, in <module>

shap.summary_plot(shap_values, X_test, feature_names=iris.feature_names)

File "D:\Ethan\Projects\fattyLiver\venv\lib\site-packages\shap\plots\_beeswarm.py", line 605, in summary_legacy

feature_names=feature_names[sort_inds],

TypeError: only integer scalar arrays can be converted to a scalar index

修改为如下代码(正确):

python 复制代码
import numpy as np
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn import svm
import shap

# 加载数据集(这里使用iris数据集作为例子)
iris = datasets.load_iris()
X = iris.data
y = iris.target

# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 初始化SVM模型并进行训练
clf = svm.SVC(kernel='linear', probability=True, random_state=42)
clf.fit(X_train, y_train)

# 初始化SHAP解释器
explainer = shap.Explainer(clf.predict_proba, X_train)

# 计算测试集上每个预测的SHAP值
# 注意:这里我们使用predict_proba方法,因为它返回了每个类别的概率
shap_values = explainer(X_test)

# 使用summary_plot可视化特征重要性
# shap.summary_plot(shap_values, X_test, feature_names=iris.feature_names)
list_of_2d_arrays = [shap_values.values[:, :, i] for i in range(3)]
shap.summary_plot(list_of_2d_arrays, X_test, feature_names=iris.feature_names, class_names=iris.target_names)

输出图片:

相关推荐
零壹AI实验室2 小时前
用AI 10分钟搭建一个监控系统:Prometheus + Grafana 实战
人工智能·grafana·prometheus
志栋智能2 小时前
超自动化巡检:量化运维成效的标尺
运维·网络·人工智能·自动化
AI科技星2 小时前
紫金山天文台与6G 超导太赫兹实验对比【乖乖数学】
人工智能·线性代数·机器学习·量子计算·agi
摩尔线程2 小时前
摩尔线程携手紫光计算机发布《语音识别全栈国产化技术实践白皮书》
人工智能·语音识别·摩尔线程
字节跳动开源2 小时前
局中局!给 Agent 装上 OpenViking,它们竟然学会了“记仇”和“伪装”?
人工智能·开源·llm
Exploring2 小时前
通过 Vibe Coding,我开发的第一款鸿蒙 App 上架了,欢迎大家下载体验
人工智能
杀生丸学AI2 小时前
【VALSE 2026】AI领域年度重要进展
人工智能
沪漂阿龙2 小时前
面试题:文本表示方法详解——One-hot、Word2Vec、上下文表示、BERT词向量全解析(NLP基础高频考点)
人工智能·神经网络·自然语言处理·bert·word2vec
Luminbox紫创测控2 小时前
氙灯太阳光模拟器加速老化测试
人工智能·测试工具·测试标准
沪漂阿龙2 小时前
面试题详解:NLP基础概念与任务——一文吃透自然语言处理、Tokenization、文本分类、文本摘要、信息抽取与大模型应用
人工智能·自然语言处理·分类