双重机器学习之因果推断 | CATE条件平均处理效应估计:五大方法原理详解与模拟数据实战(python版)

家人们我又更新了,代码和科研绘图在论文末尾,欢迎大家评论点赞和收藏,你们的认可是我坚持的动力,祝大家科研顺利。

因果推断 | CATE条件平均处理效应估计:五大方法原理详解与模拟数据实战

本文是因果推断系列文章。本篇聚焦 CATE(Conditional Average Treatment Effect,条件平均处理效应) 的估计,从ATE的局限性讲起,深入介绍S-Learner、T-Learner、X-Learner、因果森林DML和线性DML五种主流方法的原理,并在模拟数据上进行完整的代码实操与效果对比。


1 从ATE到CATE:为什么需要异质性处理效应?

1.1 ATE只能回答"平均有没有用"

ATE(Average Treatment Effect)回答的是:干预措施对整个群体的平均效果是什么

但在实际业务中,我们更想知道的是:对于不同的个体或子群,干预效果有什么不同

举几个例子:

  • 精准营销:给所有人发满减券ATE为正,但拆开看,高消费用户根本不需要券,低消费用户反而是增量用户------CATE帮你找到真正的增量人群。
  • 个性化医疗:某新药对年轻患者效果显著,对老年患者副作用大于疗效------只看ATE可能得出"有效"的结论,但对老年患者施加干预反而有害。
  • 政策评估:教育补贴对低收入家庭的效果远大于高收入家庭------CATE可以指导财政资源的精细化分配。

1.2 CATE的数学定义

CATE(条件平均处理效应)定义为:

其中:

  • Y(1):个体接受处理时的潜在结果
  • Y(0):个体未接受处理时的潜在结果
  • X:协变量/特征向量
  • τ(x):给定特征x条件下的处理效应

CATE是协变量X的函数,刻画了处理效应的异质性(Heterogeneity)。ATE只是CATE的期望:

1.3 估计CATE的核心难点

根本问题(Fundamental Problem of Causal Inference):对同一个个体,我们只能观察到一种潜在结果。要么看到Y(1),要么看到Y(0),不可能同时观察到两者。

因此,τ(x) = E[Y(1) - Y(0) | X=x] 不能直接通过数据计算,需要借助统计方法来估计。


2 五大CATE估计方法原理详解

下面介绍五种最常用的CATE估计方法,按照"从简单到复杂"的顺序排列。

2.1 S-Learner(Single Model Learner)

核心思想 :把处理变量T当作一个普通特征,训练一个统一的模型

步骤

  1. 将T拼接到协变量X中,训练模型 μ̂(X, T) 拟合结果Y
  2. CATE估计:

优点:实现最简单,只需训练一个模型。

缺点 :当使用正则化模型(如Lasso、随机森林)时,模型可能认为T的贡献不重要而将其效应缩小甚至忽略,导致CATE估计偏向于零(regularization bias)。

复制代码
def s_learner(X, T, Y):
    XT = np.column_stack([X, T])
    model = GradientBoostingRegressor(n_estimators=200, max_depth=4)
    model.fit(XT, Y)
    
    tau_hat = model.predict(np.c_[X, np.ones(len(X))]) \
            - model.predict(np.c_[X, np.zeros(len(X))])
    return tau_hat

2.2 T-Learner(Two Model Learner)

核心思想 :处理组和控制组各自训练一个模型,然后做差。

步骤

  1. 用处理组数据训练 μ̂₁(x) 拟合 E[Y|X=x, T=1]
  2. 用控制组数据训练 μ̂₀(x) 拟合 E[Y|X=x, T=0]
  3. CATE估计:

优点:允许处理组和控制组的结果模型完全不同,灵活性高。

缺点:当处理组或控制组样本很少时,对应的模型估计不准确。两个独立模型的误差会叠加。

复制代码
def t_learner(X, T, Y):
    model_1 = GradientBoostingRegressor().fit(X[T==1], Y[T==1])
    model_0 = GradientBoostingRegressor().fit(X[T==0], Y[T==0])
    tau_hat = model_1.predict(X) - model_0.predict(X)
    return tau_hat

2.3 X-Learner(Cross Learner)

核心思想:利用"反事实插补"的思想,用一个组的模型去预测另一个组的反事实结果,弥补T-Learner在样本不均衡时的不足。出自 Künzel et al. (2019) 的经典论文。

步骤(三阶段)

阶段1:分别拟合结果模型(同T-Learner),训练控制组模型 μ̂₀(x) 和处理组模型 μ̂₁(x)。

阶段2:计算伪处理效应(Imputed Treatment Effects)

对处理组个体(观测结果 - 预测的反事实):

对控制组个体(预测的反事实 - 观测结果):

阶段3:拟合CATE模型并加权

  • 处理组CATE模型 τ̂₁(x) 拟合 D̃¹ ~ X
  • 控制组CATE模型 τ̂₀(x) 拟合 D̃⁰ ~ X
  • 最终CATE:

其中e(x)是倾向性得分。直觉是:哪个组样本多,就更多地依赖那个组的CATE估计

优点:在处理组/控制组样本严重不均衡时表现远优于T-Learner。

2.4 因果森林DML(CausalForestDML)

核心思想 :结合双重机器学习(DML)的去偏框架因果森林来捕捉非线性异质性效应。

步骤(两阶段)

第一阶段------去偏/正交化(Orthogonalization)

使用机器学习模型去除协变量对Y和T的影响:

关键操作是交叉拟合(Cross-Fitting):将数据分为K折,每折用其余K-1折训练的模型来预测当前折,避免过拟合偏差。

第二阶段------CATE估计

用因果森林(Generalized Random Forest)在残差Ỹ和T̃上估计异质性效应τ(X)。因果森林的分裂准则:最大化子节点间处理效应差异(而非预测误差)。

优点:能捕捉复杂的非线性CATE模式,同时DML去偏保证了估计的无偏性。

2.5 线性DML(LinearDML)

核心思想 :与因果森林DML共享相同的第一阶段去偏过程,但第二阶段假设CATE是协变量的线性函数

优点

  • 可解释性极强:β的每个分量直接告诉你对应特征对CATE的边际效应
  • 提供统计推断:可以输出系数的置信区间和p值
  • 适合特征影响近似线性的场景

缺点:无法捕捉非线性异质性(如我们模拟数据中X₂²的效应)。

2.6 方法对比总结

方法 模型数 能否捕捉非线性CATE 是否去偏 是否有置信区间 适用场景
S-Learner 1 快速基线
T-Learner 2 样本均衡
X-Learner 4+ 部分 样本不均衡
CausalForestDML 通用首选
LinearDML 需要可解释性

3 模拟数据设计(DGP)

为了公平评估各方法,我们设计了如下的数据生成过程(DGP):

3.1 变量设计

符号 说明
X(5维) 协变量,服从标准正态分布
T(0或1) 二值处理变量
Y(连续) 结果变量

3.2 数据生成公式

真实CATE函数

设计意图:

  • X₁ :对CATE有线性影响
  • X₂ :对CATE有**非线性(二次)**影响
  • X₃, X₄, X₅ :对CATE无影响(噪声特征)

倾向性得分(处理分配机制)

X₁和X₃同时影响处理分配和结果(混淆变量),这使得朴素估计存在偏差。

基线结果函数

观测结果

3.3 数据生成代码

复制代码
def generate_data(n=5000, seed=42):
    np.random.seed(seed)
    X = np.random.randn(n, 5)
    X1, X2, X3 = X[:, 0], X[:, 1], X[:, 2]
    
    # 真实CATE
    tau_true = X1 + 0.5 * X2**2 - 1.0
    
    # 倾向性得分(存在混淆)
    e_x = 1.0 / (1.0 + np.exp(-(0.5*X1 + 0.3*X3)))
    T = np.random.binomial(1, e_x)
    
    # 基线结果 + 处理效应 + 噪声
    Y = (2.0*X1 - X2 + 0.5*X3) + T * tau_true + np.random.randn(n)*0.5
    
    return X, T, Y, tau_true

4 实验结果

4.1 评估指标

MSE(均方误差)

Bias(平均偏差)

指标 含义 越小/大越好
MSE 均方误差 越小越好
MAE 平均绝对误差 越小越好
Bias 平均偏差 越接近0越好
对真实CATE变异的解释比例 越大越好
Corr 与真实CATE的相关系数 越大越好

4.2 各方法评估结果

以n=5000的模拟数据运行结果:

方法 MSE MAE Bias Corr
S-Learner 0.1003 0.2168 0.0469 0.9362 0.9788
T-Learner 0.1311 0.2652 0.0474 0.9166 0.9593
X-Learner 0.0516 0.1539 0.0226 0.9672 0.9840
DML-CausalForest 0.1329 0.2220 0.0666 0.9155 0.9648
DML-Linear 0.7658 0.7626 0.5035 0.5132 0.8220

4.3 结果分析

1. X-Learner表现最优(MSE=0.0516, R²=0.9672)

X-Learner通过反事实插补和倾向性加权,充分利用了两组数据的信息互补。在本实验中MSE最低、R²最高,是五种方法中表现最好的。

2. S-Learner表现稳健(MSE=0.1003, R²=0.9362)

S-Learner只用了一个模型,但GBT能自动捕捉T与X的交互效应,表现仅次于X-Learner。在快速验证场景下是很好的基线方法。

3. DML-CausalForest受限于简化实现

本代码中的DML-CausalForest是基于"残差比值+随机森林"的简化实现。完整的因果森林(如EconML的CausalForestDML)使用专门的因果分裂准则,效果会更好。如果安装了econml,建议直接使用CausalForestDML

4. DML-Linear的局限性(MSE=0.7658, R²=0.5132)

由于真实CATE包含X₂²项(非线性),LinearDML无法捕捉这部分异质性。但它成功识别出了X1的系数约为1.0(真实值为1.0),X3/X4/X5的系数接近0(符合真实DGP)。当你需要可解释的系数和置信区间时,LinearDML仍有不可替代的价值

5. T-Learner中规中矩

T-Learner独立训练两个模型,误差会叠加。在样本不均衡或效应较弱时劣势更明显。


5 可视化解读

5.1 CATE估计值 vs 真实值散点图

每个子图展示一种方法的估计值(纵轴)与真实CATE(横轴)的关系。理想情况下所有点应落在对角线上。

  • X-Learner的点最紧密地围绕对角线
  • LinearDML由于线性假设,在CATE极端值处偏离明显

【在此插入 fig1_cate_scatter.png】

5.2 CATE随X1变化趋势

真实CATE与X1呈线性关系(τ随X1增大而增大)。可以看到:

  • 所有方法都能捕捉到这个线性趋势
  • 差异主要体现在噪声大小和边缘区域的偏差

【在此插入 fig2_cate_by_x1.png】

5.3 CATE随X2变化趋势

真实CATE与X2呈U型关系(X₂²项)。这是对各方法的关键测试:

  • CausalForestDML和X-Learner能较好地还原U型
  • LinearDML只能拟合出线性趋势,完全丢失了非线性信息

【在此插入 fig2b_cate_by_x2.png】

5.4 误差箱线图

箱线图展示各方法估计误差的分布:

  • 中位线接近0说明无偏
  • 箱体越窄说明估计越精确
  • X-Learner的箱体最窄且最居中

【在此插入 fig4_error_boxplot.png】


6 实战建议与方法选择指南

6.1 该用哪种方法?

根据你的实际需求选择:

场景1:快速验证异质性是否存在 → 先用 S-LearnerT-Learner 作为基线。

场景2:样本不均衡(处理组远少于控制组) → 优先选择 X-Learner

场景3:需要无偏估计 + 非线性CATE → 使用 CausalForestDML(推荐首选)。

场景4:需要系数可解释性和置信区间 → 使用 LinearDML,可输出每个特征对CATE的边际效应。

场景5:高维稀疏特征 → 使用 SparseLinearDML(EconML提供),自动做特征选择。

6.2 注意事项

  1. SUTVA假设:个体间无干扰(Stable Unit Treatment Value Assumption)
  2. 无未观测混淆:所有影响T和Y的变量都被纳入X(Unconfoundedness)
  3. 正值假设:0 < e(X) < 1,每个个体都有可能被分配到任一组(Overlap / Positivity)
  4. 交叉拟合很重要:DML方法中的cross-fitting可以有效降低过拟合偏差
  5. 基学习器选择:第一阶段的ML模型不宜过于复杂,避免引入额外偏差

7 完整代码

完整代码文件 cate_simulation.py 包含:

  • 数据生成函数(可调节样本量、效应强度、混淆程度)
  • 5种CATE估计方法的实现
  • 评估指标计算
  • 6张可视化图表的生成

环境安装

复制代码
pip install numpy pandas matplotlib scikit-learn scipy

注:本代码纯sklearn实现,无需econml。如需使用完整的CausalForestDML,额外安装 pip install econml

运行方式

复制代码
python cate_simulation.py

运行后会在当前目录生成以下图表文件:

文件名 内容
fig1_cate_scatter.png 各方法 CATE 估计 vs 真实值散点图
fig2_cate_by_x1.png CATE 随 X1 变化趋势对比
fig2b_cate_by_x2.png CATE 随 X2 变化趋势对比
fig3_cate_dist.png CATE 估计值分布直方图
fig4_error_boxplot.png 估计误差箱线图
fig5_dml_linear_coef.png LinearDML 系数(含置信区间)
fig6_true_cate_heatmap.png 真实 CATE 在 X1-X2 平面的热力图

8 总结

要点 说明
CATE是什么 条件平均处理效应,刻画处理效应的异质性
为什么需要CATE ATE是"一刀切",CATE支持个性化决策
核心难点 反事实不可观测,需要统计方法估计
推荐方法 CausalForestDML(通用首选)+ LinearDML(需要可解释性时)
关键假设 无未观测混淆 + SUTVA + Positivity

下一篇我们将介绍DML在真实数据集上的应用实战,敬请期待!


参考文献

  1. Künzel S R, Sekhon J S, Bickel P J, et al. Metalearners for estimating heterogeneous treatment effects using machine learning[J]. PNAS, 2019.
  2. Chernozhukov V, Chetverikov D, Demirer M, et al. Double/debiased machine learning for treatment and structural parameters[J]. The Econometrics Journal, 2018.
  3. Athey S, Tibshirani J, Wager S. Generalized random forests[J]. The Annals of Statistics, 2019.
  4. EconML Documentation: https://econml.azurewebsites.net/
python 复制代码
"""
因果推断 | CATE(条件平均处理效应)估计方法:模拟数据下的完整实战
=================================================================
包含以下方法:
1. S-Learner  2. T-Learner  3. X-Learner
4. DML + 因果森林  5. DML + 线性模型
全部基于 numpy / sklearn 实现,无需 econml。
"""

import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib import rcParams
from scipy.ndimage import uniform_filter1d
import warnings
warnings.filterwarnings('ignore')

rcParams['font.sans-serif'] = ['DejaVu Sans']
rcParams['axes.unicode_minus'] = False
plt.style.use('seaborn-v0_8-whitegrid')
COLORS = ['#2196F3', '#FF5722', '#4CAF50', '#9C27B0', '#FF9800']

from sklearn.ensemble import GradientBoostingRegressor, RandomForestRegressor
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.model_selection import KFold

# ==================== DGP ====================
def generate_data(n=5000, seed=42):
    np.random.seed(seed)
    X = np.random.randn(n, 5)
    X1, X2, X3 = X[:,0], X[:,1], X[:,2]
    tau_true = X1 + 0.5*X2**2 - 1.0
    logit_e = 0.5*X1 + 0.3*X3
    e_x = 1.0/(1.0+np.exp(-logit_e))
    T = np.random.binomial(1, e_x)
    mu_0 = 2.0*X1 - X2 + 0.5*X3
    Y = mu_0 + T*tau_true + np.random.randn(n)*0.5
    print("="*60)
    print("  Data Summary")
    print("="*60)
    print(f"  n={n}  dim={X.shape[1]}  treat_rate={T.mean():.3f}  ATE={tau_true.mean():.3f}")
    print(f"  CATE range: [{tau_true.min():.2f}, {tau_true.max():.2f}]")
    print("="*60)
    return X, T, Y, tau_true

# ==================== Methods ====================
def s_learner(X, T, Y):
    XT = np.column_stack([X, T])
    m = GradientBoostingRegressor(n_estimators=200, max_depth=4, learning_rate=0.1, random_state=42)
    m.fit(XT, Y)
    return m.predict(np.c_[X, np.ones(len(X))]) - m.predict(np.c_[X, np.zeros(len(X))])

def t_learner(X, T, Y):
    m1 = GradientBoostingRegressor(n_estimators=200, max_depth=4, learning_rate=0.1, random_state=42)
    m0 = GradientBoostingRegressor(n_estimators=200, max_depth=4, learning_rate=0.1, random_state=42)
    m1.fit(X[T==1], Y[T==1]); m0.fit(X[T==0], Y[T==0])
    return m1.predict(X) - m0.predict(X)

def x_learner(X, T, Y):
    m1 = GradientBoostingRegressor(n_estimators=200, max_depth=4, learning_rate=0.1, random_state=42)
    m0 = GradientBoostingRegressor(n_estimators=200, max_depth=4, learning_rate=0.1, random_state=42)
    m1.fit(X[T==1], Y[T==1]); m0.fit(X[T==0], Y[T==0])
    D1 = Y[T==1] - m0.predict(X[T==1])
    D0 = m1.predict(X[T==0]) - Y[T==0]
    tm1 = GradientBoostingRegressor(n_estimators=200, max_depth=4, learning_rate=0.1, random_state=42)
    tm0 = GradientBoostingRegressor(n_estimators=200, max_depth=4, learning_rate=0.1, random_state=42)
    tm1.fit(X[T==1], D1); tm0.fit(X[T==0], D0)
    ps = LogisticRegression(random_state=42, max_iter=1000); ps.fit(X, T)
    e = ps.predict_proba(X)[:,1]
    return e*tm0.predict(X) + (1-e)*tm1.predict(X)

def dml_cross_fit(X, T, Y, n_splits=5):
    kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
    Y_res, T_res = np.zeros(len(Y)), np.zeros(len(T), dtype=float)
    for tr, te in kf.split(X):
        my = GradientBoostingRegressor(n_estimators=100, max_depth=3, learning_rate=0.1, random_state=42)
        mt = GradientBoostingRegressor(n_estimators=100, max_depth=3, learning_rate=0.1, random_state=42)
        my.fit(X[tr], Y[tr]); mt.fit(X[tr], T[tr].astype(float))
        Y_res[te] = Y[te] - my.predict(X[te])
        T_res[te] = T[te] - mt.predict(X[te])
    return Y_res, T_res

def dml_causal_forest(X, T, Y):
    Y_res, T_res = dml_cross_fit(X, T, Y)
    T_clip = np.clip(np.abs(T_res), 0.01, None)*np.sign(T_res)
    T_clip[T_clip==0] = 0.01
    pseudo = Y_res / T_clip
    q_lo, q_hi = np.percentile(pseudo, [2, 98])
    pseudo = np.clip(pseudo, q_lo, q_hi)
    rf = RandomForestRegressor(n_estimators=300, max_depth=6, min_samples_leaf=20, random_state=42)
    rf.fit(X, pseudo)
    return rf.predict(X)

def dml_linear(X, T, Y):
    Y_res, T_res = dml_cross_fit(X, T, Y)
    Z = T_res.reshape(-1,1) * X
    reg = LinearRegression(fit_intercept=True); reg.fit(Z, Y_res)
    tau_hat = X @ reg.coef_ + reg.intercept_
    resid = Y_res - reg.predict(Z)
    n, p = Z.shape
    sigma2 = np.sum(resid**2)/(n-p-1)
    se = np.sqrt(np.diag(sigma2 * np.linalg.inv(Z.T@Z + 1e-8*np.eye(p))))
    return tau_hat, reg.coef_, se

# ==================== Evaluation ====================
def evaluate(tau_true, tau_hat, name):
    mse = np.mean((tau_true-tau_hat)**2)
    mae = np.mean(np.abs(tau_true-tau_hat))
    bias = np.mean(tau_hat-tau_true)
    ss_res = np.sum((tau_true-tau_hat)**2)
    ss_tot = np.sum((tau_true-tau_true.mean())**2)
    r2 = 1 - ss_res/ss_tot
    corr = np.corrcoef(tau_true, tau_hat)[0,1]
    return {'Method':name, 'MSE':round(mse,4), 'MAE':round(mae,4),
            'Bias':round(bias,4), 'R2':round(r2,4), 'Corr':round(corr,4)}

# ==================== Plots ====================
def plot_scatter(tau_true, res, path='fig1_cate_scatter.png'):
    n = len(res)
    fig, axes = plt.subplots(1, n, figsize=(3.8*n, 4), dpi=130)
    if n==1: axes=[axes]
    for i,(name,th) in enumerate(res.items()):
        ax=axes[i]
        ax.scatter(tau_true, th, alpha=0.12, s=6, c=COLORS[i%5])
        lims=[min(tau_true.min(),th.min())-0.5, max(tau_true.max(),th.max())+0.5]
        ax.plot(lims,lims,'k--',lw=1,alpha=0.5)
        c=np.corrcoef(tau_true,th)[0,1]; m=np.mean((tau_true-th)**2)
        ax.set_title(f'{name}\nCorr={c:.3f} MSE={m:.3f}',fontsize=10)
        ax.set_xlabel('True CATE',fontsize=9)
        if i==0: ax.set_ylabel('Estimated CATE',fontsize=9)
    plt.tight_layout(); plt.savefig(path,bbox_inches='tight'); plt.close()
    print(f"  [saved] {path}")

def plot_by_feature(X, tau_true, res, fi=0, fn='X1', path='fig2_cate_by_x1.png'):
    fig,ax=plt.subplots(1,1,figsize=(10,6),dpi=130)
    si=np.argsort(X[:,fi]); xs=X[si,fi]; w=100
    ax.scatter(xs,tau_true[si],alpha=0.06,s=4,c='gray',label='True CATE')
    ax.plot(xs,uniform_filter1d(tau_true[si],w),'k-',lw=2.5,label='True (smoothed)')
    for i,(name,th) in enumerate(res.items()):
        ax.plot(xs,uniform_filter1d(th[si],w),lw=2,alpha=0.85,color=COLORS[i%5],label=name)
    ax.set_xlabel(fn,fontsize=13); ax.set_ylabel('CATE',fontsize=13)
    ax.set_title(f'CATE vs {fn}',fontsize=14); ax.legend(fontsize=9,loc='upper left')
    plt.tight_layout(); plt.savefig(path,bbox_inches='tight'); plt.close()
    print(f"  [saved] {path}")

def plot_dist(tau_true, res, path='fig3_cate_dist.png'):
    fig,ax=plt.subplots(1,1,figsize=(10,6),dpi=130)
    ax.hist(tau_true,bins=50,alpha=0.3,color='gray',density=True,label='True CATE')
    for i,(name,th) in enumerate(res.items()):
        ax.hist(th,bins=50,alpha=0.4,color=COLORS[i%5],density=True,label=name,histtype='step',lw=2)
    ax.set_xlabel('CATE',fontsize=13); ax.set_ylabel('Density',fontsize=13)
    ax.set_title('CATE Distribution',fontsize=14); ax.legend(fontsize=9)
    plt.tight_layout(); plt.savefig(path,bbox_inches='tight'); plt.close()
    print(f"  [saved] {path}")

def plot_error_box(tau_true, res, path='fig4_error_boxplot.png'):
    fig,ax=plt.subplots(1,1,figsize=(10,6),dpi=130)
    errs=[th-tau_true for th in res.values()]; labs=list(res.keys())
    bp=ax.boxplot(errs,labels=labs,patch_artist=True,showfliers=False)
    for p,c in zip(bp['boxes'],COLORS[:len(labs)]): p.set_facecolor(c); p.set_alpha(0.5)
    ax.axhline(0,color='red',ls='--',lw=1,alpha=0.7)
    ax.set_ylabel('Error',fontsize=12); ax.set_title('CATE Estimation Error',fontsize=14)
    plt.tight_layout(); plt.savefig(path,bbox_inches='tight'); plt.close()
    print(f"  [saved] {path}")

def plot_coef(coef, se, fnames, path='fig5_dml_linear_coef.png'):
    fig,ax=plt.subplots(1,1,figsize=(8,5),dpi=130)
    y=np.arange(len(fnames)); ci=1.96*se
    ax.barh(y,coef,xerr=ci,color=COLORS[0],alpha=0.7,capsize=5)
    ax.set_yticks(y); ax.set_yticklabels(fnames,fontsize=12)
    ax.axvline(0,color='red',ls='--',lw=1)
    ax.set_xlabel('Coefficient',fontsize=12); ax.set_title('DML-Linear Coefficients (95% CI)',fontsize=13)
    for i,(c,s) in enumerate(zip(coef,se)):
        ax.text(c+ci[i]+0.02, i, f'{c:.3f}', va='center', fontsize=10)
    plt.tight_layout(); plt.savefig(path,bbox_inches='tight'); plt.close()
    print(f"  [saved] {path}")

def plot_heatmap(tau_true, X, path='fig6_true_cate_heatmap.png'):
    fig,ax=plt.subplots(1,1,figsize=(8,6),dpi=130)
    sc=ax.scatter(X[:,0],X[:,1],c=tau_true,cmap='RdYlBu_r',alpha=0.4,s=5,vmin=-3,vmax=5)
    plt.colorbar(sc,ax=ax,label='True CATE')
    ax.set_xlabel('X1',fontsize=13); ax.set_ylabel('X2',fontsize=13)
    ax.set_title(r'True CATE: $\tau(X) = X_1 + 0.5 X_2^2 - 1$',fontsize=14)
    plt.tight_layout(); plt.savefig(path,bbox_inches='tight'); plt.close()
    print(f"  [saved] {path}")

# ==================== Main ====================
def main():
    print("\n"+"="*60)
    print("   CATE Estimation - Simulation Study")
    print("="*60+"\n")

    X, T, Y, tau_true = generate_data(n=5000, seed=42)
    fnames = ['X1','X2','X3','X4','X5']
    res = {}; metrics = []

    print("\n>>> [1/5] S-Learner...")
    res['S-Learner'] = s_learner(X,T,Y)
    metrics.append(evaluate(tau_true, res['S-Learner'], 'S-Learner'))

    print(">>> [2/5] T-Learner...")
    res['T-Learner'] = t_learner(X,T,Y)
    metrics.append(evaluate(tau_true, res['T-Learner'], 'T-Learner'))

    print(">>> [3/5] X-Learner...")
    res['X-Learner'] = x_learner(X,T,Y)
    metrics.append(evaluate(tau_true, res['X-Learner'], 'X-Learner'))

    print(">>> [4/5] DML-CausalForest...")
    res['DML-CausalForest'] = dml_causal_forest(X,T,Y)
    metrics.append(evaluate(tau_true, res['DML-CausalForest'], 'DML-CausalForest'))

    print(">>> [5/5] DML-Linear...")
    tau_l, coef_l, se_l = dml_linear(X,T,Y)
    res['DML-Linear'] = tau_l
    metrics.append(evaluate(tau_true, tau_l, 'DML-Linear'))

    # Results table
    print("\n"+"="*60)
    print("  Evaluation Results")
    print("="*60)
    df = pd.DataFrame(metrics)
    print(df.to_string(index=False))

    # DML-Linear coefficients
    print("\n"+"="*60)
    print("  DML-Linear Coefficients (true: X1=1.0, X2=nonlinear, others=0)")
    print("="*60)
    for i,f in enumerate(fnames):
        lo=coef_l[i]-1.96*se_l[i]; hi=coef_l[i]+1.96*se_l[i]
        sig = "*" if lo>0 or hi<0 else ""
        print(f"  {f:>4}: {coef_l[i]:>7.4f}  SE={se_l[i]:.4f}  95%CI=[{lo:.4f}, {hi:.4f}] {sig}")

    # Plots
    print("\n>>> Generating plots...")
    plot_scatter(tau_true, res)
    plot_by_feature(X, tau_true, res, 0, 'X1', 'fig2_cate_by_x1.png')
    plot_by_feature(X, tau_true, res, 1, 'X2', 'fig2b_cate_by_x2.png')
    plot_dist(tau_true, res)
    plot_error_box(tau_true, res)
    plot_coef(coef_l, se_l, fnames)
    plot_heatmap(tau_true, X)

    print("\n"+"="*60)
    print("  All done!")
    print("="*60)
    return res, df

if __name__ == '__main__':
    results, df_metrics = main()

如果觉得有帮助,请点赞+收藏+关注,你的支持是我持续输出的动力!

相关推荐
希潭实验室2 小时前
第136篇:美国NSA的苹果手机“三角测量“后门的窃密模块分析 | 机器学习引擎识别照片信息
人工智能·机器学习·智能手机
绝不原创的飞龙2 小时前
使用机器学习和数学预测美国 GDP
人工智能·机器学习
Zero2 小时前
机器学习微积分--(1)核心思想
人工智能·算法·机器学习
competes2 小时前
学生需求 交易累计积分,积分兑换奖品
java·大数据·开发语言·人工智能·java-ee
大势智慧2 小时前
使用低空大师进行无人机巡查一段时间后,如何确定重点整治区域?
人工智能·ai·无人机·实景三维·低空经济·事件预警
小小数媒成员2 小时前
Shader中的光照模型
人工智能·计算机视觉
忧郁的橙子.2 小时前
12-大模型压缩预训练、知识蒸馏
人工智能·量化·模型蒸馏
科学创新前沿2 小时前
物理信息神经网络:从数据驱动到物理嵌入的科学计算新范式!
人工智能·深度学习·dft·pinn·流体力学·固体力学·断裂力学
高光视点2 小时前
海康存储亮相CFMS 2026全场景解决方案赋能AI存储新生态
人工智能