Day37 深入理解SHAP图

SHAP值的解读

对于信贷问题,我们除了希望知道是否存在风险,还希望知道每个特征贡献了多少,比如年收入+0.15,收入高,加分;负债率-0.30负债太高,减分;工作年限+0.05工作稳定,小加分;信用评分-0.25 信用不好,减分;年龄+0.02影响很小,把模型的决策分解到每个特征上

到这里你可能心想,这样是不是很类似于线性回归的特征前的系数 y=ax1+bx2+cx3,那既然如此,我直接选择用线性回归的系数作解释,那岂不是更好?

线性回归的系数:y=0.15x收入 +(-0.30)x负债率 +0.05x工作年限 +(-0.25)x信用评分+0.02x年龄

  • 系数是固定的:不管是谁,"收入"的系数永远是 0.15
  • 全局解释:一个系数解释所有样本
  • 简单,但假设特征和目标是线性关系

比如线性回归会说收入每增加1万,贡献固定增加 x。但现实中:收入从 5万→30万,影响很大,收入从 100 万→500 万,影响就没那么大了(边际效应递减),SHAP能捕捉这种非线性关系

核心差异就在于SHAP值是因人而异的:张三的"收入"贡献可能是+0.15,李四的可能是+0.08

  • 局部解释:每个样本有自己的一组SHAP值
  • 复杂,但能捕捉非线性关系,同一个特征,不同样本贡献不同

这种非线性如何呈现:

  1. 特征本身的非线性关系:比如边际效应
  2. 特征之间存在交互效应:男性(性别)+年龄(25)发生质变

Shapley 值的核心就是当特征之间有交互作用时,如何公平地把"功劳"分给每个特征

SHAP 的原理来自博弈论,但我们用一个更简单的例子来理解:想象三个人合伙开了一家奶茶店,年底赚了 100 万。问题来了:这 100 万怎么分?

小王负责研发配方、小李负责营销推广、小张负责店面运营

直接三等分?不公平!因为每个人的贡献不一样。经济学家 Shapley 提出了一个方法:

数学家 做了数学假设:

博弈论基础上有4条规则,满足这4个客观规则的只有 shap 值,很自洽

一般来说基准值不是-500 +500

100 个样本,对这个样本的预测取平均(训练)=基准值

shap 值 本质上是解释模型在训练集学习的东西 加入什么都没学 直接取平均 最好的
SHAP 在机器学习中的应用

开店=机器学习

合伙人 = 特征

总收入 = 预测值 - 基准值

每人贡献 = 每个特征的 SHAP 值

基准值(Base Value):模型在所有样本上的平均预测值

三人合伙前,收入是 0,三人合伙后,收入是 100万,要分配的"蛋糕"就是 100万-0=100 万
基准值 = 没有任何特征信息时的"默认"预测(相当于"0"的起点),这个值一般就是平均值,把训练集的所有样本都输入模型,得到所有预测值取平均值,在没有关于这个特定样本的区分性信息时,最合理的猜测就是平均值
预测值 = 加入所有特征后的预测(相当于"100 万"的终点)

要分配的"蛋糕"=预测值 -基准值

核心公式:

模型预测值 = 基准值+SHAP(特征 1)+SHAP(特征 2)+..+SHAP(特征 N)

SHAP 值加起来 =预测值与基准值的差!
那么如何实现 shaply 值动态变化呢?上面说的是平均这一家店是一个样本,对于多家店每个店都是样本,所以特征贡献不同那么对于一个样本,如果控制变量计算特征贡献呢?真实在做的时候肯定是没法实现控制其他特征不动,检测单个特征的贡献,其实还是多样本比对了,shap 值本质上也是一个近似值。
虽然不完美,但 SHAP 是目前理论最完善、应用最广泛的解释方法

SHAP 值的计算用训练还是测试集?

先说结论,两者均可,但是为了图好看一般都是选择训练集。

  1. 做机器学习的专业大多都是交叉学科,本身你的研究多是针对私有数据集,别人不会关注你的泛化性。所以不必因为这个纠结。
  2. shap 值是每个样本的每个特征都会得到对应类别的值,所以如果你的数据量本身就不大,用训练集来绘制,点多会让图美观很多,或者也可以对测试集插值也可以起到一样的效果
  3. 补充一个 shap图美观的小技巧,可以绘制出每一个类别 shap 曲线的置信区间,因为机器学习多是点估计,而区间估计会让你的结果更加具有信服力,利用bootstrap 重采样思想可以绘制出置信区间,自己写一下 shap 图函数,不用借助 shap 库的接口。
python 复制代码
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_diabetes
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
import warnings
warnings.filterwarnings('ignore')

# --- 1. 全局设置 ---
plt.rcParams['font.sans-serif'] = ['SimHei']  # 中文字体
plt.rcParams['axes.unicode_minus'] = False   # 正常显示负号
sns.set(style="whitegrid", font='SimHei')

# ==========================================
# 2. 加载糖尿病数据集
# ==========================================
print("正在加载糖尿病数据集...")

# 加载数据
diabetes = load_diabetes()
X, y = diabetes.data, diabetes.target

# 转换为DataFrame以便更好地显示
df = pd.DataFrame(X, columns=diabetes.feature_names)
df['target'] = y

print("="*30 + " 数据集概览 " + "="*30)
print(f"数据形状: {df.shape}")
print(f"特征数量: {len(diabetes.feature_names)}")
print(f"目标变量: 一年后疾病进展的定量测量")

print("\n特征说明:")
for i, (name, desc) in enumerate(zip(diabetes.feature_names, diabetes.DESCR.split('\n')[10:20])):
    print(f"{i+1:2d}. {name:15s} - {desc.strip()}")

print("\n前5行数据:")
print(df.head())

# ==========================================
# 3. 数据探索性分析
# ==========================================
print("\n" + "="*30 + " 数据探索分析 " + "="*30)

# 创建可视化
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

# 1. 目标变量分布
axes[0].hist(y, bins=30, color='skyblue', edgecolor='black', alpha=0.7)
axes[0].set_title('目标变量分布', fontsize=12, fontweight='bold')
axes[0].set_xlabel('疾病进展')
axes[0].set_ylabel('频数')
axes[0].grid(True, alpha=0.3)

# 2. 特征相关性热图
corr_matrix = df.corr()
sns.heatmap(corr_matrix, annot=False, cmap='coolwarm', center=0, 
            ax=axes[1], cbar_kws={'shrink': 0.8})
axes[1].set_title('特征相关性热图', fontsize=12, fontweight='bold')

# 3. 目标变量与重要特征的关系
important_features = ['bmi', 's5', 'bp', 's3']
for i, feature in enumerate(important_features[:3]):
    axes[i+2].scatter(df[feature], y, alpha=0.5, s=20)
    axes[i+2].set_title(f'目标 vs {feature}', fontsize=12)
    axes[i+2].set_xlabel(feature)
    axes[i+2].set_ylabel('疾病进展')
    axes[i+2].grid(True, alpha=0.3)

plt.suptitle('糖尿病数据集探索性分析', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

# 数据统计
print("\n数据统计信息:")
print(df.describe().round(3))

# ==========================================
# 4. 训练随机森林模型
# ==========================================
print("\n" + "="*30 + " 模型训练 " + "="*30)

# 划分数据集
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42
)

print(f"训练集形状: {X_train.shape}")
print(f"测试集形状: {X_test.shape}")

# 训练随机森林模型
rf_model = RandomForestRegressor(
    n_estimators=100,
    max_depth=10,
    random_state=42,
    n_jobs=-1
)

rf_model.fit(X_train, y_train)
y_pred = rf_model.predict(X_test)

# 评估模型
mse = mean_squared_error(y_test, y_pred)
rmse = np.sqrt(mse)
mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print("\n模型性能评估:")
print(f"均方误差 (MSE): {mse:.3f}")
print(f"均方根误差 (RMSE): {rmse:.3f}")
print(f"平均绝对误差 (MAE): {mae:.3f}")
print(f"决定系数 (R²): {r2:.3f}")

# 预测结果可视化
plt.figure(figsize=(8, 6))
plt.scatter(y_test, y_pred, alpha=0.5, color='blue')
plt.plot([y.min(), y.max()], [y.min(), y.max()], 'r--', lw=2)
plt.xlabel('真实值')
plt.ylabel('预测值')
plt.title('随机森林模型预测结果', fontsize=14, fontweight='bold')
plt.grid(True, alpha=0.3)
plt.show()

# ==========================================
# 5. SHAP可解释性分析
# ==========================================
print("\n" + "="*30 + " SHAP可解释性分析 " + "="*30)

# 安装shap库(如果未安装)
try:
    import shap
    print("SHAP库已安装,开始分析...")
except ImportError:
    print("正在安装SHAP库...")
    import subprocess
    import sys
    subprocess.check_call([sys.executable, "-m", "pip", "install", "shap"])
    import shap

# 初始化SHAP解释器
explainer = shap.TreeExplainer(rf_model)
shap_values = explainer.shap_values(X_test)

# 转换为DataFrame以便更好地处理
shap_df = pd.DataFrame(shap_values, columns=diabetes.feature_names)

print("\nSHAP分析完成!")
print(f"SHAP值形状: {shap_df.shape}")

# ==========================================
# 6. SHAP可视化分析
# ==========================================
print("\n" + "="*30 + " SHAP可视化分析 " + "="*30)

# 创建可视化图表
fig, axes = plt.subplots(2, 2, figsize=(16, 12))

# 1. 特征重要性总结图
shap.summary_plot(shap_values, X_test, feature_names=diabetes.feature_names, 
                  show=False, plot_size=None, max_display=10)
plt.title('SHAP特征重要性总结', fontsize=14, fontweight='bold', y=1.02)
fig1 = plt.gcf()
fig1.set_size_inches(10, 6)
plt.tight_layout()
plt.show()

# 2. 特征重要性条形图
plt.figure(figsize=(10, 6))
shap.summary_plot(shap_values, X_test, feature_names=diabetes.feature_names, 
                  plot_type="bar", show=False, max_display=10)
plt.title('SHAP特征重要性(平均绝对影响)', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.show()

# 3. 单个特征的SHAP依赖图(选择最重要的3个特征)
important_indices = np.argsort(np.abs(shap_values).mean(0))[-3:][::-1]
important_features = [diabetes.feature_names[i] for i in important_indices]

print(f"\n最重要的3个特征:")
for i, feature in enumerate(important_features, 1):
    shap_mean_abs = np.abs(shap_df[feature]).mean()
    print(f"{i}. {feature}: 平均绝对SHAP值 = {shap_mean_abs:.3f}")

fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for idx, (ax, feature) in enumerate(zip(axes, important_features)):
    feature_idx = list(diabetes.feature_names).index(feature)
    
    # 使用shap的partial_dependence_plot
    shap.dependence_plot(
        feature_idx, shap_values, X_test, 
        feature_names=diabetes.feature_names,
        ax=ax, show=False
    )
    ax.set_title(f'{feature}的SHAP依赖图', fontsize=12, fontweight='bold')
    ax.set_xlabel(feature)
    ax.set_ylabel('SHAP值')
    ax.grid(True, alpha=0.3)

plt.suptitle('重要特征的SHAP依赖图', fontsize=14, fontweight='bold', y=1.05)
plt.tight_layout()
plt.show()

# 4. 单个样本的SHAP解释(选择3个样本)
print("\n" + "="*30 + " 单样本解释 " + "="*30)

# 选择3个有代表性的样本
sample_indices = [0, 50, 100]  # 可以根据需要调整

fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for i, (ax, sample_idx) in enumerate(zip(axes, sample_indices)):
    # 创建force plot(瀑布图)
    shap.force_plot(
        explainer.expected_value, 
        shap_values[sample_idx, :], 
        X_test[sample_idx, :],
        feature_names=diabetes.feature_names,
        matplotlib=True,
        show=False,
        text_rotation=30
    )
    
    # 获取当前图形
    temp_fig = plt.gcf()
    
    # 手动设置标题和标签
    ax.set_title(f'样本 {sample_idx} 的SHAP解释', fontsize=12, fontweight='bold')
    ax.set_xlabel('特征')
    ax.set_ylabel('SHAP贡献')
    
    # 清理并显示
    plt.close(temp_fig)

plt.suptitle('单个样本的SHAP解释(Force Plot)', fontsize=14, fontweight='bold', y=1.05)
plt.tight_layout()
plt.show()

# 显示样本的具体数值
for i, sample_idx in enumerate(sample_indices):
    print(f"\n样本 {sample_idx} 详情:")
    print(f"真实值: {y_test[sample_idx]:.2f}")
    print(f"预测值: {y_pred[sample_idx]:.2f}")
    print(f"预测偏差: {y_pred[sample_idx] - y_test[sample_idx]:.2f}")
    
    # 显示特征值和SHAP值
    sample_shap = shap_values[sample_idx, :]
    important_features_idx = np.argsort(np.abs(sample_shap))[-5:][::-1]
    
    print("最重要的5个特征贡献:")
    for j, feat_idx in enumerate(important_features_idx):
        feat_name = diabetes.feature_names[feat_idx]
        feat_value = X_test[sample_idx, feat_idx]
        shap_value = sample_shap[feat_idx]
        print(f"  {feat_name:10s}: 值={feat_value:6.3f}, SHAP={shap_value:7.3f}")

# 5. 特征交互分析
print("\n" + "="*30 + " 特征交互分析 " + "="*30)

# 寻找最重要的交互特征
interaction_feature = 'bmi'  # 选择一个重要特征
feature_idx = list(diabetes.feature_names).index(interaction_feature)

plt.figure(figsize=(10, 6))
shap.dependence_plot(
    feature_idx, shap_values, X_test, 
    feature_names=diabetes.feature_names,
    interaction_index='auto',  # 自动检测交互特征
    show=False
)
plt.title(f'{interaction_feature}的特征交互分析', fontsize=14, fontweight='bold')
plt.xlabel(interaction_feature)
plt.ylabel('SHAP值')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# ==========================================
# 7. 与传统特征重要性对比
# ==========================================
print("\n" + "="*30 + " 特征重要性对比 " + "="*30)

# 传统特征重要性(基于基尼不纯度)
traditional_importance = pd.DataFrame({
    '特征': diabetes.feature_names,
    '传统重要性': rf_model.feature_importances_
}).sort_values('传统重要性', ascending=False)

# SHAP特征重要性(基于平均绝对SHAP值)
shap_importance = pd.DataFrame({
    '特征': diabetes.feature_names,
    'SHAP重要性': np.abs(shap_values).mean(0)
}).sort_values('SHAP重要性', ascending=False)

print("\n传统特征重要性(基尼不纯度):")
print(traditional_importance.head(10))

print("\nSHAP特征重要性(平均绝对SHAP值):")
print(shap_importance.head(10))

# 可视化对比
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# 传统特征重要性
axes[0].barh(range(10), traditional_importance['传统重要性'].head(10)[::-1], 
            color='skyblue', alpha=0.7)
axes[0].set_yticks(range(10))
axes[0].set_yticklabels(traditional_importance['特征'].head(10)[::-1])
axes[0].set_xlabel('重要性得分')
axes[0].set_title('传统特征重要性', fontsize=12, fontweight='bold')
axes[0].grid(True, alpha=0.3, axis='x')

# SHAP特征重要性
axes[1].barh(range(10), shap_importance['SHAP重要性'].head(10)[::-1], 
            color='lightcoral', alpha=0.7)
axes[1].set_yticks(range(10))
axes[1].set_yticklabels(shap_importance['特征'].head(10)[::-1])
axes[1].set_xlabel('平均绝对SHAP值')
axes[1].set_title('SHAP特征重要性', fontsize=12, fontweight='bold')
axes[1].grid(True, alpha=0.3, axis='x')

plt.suptitle('特征重要性方法对比', fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

# ==========================================
# 8. 模型对比分析
# ==========================================
print("\n" + "="*30 + " 模型对比分析 " + "="*30)

# 训练线性回归模型进行对比
lr_model = LinearRegression()
lr_model.fit(X_train, y_train)
y_pred_lr = lr_model.predict(X_test)

# 计算线性回归的SHAP值(使用KernelExplainer)
lr_explainer = shap.KernelExplainer(lr_model.predict, X_train[:100])  # 使用子集提高速度
lr_shap_values = lr_explainer.shap_values(X_test[:100])  # 使用子集

print("\n模型性能对比:")
print(f"随机森林 R²: {r2_score(y_test, y_pred):.3f}")
print(f"线性回归 R²: {r2_score(y_test[:100], y_pred_lr[:100]):.3f}")

# 对比特征重要性
if lr_shap_values is not None:
    lr_shap_importance = pd.DataFrame({
        '特征': diabetes.feature_names,
        '线性回归SHAP': np.abs(lr_shap_values).mean(0)
    }).sort_values('线性回归SHAP', ascending=False)
    
    print("\n线性回归SHAP重要性:")
    print(lr_shap_importance.head(10))

# ==========================================
# 9. 总结报告
# ==========================================
print("\n" + "="*30 + " SHAP分析总结 " + "="*30)

print("\n关键发现:")
print("1. 最重要的预测特征:")
for i, row in shap_importance.head(3).iterrows():
    print(f"   {row['特征']}: 平均绝对SHAP值 = {row['SHAP重要性']:.3f}")

print("\n2. 模型解释性:")
print("   • SHAP提供了局部和全局的解释")
print("   • 可以理解每个特征对单个预测的贡献")
print("   • 揭示了特征之间的相互作用")

print("\n3. 临床应用建议:")
print("   • 重点关注BMI、s5和血压等关键指标")
print("   • 这些特征对疾病进展预测有最大影响")
print("   • 可以为个性化医疗提供数据支持")

print("\n4. 技术要点:")
print("   • 随机森林模型表现良好 (R² = {:.3f})".format(r2))
print("   • SHAP分析揭示了特征的非线性关系")
print("   • 与传统特征重要性方法相比,SHAP更准确")

# 保存结果
results = {
    '模型性能': {
        'R2': r2,
        'RMSE': rmse,
        'MAE': mae
    },
    '特征重要性': shap_importance.to_dict('records')[:5]
}

import json
with open('diabetes_shap_results.json', 'w', encoding='utf-8') as f:
    json.dump(results, f, ensure_ascii=False, indent=2)

print(f"\n分析结果已保存到: diabetes_shap_results.json")
print("\n" + "="*30 + " 分析完成 " + "="*30)

@浙大疏锦行

相关推荐
李星星BruceL6 小时前
Pytest第三章(参考指南1)
python·自动化·pytest
哎呀呦呵6 小时前
pytest基本使用
python·pytest
阿关@6 小时前
Vscode中Python无法将pip/pytest”项识别为 cmdlet、函数、脚本文件或可运行程序的名称
vscode·python·pip
Kristen_YXQDN6 小时前
PyCharm 中 pytest 运行 python 测试文件报错:D:\Python_file\.venv\Scripts\python.exe: No module named pytest
运维·开发语言·python·pycharm·pytest
Low--Key6 小时前
pytest框架快速入门
python·自动化·pytest
IMPYLH6 小时前
Lua 的 Debug(调试) 模块
开发语言·笔记·python·单元测试·lua·fastapi
普通网友6 小时前
更优雅的测试:Pytest框架入门
jvm·数据库·python
Beaman10246 小时前
pytest框架
python·pytest