使用sklearn函数对模型进行交叉验证

使用sklearn函数对模型进行交叉验证

交叉验证用来做什么

交叉验证(Cross-Validatio),是用于在驯良过程中对训练模型的性能和参数进行评估选择的技术。

它的意义在于能够充分利用优先的数据集,减少数据分布不均匀以及随机性带来的模型评估误差。

交叉验证的作用就是将数据集分割成多个自己进行多次训练,每次训练的训练集与测试机不完全相同。

sklearn 中的函数

python 复制代码
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold
skf = KFold(n_splits=10, random_state=233, shuffle=True)

n_splits:int, default=5

表示,要分割为多少个K子集

shuffle:bool, default=False

是否打乱数据

random_state:int or RandomState instance, default=None

随机状态,需要配合shuffle参数使用

参考文章 https://blog.csdn.net/weixin_43803950/article/details/120894868

python 复制代码
# 如果有额外的标签,train_path 标签数据,如果标签是跟随train_path,第二个可不填入
skf.split(train_path, train_path)
python 复制代码
   for fold_idx, (train_idx, val_idx) in enumerate(skf.split(train_path, train_path)):
        train_loader = torch.utils.data.DataLoader(
            XunFeiDataset(np.array(train_path)[train_idx],
                          A.Compose([
                              A.RandomRotate90(),
                              A.RandomCrop(120, 120),
                              A.HorizontalFlip(p=0.5),
                              A.RandomContrast(p=0.5),
                              A.RandomBrightnessContrast(p=0.5),
                          ])
                          ), batch_size=8, shuffle=True, num_workers=0, pin_memory=False
        )

        val_loader = torch.utils.data.DataLoader(
            XunFeiDataset(np.array(train_path)[val_idx],
                          A.Compose([
                              A.RandomCrop(120, 120),
                          ])
                          ), batch_size=8, shuffle=False, num_workers=0, pin_memory=False
        )

        for epoch_item in range(30):

            # adjust_learning_rate(optimizer, epoch_item)

            train_loss = train(train_loader, model, criterion, optimizer)

            val_acc = validate(val_loader, model, criterion)

            train_acc = validate(train_loader, model, criterion)

            print(train_loss, train_acc, val_acc)
相关推荐
vx_biyesheji00011 小时前
计算机毕业设计:Python汽车数据分析系统 Django框架 requests爬虫 可视化 车辆 数据分析 大数据 机器学习(建议收藏)✅
爬虫·python·算法·机器学习·django·汽车·课程设计
承渊政道2 小时前
从n-grams到Transformer:一文读懂语言模型基础
深度学习·学习·语言模型·自然语言处理·chatgpt·transformer·机器翻译
xrgs_shz2 小时前
图像的点运算(线性点运算和非线性点运算)
人工智能·算法·机器学习
大模型实验室Lab4AI2 小时前
LlamaFactory 微调实测|Qwen3-4B现代诗风格微调
人工智能·深度学习
sin°θ_陈2 小时前
前馈式3D Gaussian Splatting 研究地图(总览篇):解构七大路线,梳理方法谱系,看懂关键分歧与未来趋势
论文阅读·深度学习·算法·3d·aigc·空间计算·3dgs
普密斯科技2 小时前
高精度车载插座多维度检测方案——基于3D线激光轮廓传感器的实践应用
大数据·人工智能·深度学习·计算机视觉·3d·测量
LingYi_02 小时前
语义分割-paddleseg
深度学习·语义分割
B站_计算机毕业设计之家2 小时前
计算机毕业设计:汽车数据可视化与后台管理平台 Django框架 requests爬虫 可视化 车辆 数据分析 大数据 机器学习(建议收藏)✅
python·算法·机器学习·信息可视化·django·汽车·课程设计
gloomyfish11 小时前
【最新认知】2026 | 深度学习工业缺陷检测三种技术路线分析与趋势
人工智能·深度学习
Lab_AI13 小时前
AI for Science应用:深度学习助力新型靶蛋白的药物从头设计(AIDD助力药物研发)
人工智能·深度学习·aidd·药物发现·新靶点药物设计