#机器学习--模型可解释性框架SHAP的使用
引言
本系列博客旨在为机器学习(深度学习)提供数学理论基础。因此内容更为精简,适合二次学习的读者快速学习或查阅。
安装
shap依赖于tensorflow,即使我们使用的是pytorch。
shell
pip install tensorflow tf-keras shap
截至笔者这篇博客,shap的最新版本是0.50.0,我们需要重新再安装以下版本的 shap,官方版本有 bug。
shell
pip install git+https://github.com/maciejskorski/shap.git@fix/shap_text_colors --no-deps
基本原理
理论上我们可以使用shap来对任意模型进行可解释性分析。假设我们有一个模型 P P P,它的输入是向量 X X X。可以将 shap 的工作原理简单地理解为:现在我们需要分析向量 X X X 中每一个值对结果的贡献度,因此我们需要不断改变 X X X 中的值,来观察模型的输出变化情况。
基准值 :模型 P P P 在输入 X X X 上的输出值。
使用方法
官方示例都是使用的transformer.pipeline,和我们自己的模型有点差异,以下为笔者手写代码,逐步给读者介绍分析如何可视化自己的模型。对于其他任意模型大致逻辑都基本一致。
1、加载 BERT 预处理模型。非必须,读者可以有自己的数据处理逻辑。
python
model = BertModel.from_pretrained('bert-base-chinese', local_files_only=True).to(device)
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese', local_files_only=True)
# 设置最大长度
tokenizer.model_max_length = 100_000
2、定义模型,我这里是二分类任务,因此只需要输出一个logits即可。
python
class MatchingModel(nn.Module):
def __init__(self) -> None:
super().__init__()
...
self.output_mlp = nn.Sequential(
nn.RMSNorm(FEATURE_DIM * 2),
nn.Linear(FEATURE_DIM * 2, 512),
nn.GELU(),
nn.Linear(512, 128),
nn.GELU(),
nn.Linear(128, 1)
)
def forward(self, inputs: torch.Tensor):
...
logits = self.output_mlp(x.float())
return logits
3、定义数据预处理函数,如果读者使用的是其他数据或处理方法也类似。
python
def encode_text(text: str, chunk_size: int, batch_size: int):
"""
将文本按 token 数切分,每段最多 chunk_size 个 token,批量编码后拼接所有 token 向量。
返回整体向量:[1, total_tokens, hidden_dim]
"""
if len(text) == 0:
return None
inputs = tokenizer(text, return_tensors='pt', add_special_tokens=False)
input_ids = inputs['input_ids'][0] # [seq_len]
total_len = input_ids.size(0)
# 切分 token
chunks = []
for i in range(0, total_len, chunk_size):
chunks.append(input_ids[i:i + chunk_size])
token_vectors = []
# 批量处理
for i in range(0, len(chunks), batch_size):
batch_chunks = chunks[i:i + batch_size]
# padding
batch_padded = torch.nn.utils.rnn.pad_sequence(batch_chunks, batch_first=True, padding_value=0)
attention_mask = (batch_padded != 0).long().to(device)
batch_padded = batch_padded.to(device)
# 前向传播
with torch.no_grad():
outputs = model(input_ids=batch_padded, attention_mask=attention_mask)
last_hidden = outputs.last_hidden_state # [B, L, hidden_dim]
# 按原始长度切分
for j, chunk_tensor in enumerate(batch_chunks):
valid_len = chunk_tensor.size(0)
token_vectors.append(last_hidden[j, :valid_len].cpu().numpy())
return torch.unsqueeze(torch.tensor(np.concatenate(token_vectors, axis=0), device='cuda'), 0) # [1, total_tokens, hidden_dim]
4、定义模型与数据之间的适配器。注意,shap.Explainer只能解释单输入和单输出模型,我们这里定义适配器的目的之一就是适配多输入和多输出模型。
python
def model_adapter(matching_model: MatchingModel, masked_text_list: List[str]):
retval = []
for text in masked_text_list:
ext = encode_text(text, 512, 20)
if ext is None:
retval.append(0.)
continue
retval.append(stable_sigmoid_tensor(matching_model(ext)))
return torch.tensor(retval, device='cuda')
5、创建解释器,进行解释,并进行可视化。注意,plots.text需要在juypter的环境下运行。
python
explainer = shap.Explainer(
lambda masked_text_list: model_adapter(matching_model, masked_text_list),
maskers.Text(tokenizer),
max_evals=2048
)
with torch.no_grad():
ext_shap = explainer([ext_text_list[0]])
# 生成交互式 HTML 字符串
plots.text(ext_shap[0], grouping_threshold=0.1)
注意:如果使用的是官方的shap,输出可能是黑白色,没有任何颜色。更换后才能正常显示如下输出:
