使用sklearn函数对模型进行交叉验证

使用sklearn函数对模型进行交叉验证

交叉验证用来做什么

交叉验证(Cross-Validatio),是用于在驯良过程中对训练模型的性能和参数进行评估选择的技术。

它的意义在于能够充分利用优先的数据集,减少数据分布不均匀以及随机性带来的模型评估误差。

交叉验证的作用就是将数据集分割成多个自己进行多次训练,每次训练的训练集与测试机不完全相同。

sklearn 中的函数

python 复制代码
from sklearn.model_selection import train_test_split, StratifiedKFold, KFold
skf = KFold(n_splits=10, random_state=233, shuffle=True)

n_splits:int, default=5

表示,要分割为多少个K子集

shuffle:bool, default=False

是否打乱数据

random_state:int or RandomState instance, default=None

随机状态,需要配合shuffle参数使用

参考文章 https://blog.csdn.net/weixin_43803950/article/details/120894868

python 复制代码
# 如果有额外的标签,train_path 标签数据,如果标签是跟随train_path,第二个可不填入
skf.split(train_path, train_path)
python 复制代码
   for fold_idx, (train_idx, val_idx) in enumerate(skf.split(train_path, train_path)):
        train_loader = torch.utils.data.DataLoader(
            XunFeiDataset(np.array(train_path)[train_idx],
                          A.Compose([
                              A.RandomRotate90(),
                              A.RandomCrop(120, 120),
                              A.HorizontalFlip(p=0.5),
                              A.RandomContrast(p=0.5),
                              A.RandomBrightnessContrast(p=0.5),
                          ])
                          ), batch_size=8, shuffle=True, num_workers=0, pin_memory=False
        )

        val_loader = torch.utils.data.DataLoader(
            XunFeiDataset(np.array(train_path)[val_idx],
                          A.Compose([
                              A.RandomCrop(120, 120),
                          ])
                          ), batch_size=8, shuffle=False, num_workers=0, pin_memory=False
        )

        for epoch_item in range(30):

            # adjust_learning_rate(optimizer, epoch_item)

            train_loss = train(train_loader, model, criterion, optimizer)

            val_acc = validate(val_loader, model, criterion)

            train_acc = validate(train_loader, model, criterion)

            print(train_loss, train_acc, val_acc)
相关推荐
蓝海星梦8 分钟前
GRPO 算法演进——偏差修正/鲁棒优化/架构扩展篇
论文阅读·人工智能·深度学习·算法·自然语言处理·强化学习
Dev7z9 分钟前
基于深度学习的肺音分类算法研究:从肺音识别到疾病辅助诊断
人工智能·深度学习·分类·肺音分类算法
jay神23 分钟前
基于MobileNet花卉识别系统
人工智能·深度学习·计算机视觉·毕业设计·花卉识别
zhangfeng113325 分钟前
大语言模型llm 量化模型 跑在 边缘设备小显存显卡 GGUF GGML PyTorch (.pth, .bin, SafeTensors)
人工智能·pytorch·深度学习·语言模型
纤纡.25 分钟前
深度学习环境搭建:CUDA+PyTorch+TorchVision+Torchaudio 一站式安装教程
人工智能·pytorch·深度学习
是小蟹呀^29 分钟前
图像识别/分类常见学习范式:有监督、无监督、自监督、半监督……(通俗版)
人工智能·深度学习·分类
kebijuelun29 分钟前
Towards Automated Kernel Generation in the Era of LLMs:LLM 时代的自动化 Kernel 生成全景图
人工智能·gpt·深度学习·语言模型
汉克老师35 分钟前
小学生0基础学大语言模型应用(第 19 课《字符串提示词训练(Prompt Thinking)》)
人工智能·深度学习·机器学习·语言模型·prompt·提示词
weisian15139 分钟前
进阶篇-11-数学篇-10--梯度在神经网络中的实际应用:从“猜答案”到“学会思考”的旅程
人工智能·深度学习·神经网络·梯度下降·反向传播·学习率·正向传播
狮子座明仔1 小时前
AgentScope 深度解读:多智能体开发框架的工程化实践
人工智能·深度学习·语言模型·自然语言处理