OOD分类项目训练

一、项目地址

GitHub - LooKing9218/UIOS

二、label制作

将训练、验证、测试数据的分类信息转换入.csv文件中,运行如下脚本即可:

import os
import csv
 
#要读取的训练、验证、测试文件的目录,该文件下保存着以各个类别命名的文件夹和对应的分类图片
root_path=r'/media/*********************/train' 
#类别种类
classes=['cls1','cls2']

def get_Write_file_infos(path):
    # 文件信息列表
    file_infos_list=[]
    typeclothes=os.listdir(path)
    for ii in typeclothes:
        everyfile=os.path.join(path , ii)
        for root, dirnames, filenames in os.walk(everyfile):
            for filename in filenames:
                file_infos = {}
                dirname=root
                 
                #根据自己的需求更改路径地址
                filename1 ='train/'+ii+'/'+ filename#.split('.jpg')[0]
                flag = filename1[-1]
                file_infos["ImageId"] = filename1
     
                file_infos["Flag"] = classes.index(ii)
                #将数据追加字典到列表中
                file_infos_list.append(file_infos)
                
    return file_infos_list
 
 
#写入csv文件
def write_csv(file_infos_list):
    with open('train_label.csv','a+',newline='') as csv_file_train:
        csv_writer = csv.DictWriter(csv_file_train,fieldnames=['ImageId','Flag'])
        csv_writer.writeheader()
        for each in file_infos_list:
            print(each)
            csv_writer.writerow(each)
            
def main():
    file_infos_list =get_Write_file_infos(root_path)
    write_csv(file_infos_list)
 
 
if __name__ == '__main__':
    main()
    print('The End!')

生成情况如下:

三、运行程序

(1)修改参数文件 utils/config.py

# -*- coding: utf-8 -*-
class DefaultConfig(object):
    net_work = 'ResUnNet50'
    num_classes = 2
    num_epochs = 100
    batch_size = 256
    validation_step = 1
    root = "/media/code/"
    train_file = "train_label.csv"
    val_file = "val_label.csv"
    test_file = "test_label.csv"
    lr = 1e-4
    lr_mode = 'poly'
    momentum = 0.9
    weight_decay = 1e-4
    save_model_path = './Model_Saved'.format(net_work,lr)
    log_dirs = './Logs_Adam_0304'
    pretrained =True# False
    pretrained_model_path ='/media/code/UIOS-master/Trained/archive/data/99843712' #None
    cuda = 0
    num_workers = 4
    use_gpu = True
    trained_model_path = ''
    predict_fold = 'predict_mask'

(2)运行

命令:

python train.py

(3)运行界面

四、踩坑记录

问题原因:ValueError: Only one class present in y_true. ROC AUC score is not defined in that case.

解决方法:

(1)网上看了很多:

方法1:添加 try-except

        try:
            epoch_train_auc = metrics.roc_auc_score(labels, outputs)

            writer.add_scalar('Train/train_auc', float(epoch_train_auc),
                          epoch)
            print('loss for train : {},{}'.format(loss_train_mean,round(epoch_train_auc,6)))

        except ValueError:
            pass

方法2:DataLoader的参数设置shuffle=True

   train_loader = DataLoader(DatasetCFP(
        root=args.root,
        mode='train',
        data_file=args.train_file,
    ),
        batch_size=args.batch_size, shuffle=True, pin_memory=True)
    val_loader = DataLoader(DatasetCFP(
        root=args.root,
        mode='val',
        data_file=args.val_file,
    ),
        batch_size=args.batch_size, shuffle=True, pin_memory=True)
    test_loader = DataLoader(DatasetCFP(
        root=args.root,
        mode='test',
        data_file=args.test_file,
    ),
        batch_size=args.batch_size, shuffle=True, pin_memory=True)

方法3:增大batch_size

(2)我的方法:

其实是我马虎大意

修改好config.py中的num_classes参数就行了,

见谅(不好意思~( ̄▽ ̄)~*)

相关推荐
小陈phd42 分钟前
OpenCV从入门到精通实战(九)——基于dlib的疲劳监测 ear计算
人工智能·opencv·计算机视觉
Guofu_Liao2 小时前
大语言模型---LoRA简介;LoRA的优势;LoRA训练步骤;总结
人工智能·语言模型·自然语言处理·矩阵·llama
秀儿还能再秀4 小时前
神经网络(系统性学习三):多层感知机(MLP)
神经网络·学习笔记·mlp·多层感知机
ZHOU_WUYI6 小时前
3.langchain中的prompt模板 (few shot examples in chat models)
人工智能·langchain·prompt
如若1236 小时前
主要用于图像的颜色提取、替换以及区域修改
人工智能·opencv·计算机视觉
老艾的AI世界6 小时前
AI翻唱神器,一键用你喜欢的歌手翻唱他人的曲目(附下载链接)
人工智能·深度学习·神经网络·机器学习·ai·ai翻唱·ai唱歌·ai歌曲
DK221516 小时前
机器学习系列----关联分析
人工智能·机器学习
Robot2516 小时前
Figure 02迎重大升级!!人形机器人独角兽[Figure AI]商业化加速
人工智能·机器人·微信公众平台
FreedomLeo17 小时前
Python数据分析NumPy和pandas(四十、Python 中的建模库statsmodels 和 scikit-learn)
python·机器学习·数据分析·scikit-learn·statsmodels·numpy和pandas
浊酒南街7 小时前
Statsmodels之OLS回归
人工智能·数据挖掘·回归