DAY 37 深入理解SHAP图

python 复制代码
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split

# --- 1. 全局绘图设置 ---
plt.rcParams['font.sans-serif'] = ['SimHei']  # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False   # 用来正常显示负号
sns.set(style="whitegrid", font='SimHei')    # Seaborn 也要同步设置字体


df = pd.read_csv('housing.csv') 

# 为了讲义可读性,我们重命名列为中文对照
# 注意:原始 CSV 的列顺序是固定的,所以直接按顺序重命名是安全的

print("="*30 + " 数据集概览 " + "="*30)
print(f"数据形状: {df.shape}")
# print(df.head()) # 打印前5行预览
print(df.info()) # 打印前5行预览

# print(df['ocean_proximity'].value_counts())

mapping = {
    'INLAND': 1,       # 内陆,房价最低
    '<1H OCEAN': 2,    # 离海近1小时以内
    'NEAR OCEAN': 3,   # 靠近海岸
    'NEAR BAY': 4,     # 靠近海湾
    'ISLAND': 5        # 小岛,通常最贵
}
df['ocean_proximity'] = df['ocean_proximity'].map(mapping)

continuous_features = df.select_dtypes(include=['int64', 'float64']).columns.tolist() 
for feature in continuous_features:     
    mode_value = df[feature].mode()[0]            #获取该列的众数。
    df[feature].fillna(mode_value, inplace=True) 

# ==========================================
# 3. 数据切分
# ==========================================
# 最后一列 'Strength' 是我们要预测的目标 (y)
X = pd.concat([df.iloc[:, :-2], df.iloc[:, -1]], axis=1)
y = df.iloc[:, -2]

# 80% 训练,20% 测试
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)


from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score

# =============================
# 4. 训练随机森林回归模型
# =============================
rf = RandomForestRegressor(
    n_estimators=200,
    max_depth=None,
    random_state=42
)

rf.fit(X_train, y_train)

# =============================
# 5. 预测
# =============================
y_pred = rf.predict(X_test)

# =============================
# 6. 评价指标(回归)
# =============================
mse = mean_squared_error(y_test, y_pred)
rmse = np.sqrt(mse)
mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print("="*20, "随机森林回归:模型表现", "="*20)
print(f"MAE  (平均绝对误差): {mae:.4f}")
print(f"RMSE (均方根误差): {rmse:.4f}")
print(f"R²   (决定系数): {r2:.4f}")
python 复制代码
import shap
shap.initjs()   # Jupyter 环境自动渲染

# 1. 构造解释器
explainer = shap.TreeExplainer(rf)

# 2. 选择样本(提高计算效率)
X_sample = X_train.sample(200, random_state=42)

# 3. 计算 SHAP
shap_values = explainer.shap_values(X_sample)

# 4. Summary 散点图
shap.summary_plot(shap_values, X_sample, plot_type="dot")

# 5. Summary 条形图
shap.summary_plot(shap_values, X_sample, plot_type="bar")

# 6. 特定特征的 Dependence Plot
shap.dependence_plot("housing_median_age", shap_values, X_sample)

# 7. 单一样本 Force Plot
shap.force_plot(explainer.expected_value, shap_values[0], X_sample.iloc[0])

@浙大疏锦行

相关推荐
CryptoPP5 分钟前
对接BSE交易所获取数据。
python·金融·数据挖掘·数据分析·区块链
老歌老听老掉牙12 分钟前
PyQt5中RadioButton互斥选择的实现方法
开发语言·python·qt
Pyeako14 分钟前
Opencv计算机视觉
人工智能·python·深度学习·opencv·计算机视觉
还不秃顶的计科生20 分钟前
LeetCode 热题 100第一题:两数之和python版本
python·算法·leetcode
2401_8414956427 分钟前
【Python高级编程】2026 丙午马年元旦祝福程序
python·动画·tkinter·程序·pyinstaller·元旦·turtle
该醒醒了~28 分钟前
使用auto-py-to-exe打包python程序exe并添加图标和ico文件
python
idealzouhu30 分钟前
【Android】深入浅出 JNI
android·开发语言·python·jni
兜兜转转了多少年1 小时前
《Python 应用机器学习:代码实战指南》笔记2 从0理解机器学习 —— 核心概念全解析
笔记·python·机器学习
reasonsummer1 小时前
【教学类-70-04】20251231小2班幼儿制作折纸方镜(八卦神兽镜)
python·通义万相
IT·小灰灰1 小时前
大模型API成本优化实战指南:Token管理的艺术与科学
人工智能·python·数据分析