OOD分类项目训练

一、项目地址

GitHub - LooKing9218/UIOS

二、label制作

将训练、验证、测试数据的分类信息转换入.csv文件中,运行如下脚本即可:

复制代码
import os
import csv
 
#要读取的训练、验证、测试文件的目录,该文件下保存着以各个类别命名的文件夹和对应的分类图片
root_path=r'/media/*********************/train' 
#类别种类
classes=['cls1','cls2']

def get_Write_file_infos(path):
    # 文件信息列表
    file_infos_list=[]
    typeclothes=os.listdir(path)
    for ii in typeclothes:
        everyfile=os.path.join(path , ii)
        for root, dirnames, filenames in os.walk(everyfile):
            for filename in filenames:
                file_infos = {}
                dirname=root
                 
                #根据自己的需求更改路径地址
                filename1 ='train/'+ii+'/'+ filename#.split('.jpg')[0]
                flag = filename1[-1]
                file_infos["ImageId"] = filename1
     
                file_infos["Flag"] = classes.index(ii)
                #将数据追加字典到列表中
                file_infos_list.append(file_infos)
                
    return file_infos_list
 
 
#写入csv文件
def write_csv(file_infos_list):
    with open('train_label.csv','a+',newline='') as csv_file_train:
        csv_writer = csv.DictWriter(csv_file_train,fieldnames=['ImageId','Flag'])
        csv_writer.writeheader()
        for each in file_infos_list:
            print(each)
            csv_writer.writerow(each)
            
def main():
    file_infos_list =get_Write_file_infos(root_path)
    write_csv(file_infos_list)
 
 
if __name__ == '__main__':
    main()
    print('The End!')

生成情况如下:

三、运行程序

(1)修改参数文件 utils/config.py

复制代码
# -*- coding: utf-8 -*-
class DefaultConfig(object):
    net_work = 'ResUnNet50'
    num_classes = 2
    num_epochs = 100
    batch_size = 256
    validation_step = 1
    root = "/media/code/"
    train_file = "train_label.csv"
    val_file = "val_label.csv"
    test_file = "test_label.csv"
    lr = 1e-4
    lr_mode = 'poly'
    momentum = 0.9
    weight_decay = 1e-4
    save_model_path = './Model_Saved'.format(net_work,lr)
    log_dirs = './Logs_Adam_0304'
    pretrained =True# False
    pretrained_model_path ='/media/code/UIOS-master/Trained/archive/data/99843712' #None
    cuda = 0
    num_workers = 4
    use_gpu = True
    trained_model_path = ''
    predict_fold = 'predict_mask'

(2)运行

命令:

复制代码
python train.py

(3)运行界面

四、踩坑记录

问题原因:ValueError: Only one class present in y_true. ROC AUC score is not defined in that case.

解决方法:

(1)网上看了很多:

方法1:添加 try-except

复制代码
        try:
            epoch_train_auc = metrics.roc_auc_score(labels, outputs)

            writer.add_scalar('Train/train_auc', float(epoch_train_auc),
                          epoch)
            print('loss for train : {},{}'.format(loss_train_mean,round(epoch_train_auc,6)))

        except ValueError:
            pass

方法2:DataLoader的参数设置shuffle=True

复制代码
   train_loader = DataLoader(DatasetCFP(
        root=args.root,
        mode='train',
        data_file=args.train_file,
    ),
        batch_size=args.batch_size, shuffle=True, pin_memory=True)
    val_loader = DataLoader(DatasetCFP(
        root=args.root,
        mode='val',
        data_file=args.val_file,
    ),
        batch_size=args.batch_size, shuffle=True, pin_memory=True)
    test_loader = DataLoader(DatasetCFP(
        root=args.root,
        mode='test',
        data_file=args.test_file,
    ),
        batch_size=args.batch_size, shuffle=True, pin_memory=True)

方法3:增大batch_size

(2)我的方法:

其实是我马虎大意

修改好config.py中的num_classes参数就行了,

见谅(不好意思~( ̄▽ ̄)~*)

相关推荐
独孤--蝴蝶1 分钟前
AI人工智能-机器学习-第一周(小白)
人工智能·机器学习
西柚小萌新3 分钟前
【深入浅出PyTorch】--上采样+下采样
人工智能·pytorch·python
丁学文武30 分钟前
大语言模型(LLM)是“预制菜”? 从应用到底层原理,在到中央厨房的深度解析
人工智能·语言模型·自然语言处理·大语言模型·大模型应用·预制菜
fie888935 分钟前
基于MATLAB的声呐图像特征提取与显示
开发语言·人工智能
文火冰糖的硅基工坊2 小时前
[嵌入式系统-100]:常见的IoT(物联网)开发板
人工智能·物联网·架构
刘晓倩2 小时前
实战任务二:用扣子空间通过任务提示词制作精美PPT
人工智能
却道天凉_好个秋2 小时前
OpenCV(七):BGR
opencv·计算机视觉
shut up2 小时前
LangChain - 如何使用阿里云百炼平台的Qwen-plus模型构建一个桌面文件查询AI助手 - 超详细
人工智能·python·langchain·智能体
Hy行者勇哥2 小时前
公司全场景运营中 PPT 的类型、功能与作用详解
大数据·人工智能
FIN66683 小时前
昂瑞微:实现精准突破,攻坚射频“卡脖子”难题
前端·人工智能·安全·前端框架·信息与通信