拒绝算法黑盒!XGBoost + SHAP 一键生成 10 张出版级模型解释图
现在跑机器学习,大家最头疼的往往不是怎么把 R² 刷高,而是被导师或者业务方灵魂拷问:"你这个模型为什么会得出这个结果?哪个特征起决定性作用?"
像 XGBoost、随机森林这种集成树模型,虽然精度吊打传统回归,但"黑盒"属性太强。这时候,SHAP (SHapley Additive exPlanations) 就是我们最好的破局利器。
今天分享一套我压箱底的 Python 自动化脚本。它不仅能完成 XGBoost 的训练与评估,更核心的是:它能一口气生成 10 张高颜值、高分辨率(400 DPI)的 SHAP 可视化图表(包括小提琴图、热力图、瀑布图、依赖图等),直接满足发 Paper 或做汇报的全部需求。
🛠️ 核心代码逻辑拆解
这套脚本主打一个"端到端",从数据塞进去到美图吐出来一气呵成。为了方便理解,我们把核心操作拆解开来看看。
1. 中文映射与模型训练
在很多实际业务中(比如做城市规划、经济地理分析),我们的特征变量通常是中文(如"人均GDP"、"交通可达性")。为了防止绘图时出现乱码,脚本里内置了字段映射字典,并在训练前完成了数据清洗和 XGBoost 拟合:
python
# 核心特征中文化映射
FEATURE_CN_MAP = {
"Feat1": "人均GDP",
"Feat2": "专利/万人",
# ... 其他特征
}
TARGET_CN_NAME = "韧性指数"
# XGBoost 模型训练
model = xgb.XGBRegressor(
n_estimators=1000,
learning_rate=0.05,
max_depth=6,
random_state=42
)
model.fit(X_train, y_train)
2. SHAP 值的核心计算
模型算完了,接下来就是把模型喂给 SHAP 解释器。这一步是所有可视化的基础,它会计算出每个样本、每个特征对最终预测结果的贡献度(SHAP Value)。
python
# 实例化 SHAP 解释器并计算测试集的 SHAP 值
explainer = shap.Explainer(model)
shap_values_test = explainer(X_test)
shap_mat = shap_values_test.values
# 顺手把特征按照重要性(SHAP绝对值均值)排个序,方便后续画图
feature_order = np.argsort(np.abs(shap_mat).mean(axis=0))[::-1]
3. 出版级图表定制(以热力图为例)
很多直接调 shap.plots 画出来的默认图表,颜色比较暗淡。脚本里我对 Matplotlib 进行了深度的客制化,统一使用了极简、明亮 的配色(如 #5DADE2 亮蓝、#C0392B 亮红),并去除了冗余的边框,非常适合直接放进论文里。
以"全样本 SHAP 热力图"的生成为例:
python
# 提取排序后的数据并设置颜色阈值
heat_data = shap_mat[sample_order][:, top_idx].T
vmax = np.percentile(np.abs(heat_data), 98)
# 自定义高颜值热力图
fig_h, ax_h = plt.subplots(figsize=(16, 9), dpi=150)
# 使用 RdBu_r 红蓝渐变色带,清晰对比正负贡献
im_h = ax_h.imshow(heat_data, aspect="auto", cmap="RdBu_r", vmin=-vmax, vmax=vmax)
ax_h.set_title("SHAP Values Heatmap", fontsize=20, pad=12, fontweight="bold")
# 去除多余的网格线,保持画面干净
# ... (详见文末完整代码)
💻 完整源码(拿去即用)
以下是完整的 Python 脚本。运行前请确保安装了 xgboost, shap, pandas, matplotlib, scikit-learn 等依赖。
你只需要把 FILE_PATH 改成你自己的数据路径,调整一下映射字典,运行后在同级目录下就会自动生成 10 张高清美图和一个模型指标评估表(包含 R², RMSE, MAE)。
python
# -*- coding: utf-8 -*-
"""
功能:XGBoost 模型训练及 10 种出版级 SHAP 可视化出图
特点:高分辨率 (400 DPI)、明亮极简配色、支持中文字段映射
"""
import os
import sys
import importlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import xgboost as xgb
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
from sklearn.model_selection import train_test_split
if __name__ == "__main__":
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path = [p for p in sys.path if os.path.abspath(p or ".") != current_dir]
shap = importlib.import_module("shap")
# ==================== 参数配置区 ====================
FIG_DPI = 400
TOP_N = 15
RANDOM_HEATMAP_SAMPLES = 20
# 数据路径请替换为您自己的 CSV
FILE_PATH = "./data/dataset.csv"
TARGET_COL = "XHM"
# 字段映射(为了图表展示更直观)
FEATURE_CN_MAP = {
"Feat1": "人均GDP",
"Feat2": "专利/万人",
"Feat3": "对外开放度",
"Feat4": "产业高级化",
"Feat5": "科技支出占比",
"Feat6": "交通可达性",
"Feat7": "普惠金融指数",
}
TARGET_CN_NAME = "韧性指数"
# 解决图表中文字体显示问题
plt.rcParams["font.family"] = ["SimSun", "DejaVu Sans"]
plt.rcParams["axes.unicode_minus"] = False
# ====================================================
# 1. 数据加载与清洗
df = pd.read_csv(FILE_PATH)
df_numeric = df.select_dtypes(include=[np.number])
df_numeric = df_numeric.loc[:, df_numeric.nunique(dropna=True) > 1]
df_numeric = df_numeric.replace([np.inf, -np.inf], np.nan).dropna()
if TARGET_COL not in df_numeric.columns:
raise ValueError(f"目标列 `{TARGET_COL}` 不在数据中。")
X = df_numeric.drop(columns=[TARGET_COL], errors="ignore")
y = df_numeric[TARGET_COL]
X = X.rename(columns=FEATURE_CN_MAP)
y = y.rename(TARGET_CN_NAME)
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42, stratify=None
)
# 2. 模型训练与评估
model = xgb.XGBRegressor(
n_estimators=1000, learning_rate=0.05, max_depth=6, random_state=42
)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
test_r2 = r2_score(y_test, y_pred)
test_rmse = np.sqrt(mean_squared_error(y_test, y_pred))
test_mae = mean_absolute_error(y_test, y_pred)
metrics_df = pd.DataFrame(
{
"Metric": ["R2", "RMSE", "MAE"],
"Value": [test_r2, test_rmse, test_mae],
}
)
metrics_df.to_csv("model_test_metrics.csv", index=False, encoding="utf-8-sig")
# 3. 计算 SHAP 值
explainer = shap.Explainer(model)
shap_values_test = explainer(X_test)
shap_mat = shap_values_test.values
feature_order = np.argsort(np.abs(shap_mat).mean(axis=0))[::-1]
top_n = min(TOP_N, X_test.shape[1])
top_idx = feature_order[:top_n]
top_feature_names = [X_test.columns[i] for i in top_idx]
mean_abs_shap = np.abs(shap_mat).mean(axis=0)
top_importance = mean_abs_shap[top_idx]
# ==================== 开始批量绘图 ====================
# 1) Violin Plot (小提琴图)
plt.figure(figsize=(12, 8), dpi=150)
shap.summary_plot(
shap_values_test,
X_test,
plot_type="violin",
max_display=top_n,
color="#5DADE2",
show=False,
)
ax_v = plt.gca()
ax_v.set_title("SHAP Value Distribution (Violin Plot)", fontsize=20, pad=14, fontweight="bold")
ax_v.set_xlabel("SHAP Value", fontsize=16)
ax_v.set_ylabel("")
ax_v.tick_params(axis="both", labelsize=13)
ax_v.grid(axis="x", linestyle="--", alpha=0.2)
plt.tight_layout()
plt.savefig("shap_violin.png", dpi=FIG_DPI, bbox_inches="tight", facecolor="white")
plt.close()
# 2) 全样本热力图
sample_order = np.argsort(np.abs(shap_mat).sum(axis=1))[::-1]
heat_data = shap_mat[sample_order][:, top_idx].T
vmax = np.percentile(np.abs(heat_data), 98)
fig_h, ax_h = plt.subplots(figsize=(16, 9), dpi=150)
im_h = ax_h.imshow(heat_data, aspect="auto", cmap="RdBu_r", vmin=-vmax, vmax=vmax)
ax_h.set_title("SHAP Values Heatmap", fontsize=20, pad=12, fontweight="bold")
ax_h.set_ylabel("Feature", fontsize=14)
ax_h.set_xlabel("Sample Index (sorted by total |SHAP|)", fontsize=14)
ax_h.set_yticks(np.arange(len(top_feature_names)))
ax_h.set_yticklabels(top_feature_names, fontsize=11)
ax_h.set_xticks(np.linspace(0, heat_data.shape[1] - 1, min(6, heat_data.shape[1])).astype(int))
ax_h.tick_params(axis="x", labelsize=10)
cbar_h = fig_h.colorbar(im_h, ax=ax_h, fraction=0.03, pad=0.02)
cbar_h.set_label("SHAP Value", fontsize=13)
cbar_h.ax.tick_params(labelsize=10)
fig_h.tight_layout()
fig_h.savefig("shap_heatmap.png", dpi=FIG_DPI, bbox_inches="tight", facecolor="white")
plt.close(fig_h)
# 3) 20个随机样本热力图
rng = np.random.default_rng(42)
sample_count = min(RANDOM_HEATMAP_SAMPLES, X_test.shape[0])
rand_idx = np.sort(rng.choice(X_test.shape[0], size=sample_count, replace=False))
top12_idx = feature_order[: min(12, len(feature_order))]
random_heat_data = shap_mat[rand_idx][:, top12_idx]
random_feature_labels = [X_test.columns[i] for i in top12_idx]
random_sample_labels = [f"样本 {i}" for i in rand_idx]
vmax2 = np.percentile(np.abs(random_heat_data), 98)
fig_r, ax_r = plt.subplots(figsize=(13, 10), dpi=150)
im_r = ax_r.imshow(random_heat_data, aspect="auto", cmap="RdBu_r", vmin=-vmax2, vmax=vmax2)
ax_r.set_title("SHAP Heatmap - 20 Random Samples", fontsize=20, pad=12, fontweight="bold")
ax_r.set_xlabel("Features", fontsize=14, fontweight="bold")
ax_r.set_ylabel("Samples", fontsize=14, fontweight="bold")
ax_r.set_xticks(np.arange(len(random_feature_labels)))
ax_r.set_xticklabels(random_feature_labels, rotation=40, ha="right", fontsize=11)
ax_r.set_yticks(np.arange(len(random_sample_labels)))
ax_r.set_yticklabels(random_sample_labels, fontsize=10)
cbar_r = fig_r.colorbar(im_r, ax=ax_r, fraction=0.036, pad=0.04)
cbar_r.set_label("SHAP Value", fontsize=13, fontweight="bold")
cbar_r.ax.tick_params(labelsize=10)
fig_r.tight_layout()
fig_r.savefig("shap_heatmap_20samples.png", dpi=FIG_DPI, bbox_inches="tight", facecolor="white")
plt.close(fig_r)
# 4) Waterfall 图 (单样本解释)
waterfall_idx = min(5, len(shap_values_test) - 1)
plt.figure(figsize=(12, 9), dpi=150)
shap.plots.waterfall(shap_values_test[waterfall_idx], max_display=10, show=False)
ax_w = plt.gca()
ax_w.set_title(f"SHAP Waterfall Plot - Sample {waterfall_idx}", fontsize=20, pad=14, fontweight="bold")
ax_w.tick_params(axis="both", labelsize=12)
plt.tight_layout()
plt.savefig("shap_waterfall_sample5.png", dpi=FIG_DPI, bbox_inches="tight", facecolor="white")
plt.close()
# 5) 测试集预测散点图
fig_p, ax_p = plt.subplots(figsize=(8, 8), dpi=150)
ax_p.scatter(y_test, y_pred, s=86, color="#2E86C1", alpha=0.82, edgecolor="white", linewidth=0.8)
min_value = min(y_test.min(), y_pred.min())
max_value = max(y_test.max(), y_pred.max())
padding = (max_value - min_value) * 0.08 if max_value > min_value else 0.05
line_min = min_value - padding
line_max = max_value + padding
ax_p.plot([line_min, line_max], [line_min, line_max], color="#C0392B", linewidth=2.0, linestyle="--")
ax_p.set_xlim(line_min, line_max)
ax_p.set_ylim(line_min, line_max)
ax_p.set_title("Test Set Prediction Performance", fontsize=20, pad=14, fontweight="bold")
ax_p.set_xlabel(f"Actual {TARGET_CN_NAME}", fontsize=14)
ax_p.set_ylabel(f"Predicted {TARGET_CN_NAME}", fontsize=14)
ax_p.grid(linestyle="--", alpha=0.25)
ax_p.text(
0.05,
0.95,
f"R² = {test_r2:.4f}\nRMSE = {test_rmse:.4f}\nMAE = {test_mae:.4f}",
transform=ax_p.transAxes,
va="top",
fontsize=13,
bbox={"boxstyle": "round,pad=0.35", "facecolor": "white", "edgecolor": "#D0D3D4", "alpha": 0.92},
)
fig_p.tight_layout()
fig_p.savefig("model_prediction_performance.png", dpi=FIG_DPI, bbox_inches="tight", facecolor="white")
plt.close(fig_p)
# 6) 特征重要性条形图
bar_order = top_idx[::-1]
bar_names = [X_test.columns[i] for i in bar_order]
bar_values = mean_abs_shap[bar_order]
fig_b, ax_b = plt.subplots(figsize=(12, 8), dpi=150)
colors = plt.cm.Blues(np.linspace(0.35, 0.95, len(bar_values)))
ax_b.barh(bar_names, bar_values, color=colors, edgecolor="white", linewidth=1.0)
ax_b.set_title("Mean |SHAP| Feature Importance", fontsize=20, pad=14, fontweight="bold")
ax_b.set_xlabel("Mean Absolute SHAP Value", fontsize=14)
ax_b.tick_params(axis="both", labelsize=12)
ax_b.grid(axis="x", linestyle="--", alpha=0.25)
for spine in ["top", "right", "left"]:
ax_b.spines[spine].set_visible(False)
for value, name in zip(bar_values, bar_names):
ax_b.text(value, name, f" {value:.4f}", va="center", fontsize=10)
fig_b.tight_layout()
fig_b.savefig("shap_feature_importance_bar.png", dpi=FIG_DPI, bbox_inches="tight", facecolor="white")
plt.close(fig_b)
# 7) Beeswarm (蜂群图)
plt.figure(figsize=(12, 8), dpi=150)
shap.summary_plot(
shap_values_test,
X_test,
plot_type="dot",
max_display=top_n,
show=False,
)
ax_s = plt.gca()
ax_s.set_title("SHAP Beeswarm Summary", fontsize=20, pad=14, fontweight="bold")
ax_s.set_xlabel("SHAP Value", fontsize=16)
ax_s.tick_params(axis="both", labelsize=12)
plt.tight_layout()
plt.savefig("shap_beeswarm.png", dpi=FIG_DPI, bbox_inches="tight", facecolor="white")
plt.close()
# 8-9) 依赖图 (Dependence Plots) - 自动画出前两个最重要的特征
def save_dependence_plot(feature_idx, rank):
feature_name = X_test.columns[feature_idx]
feature_values = X_test.iloc[:, feature_idx]
feature_shap_values = shap_mat[:, feature_idx]
fig_d, ax_d = plt.subplots(figsize=(10, 7), dpi=150)
scatter = ax_d.scatter(
feature_values,
feature_shap_values,
c=feature_values,
cmap="coolwarm",
s=78,
alpha=0.85,
edgecolor="white",
linewidth=0.7,
)
ax_d.axhline(0, color="#777777", linewidth=1.2, linestyle="--", alpha=0.7)
ax_d.set_title(f"Dependence Plot - {feature_name}", fontsize=18, pad=12, fontweight="bold")
ax_d.set_xlabel(feature_name, fontsize=14)
ax_d.set_ylabel("SHAP Value", fontsize=14)
ax_d.tick_params(axis="both", labelsize=11)
ax_d.grid(linestyle="--", alpha=0.22)
cbar_d = fig_d.colorbar(scatter, ax=ax_d, fraction=0.045, pad=0.04)
cbar_d.set_label("Feature Value", fontsize=12)
cbar_d.ax.tick_params(labelsize=10)
fig_d.tight_layout()
fig_d.savefig(f"shap_dependence_top{rank}.png", dpi=FIG_DPI, bbox_inches="tight", facecolor="white")
plt.close(fig_d)
for rank, feature_idx in enumerate(feature_order[: min(2, len(feature_order))], start=1):
save_dependence_plot(feature_idx, rank)
# 10) 累计贡献比例图
importance_pct = top_importance / top_importance.sum() * 100
cumulative_pct = np.cumsum(importance_pct)
fig_c, ax_c = plt.subplots(figsize=(13, 7), dpi=150)
x_pos = np.arange(len(top_feature_names))
ax_c.bar(x_pos, importance_pct, color="#5DADE2", edgecolor="white", linewidth=1.0)
ax_c.set_title("SHAP Contribution Share", fontsize=20, pad=14, fontweight="bold")
ax_c.set_ylabel("Contribution Share (%)", fontsize=14)
ax_c.set_xticks(x_pos)
ax_c.set_xticklabels(top_feature_names, rotation=35, ha="right", fontsize=11)
ax_c.tick_params(axis="y", labelsize=11)
ax_c.grid(axis="y", linestyle="--", alpha=0.25)
ax_c2 = ax_c.twinx()
ax_c2.plot(x_pos, cumulative_pct, color="#D35400", marker="o", linewidth=2.6)
ax_c2.set_ylabel("Cumulative Share (%)", fontsize=14)
ax_c2.set_ylim(0, 105)
ax_c2.tick_params(axis="y", labelsize=11)
for spine in ["top"]:
ax_c.spines[spine].set_visible(False)
ax_c2.spines[spine].set_visible(False)
fig_c.tight_layout()
fig_c.savefig("shap_contribution_share.png", dpi=FIG_DPI, bbox_inches="tight", facecolor="white")
plt.close(fig_c)
print("模型与图表生成完毕,查看本地文件即可!")
最后成品(由于数据保密,该图为ai出图):
有了这套代码,以后只要换个数据集,修改一下 FEATURE_CN_MAP,一套精美的可视化解释图就出来了,直接拉满报告的专业度。欢迎大家在本地跑一跑!