#机器学习--模型可解释性框架SHAP的使用

#机器学习--模型可解释性框架SHAP的使用

引言

本系列博客旨在为机器学习(深度学习)提供数学理论基础。因此内容更为精简,适合二次学习的读者快速学习或查阅。


安装

shap依赖于tensorflow,即使我们使用的是pytorch。

shell 复制代码
pip install tensorflow tf-keras shap

截至笔者这篇博客,shap的最新版本是0.50.0,我们需要重新再安装以下版本的 shap,官方版本有 bug。

shell 复制代码
pip install git+https://github.com/maciejskorski/shap.git@fix/shap_text_colors --no-deps

基本原理

理论上我们可以使用shap来对任意模型进行可解释性分析。假设我们有一个模型 P P P,它的输入是向量 X X X。可以将 shap 的工作原理简单地理解为:现在我们需要分析向量 X X X 中每一个值对结果的贡献度,因此我们需要不断改变 X X X 中的值,来观察模型的输出变化情况。

基准值 :模型 P P P 在输入 X X X 上的输出值。


使用方法

官方示例都是使用的transformer.pipeline,和我们自己的模型有点差异,以下为笔者手写代码,逐步给读者介绍分析如何可视化自己的模型。对于其他任意模型大致逻辑都基本一致。

1、加载 BERT 预处理模型。非必须,读者可以有自己的数据处理逻辑。

python 复制代码
model = BertModel.from_pretrained('bert-base-chinese', local_files_only=True).to(device)
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese', local_files_only=True)
# 设置最大长度
tokenizer.model_max_length = 100_000

2、定义模型,我这里是二分类任务,因此只需要输出一个logits即可。

python 复制代码
class MatchingModel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        ...
        self.output_mlp = nn.Sequential(
            nn.RMSNorm(FEATURE_DIM * 2),
            nn.Linear(FEATURE_DIM * 2, 512),
            nn.GELU(),
            nn.Linear(512, 128),
            nn.GELU(),
            nn.Linear(128, 1)
        )

    def forward(self, inputs: torch.Tensor):
        ...
        logits = self.output_mlp(x.float())
        return logits

3、定义数据预处理函数,如果读者使用的是其他数据或处理方法也类似。

python 复制代码
def encode_text(text: str, chunk_size: int, batch_size: int):
    """
    将文本按 token 数切分,每段最多 chunk_size 个 token,批量编码后拼接所有 token 向量。
    返回整体向量:[1, total_tokens, hidden_dim]
    """
    if len(text) == 0:
        return None
    inputs = tokenizer(text, return_tensors='pt', add_special_tokens=False)
    input_ids = inputs['input_ids'][0]  # [seq_len]
    total_len = input_ids.size(0)

    # 切分 token
    chunks = []
    for i in range(0, total_len, chunk_size):
        chunks.append(input_ids[i:i + chunk_size])

    token_vectors = []

    # 批量处理
    for i in range(0, len(chunks), batch_size):
        batch_chunks = chunks[i:i + batch_size]

        # padding
        batch_padded = torch.nn.utils.rnn.pad_sequence(batch_chunks, batch_first=True, padding_value=0)
        attention_mask = (batch_padded != 0).long().to(device)
        batch_padded = batch_padded.to(device)

        # 前向传播
        with torch.no_grad():
            outputs = model(input_ids=batch_padded, attention_mask=attention_mask)
            last_hidden = outputs.last_hidden_state  # [B, L, hidden_dim]

        # 按原始长度切分
        for j, chunk_tensor in enumerate(batch_chunks):
            valid_len = chunk_tensor.size(0)
            token_vectors.append(last_hidden[j, :valid_len].cpu().numpy())

    return torch.unsqueeze(torch.tensor(np.concatenate(token_vectors, axis=0), device='cuda'), 0)  # [1, total_tokens, hidden_dim]

4、定义模型与数据之间的适配器。注意,shap.Explainer只能解释单输入和单输出模型,我们这里定义适配器的目的之一就是适配多输入和多输出模型。

python 复制代码
def model_adapter(matching_model: MatchingModel, masked_text_list: List[str]):
    retval = []
    for text in masked_text_list:
        ext = encode_text(text, 512, 20)
        if ext is None:
            retval.append(0.)
            continue
        retval.append(stable_sigmoid_tensor(matching_model(ext)))
    return torch.tensor(retval, device='cuda')

5、创建解释器,进行解释,并进行可视化。注意,plots.text需要在juypter的环境下运行。

python 复制代码
explainer = shap.Explainer(
    lambda masked_text_list: model_adapter(matching_model, masked_text_list),
    maskers.Text(tokenizer),
    max_evals=2048
)
with torch.no_grad():
    ext_shap = explainer([ext_text_list[0]])

# 生成交互式 HTML 字符串
plots.text(ext_shap[0], grouping_threshold=0.1)

注意:如果使用的是官方的shap,输出可能是黑白色,没有任何颜色。更换后才能正常显示如下输出:

相关推荐
Mr数据杨26 分钟前
加州房价中位数预测在房地产估值中的应用
机器学习·数据分析·kaggle
xiaotao13128 分钟前
02-机器学习基础: 监督学习——线性回归
学习·机器学习·线性回归
曦樂~37 分钟前
【机器学习】概述
人工智能·机器学习
DeniuHe42 分钟前
机器学习模型中的偏置项(bias / 截距项)到底有什么用?
人工智能·机器学习
小江的记录本1 小时前
【网络安全】《网络安全常见攻击与防御》(附:《六大攻击核心特性横向对比表》)
java·网络·人工智能·后端·python·安全·web安全
深小乐1 小时前
AI 周刊【2026.04.13-04.19】:中美差距减小、Claude Opus 4.7发布、国产算力突围
人工智能
深小乐1 小时前
从 AI Skills 学实战技能(七):让 AI 自动操作浏览器
人工智能
workflower1 小时前
人机交互部分OOD
运维·人工智能·自动化·集成测试·人机交互·软件需求
lanker就是懒蛋1 小时前
深度学习Q&A:手写反向传播与OOM排查的深层逻辑
人工智能·深度学习
Old Uncle Tom1 小时前
Claude Code 记忆系统分析2
人工智能·ai·agent