#机器学习--模型可解释性框架SHAP的使用

#机器学习--模型可解释性框架SHAP的使用

引言

本系列博客旨在为机器学习(深度学习)提供数学理论基础。因此内容更为精简,适合二次学习的读者快速学习或查阅。


安装

shap依赖于tensorflow,即使我们使用的是pytorch。

shell 复制代码
pip install tensorflow tf-keras shap

截至笔者这篇博客,shap的最新版本是0.50.0,我们需要重新再安装以下版本的 shap,官方版本有 bug。

shell 复制代码
pip install git+https://github.com/maciejskorski/shap.git@fix/shap_text_colors --no-deps

基本原理

理论上我们可以使用shap来对任意模型进行可解释性分析。假设我们有一个模型 P P P,它的输入是向量 X X X。可以将 shap 的工作原理简单地理解为:现在我们需要分析向量 X X X 中每一个值对结果的贡献度,因此我们需要不断改变 X X X 中的值,来观察模型的输出变化情况。

基准值 :模型 P P P 在输入 X X X 上的输出值。


使用方法

官方示例都是使用的transformer.pipeline,和我们自己的模型有点差异,以下为笔者手写代码,逐步给读者介绍分析如何可视化自己的模型。对于其他任意模型大致逻辑都基本一致。

1、加载 BERT 预处理模型。非必须,读者可以有自己的数据处理逻辑。

python 复制代码
model = BertModel.from_pretrained('bert-base-chinese', local_files_only=True).to(device)
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese', local_files_only=True)
# 设置最大长度
tokenizer.model_max_length = 100_000

2、定义模型,我这里是二分类任务,因此只需要输出一个logits即可。

python 复制代码
class MatchingModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        ...
        self.output_mlp = nn.Sequential(
            nn.RMSNorm(FEATURE_DIM * 2),
            nn.Linear(FEATURE_DIM * 2, 512),
            nn.GELU(),
            nn.Linear(512, 128),
            nn.GELU(),
            nn.Linear(128, 1)
        )

    def forward(self, inputs: torch.Tensor):
        ...
        logits = self.output_mlp(x.float())
        return logits

3、定义数据预处理函数,如果读者使用的是其他数据或处理方法也类似。

python 复制代码
def encode_text(text: str, chunk_size: int, batch_size: int):
    """
    将文本按 token 数切分,每段最多 chunk_size 个 token,批量编码后拼接所有 token 向量。
    返回整体向量:[1, total_tokens, hidden_dim]
    """
    if len(text) == 0:
        return None
    inputs = tokenizer(text, return_tensors='pt', add_special_tokens=False)
    input_ids = inputs['input_ids'][0]  # [seq_len]
    total_len = input_ids.size(0)

    # 切分 token
    chunks = []
    for i in range(0, total_len, chunk_size):
        chunks.append(input_ids[i:i + chunk_size])

    token_vectors = []

    # 批量处理
    for i in range(0, len(chunks), batch_size):
        batch_chunks = chunks[i:i + batch_size]

        # padding
        batch_padded = torch.nn.utils.rnn.pad_sequence(batch_chunks, batch_first=True, padding_value=0)
        attention_mask = (batch_padded != 0).long().to(device)
        batch_padded = batch_padded.to(device)

        # 前向传播
        with torch.no_grad():
            outputs = model(input_ids=batch_padded, attention_mask=attention_mask)
            last_hidden = outputs.last_hidden_state  # [B, L, hidden_dim]

        # 按原始长度切分
        for j, chunk_tensor in enumerate(batch_chunks):
            valid_len = chunk_tensor.size(0)
            token_vectors.append(last_hidden[j, :valid_len].cpu().numpy())

    return torch.unsqueeze(torch.tensor(np.concatenate(token_vectors, axis=0), device='cuda'), 0)  # [1, total_tokens, hidden_dim]

4、定义模型与数据之间的适配器。注意,shap.Explainer只能解释单输入和单输出模型,我们这里定义适配器的目的之一就是适配多输入和多输出模型。

python 复制代码
def model_adapter(matching_model: MatchingModel, masked_text_list: List[str]):
    retval = []
    for text in masked_text_list:
        ext = encode_text(text, 512, 20)
        if ext is None:
            retval.append(0.)
            continue
        retval.append(stable_sigmoid_tensor(matching_model(ext)))
    return torch.tensor(retval, device='cuda')

5、创建解释器,进行解释,并进行可视化。注意,plots.text需要在juypter的环境下运行。

python 复制代码
explainer = shap.Explainer(
    lambda masked_text_list: model_adapter(matching_model, masked_text_list),
    maskers.Text(tokenizer),
    max_evals=2048
)
with torch.no_grad():
    ext_shap = explainer([ext_text_list[0]])

# 生成交互式 HTML 字符串
plots.text(ext_shap[0], grouping_threshold=0.1)

注意:如果使用的是官方的shap,输出可能是黑白色,没有任何颜色。更换后才能正常显示如下输出:

相关推荐
ggabb6 小时前
中文:承载文明,引领未来
大数据·人工智能
tobias.b6 小时前
人工智能中的基础数学概念详解
人工智能
哈罗哈皮7 小时前
trea也很强,我撸一个给你看(附教程)
前端·人工智能·微信小程序
木梯子7 小时前
大数据+AI+人|扑兔AI打造企业智慧经营,落地全域获客
大数据·人工智能·数据挖掘
maxmaxma7 小时前
ROS2 机器人 少年创客营:Day 3
人工智能·机器人·自动驾驶
AI大法师7 小时前
字标Logo设计指南:中文品牌如何用字体做出高级感与辨识度
人工智能·设计模式
跟着珅聪学java7 小时前
编写高质量 CSS 样式完全指南
人工智能·python·tensorflow
weixin_669545207 小时前
JT8166A/B电容式六按键触摸控制芯片,JT8166B具备IIC通信接口
人工智能·单片机·嵌入式硬件·硬件工程
Julia | 品牌营销观察员7 小时前
抖音小红书竞品分析用什么软件?2026 实测好用
大数据·人工智能·竞品分析·竞对监测·竞品动态监测