【python因果推断库15】使用 sci-kit learn 模型进行回归断点分析

目录

导入数据

线性模型和主效应模型

线性模型、主效应模型和交互作用模型

使用bandwidth


python 复制代码
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import ExpSineSquared, WhiteKernel
from sklearn.linear_model import LinearRegression

import causalpy as cp
%config InlineBackend.figure_format = 'retina'

导入数据

python 复制代码
data = cp.load_data("rd")
data.head()

线性模型和主效应模型

python 复制代码
result = cp.skl_experiments.RegressionDiscontinuity(
    data,
    formula="y ~ 1 + x + treated",
    model=LinearRegression(),
    treatment_threshold=0.5,
)
fig, ax = result.plot()
python 复制代码
result.summary(round_to=3)
复制代码
Difference in Differences experiment
Formula: y ~ 1 + x + treated
Running variable: x
Threshold on running variable: 0.5

Results:
Discontinuity at threshold = 0.19
Model coefficients:
  Intercept      	         0
  treated[T.True]	      0.19
  x              	      1.23

线性模型、主效应模型和交互作用模型

python 复制代码
result = cp.skl_experiments.RegressionDiscontinuity(
    data,
    formula="y ~ 1 + x + treated + x:treated",
    model=LinearRegression(),
    treatment_threshold=0.5,
)
result.plot();

虽然我们可以看到这样做并不能很好地拟合数据,几乎肯定高估了阈值处的不连续性。

python 复制代码
result.summary(round_to=3)
复制代码
Difference in Differences experiment
Formula: y ~ 1 + x + treated + x:treated
Running variable: x
Threshold on running variable: 0.5

Results:
Discontinuity at threshold = 0.92
Model coefficients:
  Intercept        	         0
  treated[T.True]  	      2.47
  x                	      1.32
  x:treated[T.True]	     -3.11

使用bandwidth

我们处理这个问题的一种方法是使用 `bandwidth` 参数。这将只对阈值附近的一定带宽内的数据进行拟合。如果 x 是连续变量,那么模型将只对满足 的数据进行拟合。

python 复制代码
result = cp.skl_experiments.RegressionDiscontinuity(
    data,
    formula="y ~ 1 + x + treated + x:treated",
    model=LinearRegression(),
    treatment_threshold=0.5,
    bandwidth=0.3,
)

result.plot();

我们甚至可以走得更远,只为接近阈值的数据拟合截距。但很明显,这将涉及更多的估计误差,因为我们使用的数据较少。

python 复制代码
result = cp.skl_experiments.RegressionDiscontinuity(
    data,
    formula="y ~ 1 + treated",
    model=LinearRegression(),
    treatment_threshold=0.5,
    bandwidth=0.3,
)

result.plot();
相关推荐
Antonio9152 分钟前
【图像处理】图像的基础几何变换
图像处理·人工智能·计算机视觉
新加坡内哥谈技术1 小时前
Perplexity AI 的 RAG 架构全解析:幕后技术详解
人工智能
武子康1 小时前
AI研究-119 DeepSeek-OCR PyTorch FlashAttn 2.7.3 推理与部署 模型规模与资源详细分析
人工智能·深度学习·机器学习·ai·ocr·deepseek·deepseek-ocr
智驱力人工智能2 小时前
基于视觉分析的人脸联动使用手机检测系统 智能安全管理新突破 人脸与手机行为联动检测 多模态融合人脸与手机行为分析模型
算法·安全·目标检测·计算机视觉·智能手机·视觉检测·边缘计算
Mr_Xuhhh2 小时前
GUI自动化测试--自动化测试的意义和应用场景
python·集成测试
Sirius Wu2 小时前
深入浅出:Tongyi DeepResearch技术解读
人工智能·语言模型·langchain·aigc
2301_764441332 小时前
水星热演化核幔耦合数值模拟
python·算法·数学建模
循环过三天2 小时前
3.4、Python-集合
开发语言·笔记·python·学习·算法
Q_Q5110082853 小时前
python+django/flask的眼科患者随访管理系统 AI智能模型
spring boot·python·django·flask·node.js·php
忙碌5443 小时前
AI大模型时代下的全栈技术架构:从深度学习到云原生部署实战
人工智能·深度学习·架构