【python因果推断库15】使用 sci-kit learn 模型进行回归断点分析

目录

导入数据

线性模型和主效应模型

线性模型、主效应模型和交互作用模型

使用bandwidth


python 复制代码
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import ExpSineSquared, WhiteKernel
from sklearn.linear_model import LinearRegression

import causalpy as cp
%config InlineBackend.figure_format = 'retina'

导入数据

python 复制代码
data = cp.load_data("rd")
data.head()

线性模型和主效应模型

python 复制代码
result = cp.skl_experiments.RegressionDiscontinuity(
    data,
    formula="y ~ 1 + x + treated",
    model=LinearRegression(),
    treatment_threshold=0.5,
)
fig, ax = result.plot()
python 复制代码
result.summary(round_to=3)
复制代码
Difference in Differences experiment
Formula: y ~ 1 + x + treated
Running variable: x
Threshold on running variable: 0.5

Results:
Discontinuity at threshold = 0.19
Model coefficients:
  Intercept      	         0
  treated[T.True]	      0.19
  x              	      1.23

线性模型、主效应模型和交互作用模型

python 复制代码
result = cp.skl_experiments.RegressionDiscontinuity(
    data,
    formula="y ~ 1 + x + treated + x:treated",
    model=LinearRegression(),
    treatment_threshold=0.5,
)
result.plot();

虽然我们可以看到这样做并不能很好地拟合数据,几乎肯定高估了阈值处的不连续性。

python 复制代码
result.summary(round_to=3)
复制代码
Difference in Differences experiment
Formula: y ~ 1 + x + treated + x:treated
Running variable: x
Threshold on running variable: 0.5

Results:
Discontinuity at threshold = 0.92
Model coefficients:
  Intercept        	         0
  treated[T.True]  	      2.47
  x                	      1.32
  x:treated[T.True]	     -3.11

使用bandwidth

我们处理这个问题的一种方法是使用 `bandwidth` 参数。这将只对阈值附近的一定带宽内的数据进行拟合。如果 x 是连续变量,那么模型将只对满足 的数据进行拟合。

python 复制代码
result = cp.skl_experiments.RegressionDiscontinuity(
    data,
    formula="y ~ 1 + x + treated + x:treated",
    model=LinearRegression(),
    treatment_threshold=0.5,
    bandwidth=0.3,
)

result.plot();

我们甚至可以走得更远,只为接近阈值的数据拟合截距。但很明显,这将涉及更多的估计误差,因为我们使用的数据较少。

python 复制代码
result = cp.skl_experiments.RegressionDiscontinuity(
    data,
    formula="y ~ 1 + treated",
    model=LinearRegression(),
    treatment_threshold=0.5,
    bandwidth=0.3,
)

result.plot();
相关推荐
kyle~5 分钟前
Opencv---深度学习开发
人工智能·深度学习·opencv·计算机视觉·机器人
运器12318 分钟前
【一起来学AI大模型】PyTorch DataLoader 实战指南
大数据·人工智能·pytorch·python·深度学习·ai·ai编程
音元系统21 分钟前
Copilot 在 VS Code 中的免费替代方案
python·github·copilot
超龄超能程序猿33 分钟前
(5)机器学习小白入门 YOLOv:数据需求与图像不足应对策略
人工智能·python·机器学习·numpy·pandas·scipy
卷福同学34 分钟前
【AI编程】AI+高德MCP不到10分钟搞定上海三日游
人工智能·算法·程序员
mit6.82435 分钟前
[Leetcode] 预处理 | 多叉树bfs | 格雷编码 | static_cast | 矩阵对角线
算法
帅次41 分钟前
系统分析师-计算机系统-输入输出系统
人工智能·分布式·深度学习·神经网络·架构·系统架构·硬件架构
皮卡蛋炒饭.1 小时前
数据结构—排序
数据结构·算法·排序算法
AndrewHZ1 小时前
【图像处理基石】如何入门大规模三维重建?
人工智能·深度学习·大模型·llm·三维重建·立体视觉·大规模三维重建
5G行业应用1 小时前
【赠书福利,回馈公号读者】《智慧城市与智能网联汽车,融合创新发展之路》
人工智能·汽车·智慧城市