【python因果推断库15】使用 sci-kit learn 模型进行回归断点分析

目录

导入数据

线性模型和主效应模型

线性模型、主效应模型和交互作用模型

使用bandwidth


python 复制代码
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import ExpSineSquared, WhiteKernel
from sklearn.linear_model import LinearRegression

import causalpy as cp
%config InlineBackend.figure_format = 'retina'

导入数据

python 复制代码
data = cp.load_data("rd")
data.head()

线性模型和主效应模型

python 复制代码
result = cp.skl_experiments.RegressionDiscontinuity(
    data,
    formula="y ~ 1 + x + treated",
    model=LinearRegression(),
    treatment_threshold=0.5,
)
fig, ax = result.plot()
python 复制代码
result.summary(round_to=3)
复制代码
Difference in Differences experiment
Formula: y ~ 1 + x + treated
Running variable: x
Threshold on running variable: 0.5

Results:
Discontinuity at threshold = 0.19
Model coefficients:
  Intercept      	         0
  treated[T.True]	      0.19
  x              	      1.23

线性模型、主效应模型和交互作用模型

python 复制代码
result = cp.skl_experiments.RegressionDiscontinuity(
    data,
    formula="y ~ 1 + x + treated + x:treated",
    model=LinearRegression(),
    treatment_threshold=0.5,
)
result.plot();

虽然我们可以看到这样做并不能很好地拟合数据,几乎肯定高估了阈值处的不连续性。

python 复制代码
result.summary(round_to=3)
复制代码
Difference in Differences experiment
Formula: y ~ 1 + x + treated + x:treated
Running variable: x
Threshold on running variable: 0.5

Results:
Discontinuity at threshold = 0.92
Model coefficients:
  Intercept        	         0
  treated[T.True]  	      2.47
  x                	      1.32
  x:treated[T.True]	     -3.11

使用bandwidth

我们处理这个问题的一种方法是使用 `bandwidth` 参数。这将只对阈值附近的一定带宽内的数据进行拟合。如果 x 是连续变量,那么模型将只对满足 的数据进行拟合。

python 复制代码
result = cp.skl_experiments.RegressionDiscontinuity(
    data,
    formula="y ~ 1 + x + treated + x:treated",
    model=LinearRegression(),
    treatment_threshold=0.5,
    bandwidth=0.3,
)

result.plot();

我们甚至可以走得更远,只为接近阈值的数据拟合截距。但很明显,这将涉及更多的估计误差,因为我们使用的数据较少。

python 复制代码
result = cp.skl_experiments.RegressionDiscontinuity(
    data,
    formula="y ~ 1 + treated",
    model=LinearRegression(),
    treatment_threshold=0.5,
    bandwidth=0.3,
)

result.plot();
相关推荐
数据大魔方6 分钟前
【期货量化实战】豆粕期货量化交易策略(Python完整代码)
开发语言·数据库·python·算法·github·程序员创富
memmolo8 分钟前
【3D传感技术系列博客】
算法·计算机视觉·3d
不爱编程爱睡觉8 分钟前
代码随想录算法训练营第四十三天 | 图论理论基础、深搜理论基础、98. 所有可达路径、广搜理论基础
算法·leetcode·图论·代码随想录
六毛的毛9 分钟前
冗余连接II
算法
@汤圆酱11 分钟前
【无标题】
python·jmeter
雷焰财经16 分钟前
务实深耕,全栈赋能:宇信科技引领金融AI工程化落地新范式
人工智能·科技·金融
西柚小萌新17 分钟前
【计算机视觉CV:标注工具】--ISAT
人工智能·计算机视觉
内存不泄露19 分钟前
基于 Spring Boot 的医院预约挂号系统(全端协同)设计与实现
java·vue.js·spring boot·python·flask
三万棵雪松20 分钟前
【AI小智硬件程序(八)】
c++·人工智能·嵌入式·esp32·ai小智
永远都不秃头的程序员(互关)21 分钟前
【K-Means深度探索(二)】K值之谜:肘部法则与轮廓系数,如何选出你的最佳K?
算法·机器学习·kmeans