机器学习——K 折交叉验证(K-Fold Cross Validation),案例:逻辑回归 交叉寻找最佳惩罚因子C


什么是交叉验证?

交叉验证是一种将原始数据集划分为若干个子集,反复训练和验证模型的策略。

交叉验证(Cross-Validation) 适用于你在模型调参(如逻辑回归中的 C

最常用的:K 折交叉验证(K-Fold Cross Validation)

将数据集平均分成 K 份,每次取其中 1 份做验证,剩下的 K-1 份做训练,重复 K 次,最终将 K 次的结果取平均。

图示流程(以 K=4 举例)

轮次 训练集 验证集
1 [2,3,4] [1]
2 [1,3,4] [2]
3 [1,2,4] [3]
4 [1,2,3] [4]

最后将 4 次的验证结果平均,得到模型在未见数据上的稳定表现。


为什么要使用交叉验证?

作用 说明
✅ 稳定评估模型表现 解决只依赖单一测试集带来的评估波动问题
✅ 防止过拟合 多次训练验证,有助于检测模型是否泛化能力不足
✅ 用于超参数选择 常用于网格搜索、正则化参数调优(如逻辑回归中的 C)

什么时候该用交叉验证?

场景 是否推荐使用交叉验证
数据量较小 ✅ 强烈建议
不平衡分类问题 ✅ 建议配合 StratifiedKFold
模型调参(如 C、k、深度) ✅ 必用
数据量极大(上百万) ❌ 考虑分批验证或子集评估

代码使用:

复制代码
from sklearn.model_selection import cross_val_score

cross_val_score(estimator, X, y=None, *, scoring=None, cv=None,
                n_jobs=None, verbose=0, fit_params=None, pre_dispatch='2*n_jobs',
                error_score=np.nan)

参数详解:

参数名 类型 说明
estimator 模型对象 要评估的模型,例如 LogisticRegression()RandomForestClassifier() 等 'model = LogisticRegression()'后直接传入'model'即可
X ndarray / DataFrame 特征数据集
y array-like 目标标签(监督学习必须)
scoring str 或 callable 指定评估指标(如 accuracy, recall, f1, roc_auc 等)
cv int 或 交叉验证对象 交叉验证折数,如 cv=5;或 StratifiedKFold, KFold 等对象
n_jobs int 并行执行的任务数:-1 使用所有核心,1 表示不并行
verbose int 控制打印的详细程度(0为不输出,越大越详细)
fit_params dict 要传递给 estimator.fit() 的额外参数(少用)
pre_dispatch str 控制预分发任务数,默认 '2*n_jobs',通常无需改动
error_score 'raise' 或 float 出错时返回分数,或抛异常。默认返回 NaN

实战案例:用交叉验证寻找最优惩罚因子 C

信用卡欺诈检测数据集 creditcard.csv

  • 数据来源 信用卡欺诈检测实战数据集_数据集-阿里云天池https://tianchi.aliyun.com/dataset/101562?accounttraceid=c1258603818f44d6a57fe125248cc969rkgu

  • 样本总数:284,807 条

  • 特征数 :30(28个匿名特征 + 金额 Amount + 时间 Time

  • 目标变量Class(0=正常交易,1=欺诈交易)

    import numpy as np
    import pandas as pd
    from sklearn.linear_model import LogisticRegression # 导入逻辑回归模型
    from sklearn.model_selection import train_test_split, cross_val_score # 用于数据拆分和交叉验证
    from sklearn.preprocessing import StandardScaler # 用于数据标准化处理
    from sklearn import metrics # 用于模型评估指标计算

    data = pd.read_csv('creditcard.csv')

    初始化标准化器,对交易金额(Amount)进行标准化处理

    scaler = StandardScaler()
    data['Amount'] = scaler.fit_transform(data[['Amount']])

    准备特征数据X(排除时间和目标变量)和目标变量y(欺诈标签,1表示欺诈,0表示正常)

    X = data.drop(["Time","Class"], axis=1)
    y = data.Class

    将数据拆分为训练集(70%)和测试集(30%),设置随机种子保证结果可复现

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=100)

    #----------------------------------------------------------------------------------------

    以下部分用于寻找最优的正则化参数C

    c_range = [0.01, 0.1, 1, 10, 100] # 定义要尝试的正则化参数C的取值范围(C越小,正则化强度越大)
    scores = [] # 存储不同C值对应的交叉验证平均召回率
    for c in c_range:
    model = LogisticRegression(C=c, penalty='l2', solver='lbfgs', max_iter=1000) # 初始化逻辑回归模型,指定正则化参数C、L2正则化、求解器和最大迭代次数
    score = cross_val_score(model, X_train, y_train, cv=8, scoring='recall') # 使用8折交叉验证,计算模型在训练集上的召回率recall
    score_mean = sum(score) / len(score) # 计算交叉验证召回率的平均值
    scores.append(score_mean) # 将平均召回率添加到列表中
    print(score_mean)

    找到最大平均召回率对应的C值,作为最优惩罚因子

    best_c = c_range[np.argmax(scores)] #argmax返回数组中最大值所在的索引位置
    print(f'最优惩罚因子为:{best_c}')
    #----------------------------------------------------------------------------------------

    使用最优惩罚因子训练最终的逻辑回归模型

    model = LogisticRegression(C=best_c, penalty='l2', solver='lbfgs')
    model.fit(X_train, y_train)


K-Fold Cross Validation 背后的原理(做了什么)

cross_val_score(model, X, y, cv=8) 等价于以下操作:

  1. 将数据按 8 等份分割

  2. 第一次拿前 7 份训练,第 8 份验证 → 计算指标

  3. 第二次拿 1,2,3,4,5,6,8 训练,第 7 份验证 → 计算指标

  4. ...

  5. 得到 8 个指标结果,返回组成数组

自动完成了分割、训练、预测和评分


常见扩展:StratifiedKFold(保持类别比例)

对于不平衡数据 (如欺诈检测),StratifiedKFold 是更合适的选择,它能在每一折中保持正负样本比例一致。

复制代码
from sklearn.model_selection import StratifiedKFold

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
scores = cross_val_score(model, X_train, y_train, cv=skf, scoring='recall')

相关推荐
何双新2 小时前
Odoo AI 智能查询系统
前端·人工智能·python
生命是有光的4 小时前
【机器学习】机器学习算法
人工智能·机器学习
Blossom.1184 小时前
把 AI 塞进「自行车码表」——基于 MEMS 的 3D 地形预测码表
人工智能·python·深度学习·opencv·机器学习·计算机视觉·3d
小鹿的工作手帐7 小时前
有鹿机器人:为城市描绘清洁新图景的智能使者
人工智能·科技·机器人
TechubNews8 小时前
香港数字资产交易市场蓬勃发展,监管与创新并驾齐驱
人工智能·区块链
DogDaoDao9 小时前
用PyTorch实现多类图像分类:从原理到实际操作
图像处理·人工智能·pytorch·python·深度学习·分类·图像分类
小和尚同志9 小时前
450 star 的神级提示词管理工具 AI-Gist,让提示词不再吃灰
人工智能·aigc
金井PRATHAMA10 小时前
大脑的藏宝图——神经科学如何为自然语言处理(NLP)的深度语义理解绘制新航线
人工智能·自然语言处理
Y|10 小时前
GBDT(Gradient Boosting Decision Tree,梯度提升决策树)总结梳理
决策树·机器学习·集成学习·推荐算法·boosting
大学生毕业题目10 小时前
毕业项目推荐:28-基于yolov8/yolov5/yolo11的电塔危险物品检测识别系统(Python+卷积神经网络)
人工智能·python·yolo·cnn·pyqt·电塔·危险物品