机器学习中的Data Leakage(数据泄漏)即target leakage和train-test contamination

目录

  1. Target leakage
  2. Train-Test Contamination
  3. 如何避免?
  4. 具体例子

kaggle教程

数据泄漏概念: 当您的训练数据包含有关目标的信息时,会发生数据泄漏(或泄漏),但当模型用于预测时,将无法获得类似的数据。这导致训练集(甚至可能还有验证数据)的高性能,但该模型在生产中表现不佳。

换句话说,换句话说,泄漏导致模型看起来准确,直到您开始用模型做出决策,然后模型变得非常不准确。

1Target leakage

一个例子会有所帮助。想象一下,你想预测谁会得肺炎。原始数据的前几行看起来像这样

是否得肺炎 是否服用抗生素 年龄 体重 性别 年龄
False False 100 65
True True 130 72
True True 100 58

这里就能看出点问题了,原始数据中,得肺炎和服用抗生素有很强的相关性,人们在患肺炎后服用抗生素是为了恢复健康,很明显可以看出,基本上没病的都不会去喝抗生素,用是否服用抗生素来这个特征用来训练结果会特别准。这就是所谓的target leakage.

再来一个例子 生病与吃药 用是否吃药,来预测是否会得病。 生病的才会吃药,所以用是否吃药来预测是否得病特别准。

2Train-Test Contamination

简单的来讲,就是使用train_test_split()之前,就对数据进行预处理了,如fit_transformed, imputer

比如在切分训练集 / 测试集之前,就使用了如均值插值法处理缺失值,那么训练集中其实就已经包含了测试集的信息。

3如何避免?

排除我们预测点之后的任何变量,打个比方,服用抗生素应该是在确定患肺炎之后,那我们预测患者是否得肺炎的时候,是否服用抗生素这一变量我们就不能使用。

1.统计分析与目标相关的列;

2.如果你建立一个模型并发现它非常精确(比如大于98%以上),可能有一个数据泄漏问题;

3.在交叉验证折叠中使用原始没进过预处理的数据;

4.使用Pipelines(一个典型的机器学习过程从数据收集开始,要经历多个步骤,才能得到需要的输出。这非常类似于流水线式工作,即通常会包含源数据ETL(抽取、转化、加载),数据预处理,指标提取,模型训练与交叉验证,新数据预测等步骤)。比如: scikit-learn Pipelines;

5. 使用Holdout Dataset。在使用模型之前,保留一个未使用过的的验证数据集作为对模型的最终健全性检查。

4具体例子

数据地址

py 复制代码
#target leakage之前,准确值太高,可能有异常
Cross-validation accuracy: 0.981810
#处理之后
Cross-val accuracy: 0.824096
  • 导入数据
  • 预测是否有卡card
py 复制代码
import pandas as pd

# Read the data
data = pd.read_csv('../input/aer-credit-card-data/AER_credit_card_data.csv', 
                   true_values = ['yes'], false_values = ['no'])

# Select target
y = data.card

# Select predictors
X = data.drop(['card'], axis=1)

print("Number of rows in the dataset:", X.shape[0])
X.head()
  • 用交叉验证
py 复制代码
from sklearn.pipeline import make_pipeline
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score

# Since there is no preprocessing, we don't need a pipeline (used anyway as best practice!)
my_pipeline = make_pipeline(RandomForestClassifier(n_estimators=100))
cv_scores = cross_val_score(my_pipeline, X, y, 
                            cv=5,
                            scoring='accuracy')

print("Cross-validation accuracy: %f" % cv_scores.mean())

预测结果, Cross-validation accuracy: 0.981810

有了经验,你会发现很少能找到98%准确率的模型。它发生了,但我们应该更仔细地检查数据是否有目标泄漏,这很不寻常。

以下是数据摘要,您也可以在数据选项卡下找到:

卡:如果接受信用卡申请,则为1,如果不接受,则为0

报告:主要贬损报告的数量

年龄:年龄n岁加十二分之一

收入:年收入(除以10,000)

份额:每月信用卡支出与年收入的比率

支出:每月平均信用卡支出

业主:1如果拥有房子,0如果租金

自营职业:自营职业者为1,非自营职业者为0

受抚养人:1 + 受抚养人数量

月份:在当前地址居住的月份

主要信用卡:持有的主要信用卡数量

活跃:活跃信用账户数量

一些变量看起来很可疑。例如,支出是指这张卡的支出还是申请前使用的卡的支出?

py 复制代码
expenditures_cardholders = X.expenditure[y]
expenditures_noncardholders = X.expenditure[~y]

print('没有收到卡片且没有支出的人中的一小部分: %.2f' \
      %((expenditures_noncardholders == 0).mean()))
print('收到卡片且没有支出的人中占一小部分: %.2f' \
      %(( expenditures_cardholders == 0).mean()))

没有收到卡片且没有支出的人中的一小部分:1.00

收到卡片且没有支出的人中占一小部分:0.02

如上所述,每个没有收到卡片的人都没有支出,而收到卡片的人中只有2%没有支出。我们的模型似乎具有很高的准确性,这并不奇怪。但这似乎也是一个目标泄漏的情况,支出可能意味着他们申请的卡上的支出。

由于份额部分由支出决定,因此也应该排除在外。变量活动卡和主卡不太清晰,但从描述上来说,它们听起来令人担忧。在大多数情况下,如果您无法追踪创建数据的人以了解更多信息,那么安全总比后悔好。

py 复制代码
# Drop leaky predictors from dataset
potential_leaks = ['expenditure', 'share', 'active', 'majorcards']
X2 = X.drop(potential_leaks, axis=1)

# Evaluate the model with leaky predictors removed
cv_scores = cross_val_score(my_pipeline, X2, y, 
                            cv=5,
                            scoring='accuracy')

print("Cross-val accuracy: %f" % cv_scores.mean())

Cross-val accuracy: 0.824096

这种准确性要低得多,这可能令人失望。然而,当用于新应用程序时,我们可以预计它大约80%的时间是正确的,而泄漏模型可能会比这差得多(尽管它在交叉验证中的表观分数更高)。

相关推荐
松果财经21 小时前
千亿级赛道,Robobus 赛道中标新加坡自动驾驶巴士项目的“确定性机会”
人工智能·机器学习·自动驾驶
Blossom.11821 小时前
用一颗MCU跑通7B大模型:RISC-V+SRAM极致量化实战
人工智能·python·单片机·嵌入式硬件·opencv·机器学习·risc-v
ARM+FPGA+AI工业主板定制专家1 天前
基于GPS/PTP/gPTP的自动驾驶数据同步授时方案
人工智能·机器学习·自动驾驶
lisw051 天前
SolidWorks:现代工程设计与数字制造的核心平台
人工智能·机器学习·青少年编程·软件工程·制造
学Linux的语莫1 天前
机器学习数据处理
java·算法·机器学习
递归不收敛1 天前
吴恩达机器学习课程(PyTorch适配)学习笔记:1.3 特征工程与模型优化
pytorch·学习·机器学习
B站_计算机毕业设计之家1 天前
机器学习实战项目:Python+Flask 汽车销量分析可视化系统(requests爬车主之家+可视化 源码+文档)✅
人工智能·python·机器学习·数据分析·flask·汽车·可视化
lucky_syq2 天前
解锁特征工程:机器学习的秘密武器
人工智能·机器学习
CM莫问2 天前
推荐算法之粗排
深度学习·算法·机器学习·数据挖掘·排序算法·推荐算法·粗排
rengang662 天前
10-支持向量机(SVM):讲解基于最大间隔原则的分类算法
人工智能·算法·机器学习·支持向量机