#机器学习--模型可解释性框架SHAP的使用

#机器学习--模型可解释性框架SHAP的使用

引言

本系列博客旨在为机器学习(深度学习)提供数学理论基础。因此内容更为精简,适合二次学习的读者快速学习或查阅。


安装

shap依赖于tensorflow,即使我们使用的是pytorch。

shell 复制代码
pip install tensorflow tf-keras shap

截至笔者这篇博客,shap的最新版本是0.50.0,我们需要重新再安装以下版本的 shap,官方版本有 bug。

shell 复制代码
pip install git+https://github.com/maciejskorski/shap.git@fix/shap_text_colors --no-deps

基本原理

理论上我们可以使用shap来对任意模型进行可解释性分析。假设我们有一个模型 P P P,它的输入是向量 X X X。可以将 shap 的工作原理简单地理解为:现在我们需要分析向量 X X X 中每一个值对结果的贡献度,因此我们需要不断改变 X X X 中的值,来观察模型的输出变化情况。

基准值 :模型 P P P 在输入 X X X 上的输出值。


使用方法

官方示例都是使用的transformer.pipeline,和我们自己的模型有点差异,以下为笔者手写代码,逐步给读者介绍分析如何可视化自己的模型。对于其他任意模型大致逻辑都基本一致。

1、加载 BERT 预处理模型。非必须,读者可以有自己的数据处理逻辑。

python 复制代码
model = BertModel.from_pretrained('bert-base-chinese', local_files_only=True).to(device)
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese', local_files_only=True)
# 设置最大长度
tokenizer.model_max_length = 100_000

2、定义模型,我这里是二分类任务,因此只需要输出一个logits即可。

python 复制代码
class MatchingModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        ...
        self.output_mlp = nn.Sequential(
            nn.RMSNorm(FEATURE_DIM * 2),
            nn.Linear(FEATURE_DIM * 2, 512),
            nn.GELU(),
            nn.Linear(512, 128),
            nn.GELU(),
            nn.Linear(128, 1)
        )

    def forward(self, inputs: torch.Tensor):
        ...
        logits = self.output_mlp(x.float())
        return logits

3、定义数据预处理函数,如果读者使用的是其他数据或处理方法也类似。

python 复制代码
def encode_text(text: str, chunk_size: int, batch_size: int):
    """
    将文本按 token 数切分,每段最多 chunk_size 个 token,批量编码后拼接所有 token 向量。
    返回整体向量:[1, total_tokens, hidden_dim]
    """
    if len(text) == 0:
        return None
    inputs = tokenizer(text, return_tensors='pt', add_special_tokens=False)
    input_ids = inputs['input_ids'][0]  # [seq_len]
    total_len = input_ids.size(0)

    # 切分 token
    chunks = []
    for i in range(0, total_len, chunk_size):
        chunks.append(input_ids[i:i + chunk_size])

    token_vectors = []

    # 批量处理
    for i in range(0, len(chunks), batch_size):
        batch_chunks = chunks[i:i + batch_size]

        # padding
        batch_padded = torch.nn.utils.rnn.pad_sequence(batch_chunks, batch_first=True, padding_value=0)
        attention_mask = (batch_padded != 0).long().to(device)
        batch_padded = batch_padded.to(device)

        # 前向传播
        with torch.no_grad():
            outputs = model(input_ids=batch_padded, attention_mask=attention_mask)
            last_hidden = outputs.last_hidden_state  # [B, L, hidden_dim]

        # 按原始长度切分
        for j, chunk_tensor in enumerate(batch_chunks):
            valid_len = chunk_tensor.size(0)
            token_vectors.append(last_hidden[j, :valid_len].cpu().numpy())

    return torch.unsqueeze(torch.tensor(np.concatenate(token_vectors, axis=0), device='cuda'), 0)  # [1, total_tokens, hidden_dim]

4、定义模型与数据之间的适配器。注意,shap.Explainer只能解释单输入和单输出模型,我们这里定义适配器的目的之一就是适配多输入和多输出模型。

python 复制代码
def model_adapter(matching_model: MatchingModel, masked_text_list: List[str]):
    retval = []
    for text in masked_text_list:
        ext = encode_text(text, 512, 20)
        if ext is None:
            retval.append(0.)
            continue
        retval.append(stable_sigmoid_tensor(matching_model(ext)))
    return torch.tensor(retval, device='cuda')

5、创建解释器,进行解释,并进行可视化。注意,plots.text需要在juypter的环境下运行。

python 复制代码
explainer = shap.Explainer(
    lambda masked_text_list: model_adapter(matching_model, masked_text_list),
    maskers.Text(tokenizer),
    max_evals=2048
)
with torch.no_grad():
    ext_shap = explainer([ext_text_list[0]])

# 生成交互式 HTML 字符串
plots.text(ext_shap[0], grouping_threshold=0.1)

注意:如果使用的是官方的shap,输出可能是黑白色,没有任何颜色。更换后才能正常显示如下输出:

相关推荐
九.九10 小时前
ops-transformer:AI 处理器上的高性能 Transformer 算子库
人工智能·深度学习·transformer
春日见10 小时前
拉取与合并:如何让个人分支既包含你昨天的修改,也包含 develop 最新更新
大数据·人工智能·深度学习·elasticsearch·搜索引擎
恋猫de小郭10 小时前
AI 在提高你工作效率的同时,也一直在增加你的疲惫和焦虑
前端·人工智能·ai编程
deephub10 小时前
Agent Lightning:微软开源的框架无关 Agent 训练方案,LangChain/AutoGen 都能用
人工智能·microsoft·langchain·大语言模型·agent·强化学习
大模型RAG和Agent技术实践11 小时前
从零构建本地AI合同审查系统:架构设计与流式交互实战(完整源代码)
人工智能·交互·智能合同审核
老邋遢11 小时前
第三章-AI知识扫盲看这一篇就够了
人工智能
互联网江湖11 小时前
Seedance2.0炸场:长短视频们“修坝”十年,不如AI放水一天?
人工智能
PythonPioneer11 小时前
在AI技术迅猛发展的今天,传统职业该如何“踏浪前行”?
人工智能
冬奇Lab11 小时前
一天一个开源项目(第20篇):NanoBot - 轻量级AI Agent框架,极简高效的智能体构建工具
人工智能·开源·agent
阿里巴巴淘系技术团队官网博客12 小时前
设计模式Trustworthy Generation:提升RAG信赖度
人工智能·设计模式