可解释的模型之LIME

可解释的AI

为什么需要XAI(Explainable AI) 技术

当使用机器学习用于医疗或者恐怖分子检测的时候,预测是不可盲目相信的,因为预测错误的后果是灾难性的[1],因此我们需要对模型进行解释。我们将依次学习LIMECAM等内容,这次先从LIME开始。

LIME

LIME是Local Interpretable Model-agnostic Explanations简称,是16年发表于KDD的论文,作者提出了一种范式LIME来解决深度学习可解释性的问题LIME的核心思想是一些复杂的模型,在极小领域内,都可以用可解释模型近似。这是不是听着很耳熟,模型即函数,那不就是任何连续光滑的曲面,在极小领域内,都可以用平面来近似。具体做法是在输入实例的极小邻域内,创建可解释特征,拟合可解释模型。作者不局限于解释一条实例,还提出了SP-LIME试图对模型进行解释通过对若干个有代表性的实例进行解释。

LIME方法[1]

先看一个例子(如图1所示),深度学习的模型Model吃入5个特征预测病人为Flu,经过LIME解释后,我们知道了因为sneezeheadache有利于Flu的,而no fatigue是不利于Flu的,因此模型做出Flu的决策。那医生只需要对模型做出的决策进行一个审查了。

图1 解释单个实例的预测

要了解LIME怎么做的,先得聊聊什么是解释特征。例如文本分类任务,输入是一文本序列 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ∈ R d x \in R^d </math>x∈Rd,输出是类别概率 <math xmlns="http://www.w3.org/1998/Math/MathML"> y y </math>y,那么解释特征就是取值于 <math xmlns="http://www.w3.org/1998/Math/MathML"> { 0 , 1 } \{0,1\} </math>{0,1}的文本遮掩向量 <math xmlns="http://www.w3.org/1998/Math/MathML"> z ∈ R d z\in R^d </math>z∈Rd,该向量的第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i个分量表明第 <math xmlns="http://www.w3.org/1998/Math/MathML"> i i </math>i个文字是否遮掩。不难得出,原始文本序列对应的解释特征就是零向量。

LIME怎么做的呢?它是通过对解释特征抖动,得到若干邻近的解释特征,邻近的解释特征还原到对应的文本序列,作为原始模型输入,得到相应的预测值。我们将解释特征和预测值成对的送入解释模型(例如线性模型,决策树等),从而在输入实例的极小邻域内,解释模型可以拟合原始模型。

图2 LIME的直观理解

如图2所示,红叉是输入要解释的实例,经过抖动再映射可得到若干邻近实例,解释模型需要对上述实例进行分类,注意与被解释的实例越邻近表明该实例越重要,越需要解释模型分类正确,这也是论文提到的可信(faithful)的意思。最后,通过解释模型的可解释性来对复杂的难以解释的模型进行解释,充当发言人的角色。

上面做法对应的伪码如图3所示。

图3 LIME算法

其中,相似性核为 <math xmlns="http://www.w3.org/1998/Math/MathML"> π x ( z ) = e − D ( x , z ) 2 σ 2 \pi_x(z)=e^{-\frac{D(x,z)^2}{\sigma^2}} </math>πx(z)=e−σ2D(x,z)2, <math xmlns="http://www.w3.org/1998/Math/MathML"> s a m p l e _ a r o u n d sample\_around </math>sample_around是说从特征 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x的解释版本 <math xmlns="http://www.w3.org/1998/Math/MathML"> x ′ x' </math>x′附近抖动采样解释版本 <math xmlns="http://www.w3.org/1998/Math/MathML"> z ′ z' </math>z′,而 <math xmlns="http://www.w3.org/1998/Math/MathML"> z z </math>z是 <math xmlns="http://www.w3.org/1998/Math/MathML"> z ′ z' </math>z′对应的实例表示, <math xmlns="http://www.w3.org/1998/Math/MathML"> K K </math>K是解释特征的数目,具体的细节可以看原paper[1]。

SP-LIME方法[1]

刚刚说了如何解释一个实例,那如何评估模型的全局可信情况呢?我们可以使用SP-LIME(Submodular pick (SP) algorithm),即子模选择算法。

那子模选择算法如何做的呢?它通过确定有代表性的实例,对有代表性的实例进行评估,从而推断模型的可信情况。实例评估是方便的,如LIME。那如何选择有代表性的实例呢?如图4所示,

图4 子模选择算法

其中, <math xmlns="http://www.w3.org/1998/Math/MathML"> c ( V , W , I ) = ∑ j = 1 d ′ 1 [ ∃ i ∈ V : W i j > 0 ] I j c(V,W,I)=\sum_{j=1}^{d'}\mathcal{1}{[\exist i\in V:W{ij}>0]}I_{j} </math>c(V,W,I)=∑j=1d′1[∃i∈V:Wij>0]Ij.

一开始,我们对指定大小 <math xmlns="http://www.w3.org/1998/Math/MathML"> N N </math>N个实例用LIME进行解释,得到 <math xmlns="http://www.w3.org/1998/Math/MathML"> N N </math>N组特征权重,如图5所示。通过对不同实例的相同特征进行聚合,我们可以得到特征的重要性。最后根据特征的重要性量化实例的代表性以选择有代表性的实例。

图5 实例与特征权重图

LIME实践[2][3]

以葡萄酒的数据集为例,我们直接蹭kaggle平台的环境。如果需要解释一个实例的话,使用LIME就可以,代码如下,

python 复制代码
import lime
import warnings
import numpy as np
import pandas as pd
from lime import lime_tabular
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
warnings.filterwarnings('ignore')
# dataset loading and spliting
# data moving
cp /kaggle/input/wine-dataset/wine.csv /kaggle/working/
df= pd.read_csv('wine.csv')
X= df.drop('quality', axis= 1)
y= df['quality']
X_train, X_test, y_train, y_test= train_test_split(X, y, test_size= 0.2, random_state= 42)
# model training and evaulating
model= RandomForestClassifier(random_state= 42)
model.fit(X_train, y_train)
score= model.score(X_test, y_test)
# lime using
# Generate a explainer.
# training_data, numpy 2d; feature_names, list; class_names, list.
explainer= lime_tabular.LimeTabularExplainer(training_data= np.array(X_train), 
                                             feature_names= list(X_train.columns), 
                                             class_names= ['bad', 'good'],
                                             mode= 'classification')
idx= 4
data_test= np.array(X_test.iloc[idx]).reshape(1, -1)
pred= model.predict(data_test)[0]
y_true= np.array(y_test)[idx]
print(f'index: {idx} in test set,model pred: {pred},true label: {y_true}')
# Generates explanations for a prediction.
expr= explainer.explain_instance(
    # 1d numpy array or scipy.sparse matrix, corresponding to a row
    data_row= np.array(X_test.iloc[idx]),
    # predict_fn, For classifiers, 
    # this should be a function that takes a numpy array
    # and outputs prediction probabilities.
    predict_fn= model.predict_proba
)
# explain one record.
expr.show_in_notebook(show_table=  True)

相应的结果如图6所示,

图6 单个实例解释的结果

如果需要获得global view of model,我们可以使用sp-lime,代码如下,

python 复制代码
# evaluating the model's interpretability or receiving Global interpretability
# generates explanations for a set of representative instances from subset of your test data.
sp_objs= submodular_pick.SubmodularPick(explainer= explainer, data= np.array(X_test),
                                    predict_fn= model.predict_proba,
                                    # sample_size, This parameter controls 
                                    # how many samples (instances) from the
                                    # dataset (X_test) are being used for 
                                    # generating explanations.
                                    #
                                    # num_exps_desired, The number of representative 
                                    # explanation objects returned
                                    sample_size= 500, num_exps_desired= 5,
                                    num_features= 6)
# output representative explanations
for explanation in sp_objs.explanations[:5]:
    explanation.show_in_notebook(show_table= True)

需要注意的是,sample_size是从数据集采样的样本数,num_exps_desired是有代表性的解释对象数目。

LIME代码和数据

关于LIME方面的代码与数据可见我的github.

致歉

在无能为力的时候遇到一个想天长地久的人,我忙于各种事项,而忽略了与你相处的种种细节。让你受委屈了,对不起,远在北方的你。

参考资料

[1] why should I Trust You?

[2] LIME机器学习可解释性分析

[3] lime package --- lime 0.1 documentation (lime-ml.readthedocs.io)

相关推荐
m0_7482329211 分钟前
DALL-M:基于大语言模型的上下文感知临床数据增强方法 ,补充
人工智能·语言模型·自然语言处理
szxinmai主板定制专家17 分钟前
【国产NI替代】基于FPGA的32通道(24bits)高精度终端采集核心板卡
大数据·人工智能·fpga开发
海棠AI实验室20 分钟前
AI的进阶之路:从机器学习到深度学习的演变(三)
人工智能·深度学习·机器学习
机器懒得学习31 分钟前
基于YOLOv5的智能水域监测系统:从目标检测到自动报告生成
人工智能·yolo·目标检测
QQ同步助手1 小时前
如何正确使用人工智能:开启智慧学习与创新之旅
人工智能·学习·百度
AIGC大时代1 小时前
如何使用ChatGPT辅助文献综述,以及如何进行优化?一篇说清楚
人工智能·深度学习·chatgpt·prompt·aigc
流浪的小新1 小时前
【AI】人工智能、LLM学习资源汇总
人工智能·学习
martian6652 小时前
【人工智能数学基础篇】——深入详解多变量微积分:在机器学习模型中优化损失函数时应用
人工智能·机器学习·微积分·数学基础
人机与认知实验室3 小时前
人、机、环境中各有其神经网络系统
人工智能·深度学习·神经网络·机器学习
黑色叉腰丶大魔王3 小时前
基于 MATLAB 的图像增强技术分享
图像处理·人工智能·计算机视觉