小样本学习(2)--LibFewShot使用

目录

一、LibFewShot安装

1、LibFewShot代码仓库

2、配置环境

3、测试安装是否正确

二、LibFewShot结构

1、config文件夹

2、core文件夹

3、reproduce文件夹

4、results文件夹

三、如何训练自己的数据集

1、调用主配置文件

2、修改主配置文件


一、LibFewShot安装

1、LibFewShot代码仓库

复制代码
cd ~
git clone https://github.com/RL-VIG/LibFewShot.git

2、配置环境

(1)创建anaconda环境

复制代码
cd <path-to-LibFewShot> # 进入clone好的LibFewShot目录
conda create -n libfewshot python=3.7
conda activate libfewshot

(2) 安装pytorch和torchvision

https://pytorch.org/get-started/locally/

(3)pip安装依赖包

复制代码
cd <path-to-LibFewShot> # cd 进入`LibFewShot` 目录
pip install -r requirements.txt

安装包依赖如下:

复制代码
numpy >= 1.19.5
pandas >= 1.1.5
Pillow >= 8.1.2
PyYAML >= 5.4.1
scikit-learn >= 0.24.1
scipy >= 1.5.4
tensorboard >= 2.4.1
torch >= 1.5.0
torchvision >= 0.6.0
python >= 3.6.0

3、测试安装是否正确

(1)修改run_trainer.py中config设置一行为

复制代码
config = Config("./config/test_install.yaml").get_config_dict()

(2) 修改config/headers/data.yaml中的data_root为当前数据集路径,博主使用阿里云GPU,数据集在根目录下,根目录为../

(3)执行run_trainer.py

复制代码
python run_trainer.py 

(4)若可以训练成功,跑通1个epoch则安装正确。

二、LibFewShot结构

1、config文件夹

config文件夹,主要为LibFewShot内置的模型的初始化模型文件yaml ,及模型中的backbone,classifier和header文件。

下属若干文件的yaml中,首先调用yaml文件中罗列的参数,如果没有再去访问includes中包括的data.yaml,device.yaml等文件。

下属classfiers文件夹中,其中基于微调的方法,如SKD,RFS是需要添加预训练的emb_func和cls_classfier路径的,这一部分路径来源于reproduce文件夹的md文件中有一定说明。

下属headers文件夹中包含五个yaml文件,分别是数据集,硬件设备,保存模型与生成日志,模型预训练、支持集参数、batchsize等,优化器和学习方式。

2、core文件夹

core文件夹中为核心模块,实现了模型架构,损失函数和优化器的内部结构。另外有train和test训练所调用的内部类架构。

简而言之,core文件实现基本函数,和基本的类,包括损失函数,神经网络,数据集的构建,而config文件夹作为core文件中若干函数,类的参数。

3、reproduce文件夹

值得一提的是,reproduce文件夹下的readme.md,这个文件介绍了不同的神经网络在预训练模型上的训练分数,对比了5-way 1-shot和5-way 5-shot,miniimagenet和tieredimagenet,conv64、resnet12和resnet18在不同网络的分数。

其中一般来说,在tieredimagenet训练集上训练的分数高,resnet比conv网络显著提升,resnet12在有些情况下甚至高于resnet18,所以要注意看一下对比实验的训练效果。

下面给出readme.md的一部分参考。

另外,在微调模型上,再加上resnet网络模型,可能会导致显存爆炸,所以需要降低batchsize。

reproduce中存在若干文件夹,这些是当时训练预训练模型时的参数,可以进行参考,但是不能照搬照抄,甚至你改了若干路径之后,也是存在一些无法修改的问题,暂时没有查出来问题。

4、results文件夹

results文件夹,显而易见,就是在模型训练之后保存模型checkpoints和日志log的地方。

三、如何训练自己的数据集

1、调用主配置文件

参考上面第一条中测试安装是否正确这一点,我们将修改run_trainer中添加config的这一行,可以先使用config文件夹下的一个网络的初始化的config.yaml调用。如博主调用skd.yaml。

2、修改主配置文件

(1)如果对于非微调方法的网络,是没有cls和emb路径的,所以不用考虑,对于微调方法来说,如果基于tiercedimage数据集的,在reproduce文件夹下的readme中会介绍这两者的预训练模型,而miniimage数据集没有处理这两者的预训练模型(所以优先考虑使用tiercedimage数据集的预训练模型)。

(2)修改神经网络预训练模型为指定路径,这个要么在config/model.yaml,要么文件里已经写出可以直接修改,(要同时修改还是修改一个,记住yaml里罗列的参数优先,如果找不到该参数才会找includes中的yaml文件)

(3)修改config/data.yaml文件夹中的数据集路径为所训练数据集路径

(4)观察主配置文件夹中的backbone,classfier是否对应预训练模型的要求,若不满足则修改

(5)观察config/model.yaml文件夹中的way_num、shot_num、query_num是否满足条件,前两者就是K way-C shot的K和C,query_num是指每次运用支持集时用了多少张测试图片来评判,test_way、test_shot、test_query一般来说跟上面相同即可

相关推荐
GIS数据转换器几秒前
城市生命线安全保障:技术应用与策略创新
大数据·人工智能·安全·3d·智慧城市
一水鉴天1 小时前
为AI聊天工具添加一个知识系统 之65 详细设计 之6 变形机器人及伺服跟随
人工智能
m0_743106464 小时前
【论文笔记】MV-DUSt3R+:两秒重建一个3D场景
论文阅读·深度学习·计算机视觉·3d·几何学
m0_743106464 小时前
【论文笔记】TranSplat:深度refine的camera-required可泛化稀疏方法
论文阅读·深度学习·计算机视觉·3d·几何学
井底哇哇7 小时前
ChatGPT是强人工智能吗?
人工智能·chatgpt
Coovally AI模型快速验证7 小时前
MMYOLO:打破单一模式限制,多模态目标检测的革命性突破!
人工智能·算法·yolo·目标检测·机器学习·计算机视觉·目标跟踪
AI浩8 小时前
【面试总结】FFN(前馈神经网络)在Transformer模型中先升维再降维的原因
人工智能·深度学习·计算机视觉·transformer
可为测控8 小时前
图像处理基础(4):高斯滤波器详解
人工智能·算法·计算机视觉
一水鉴天8 小时前
为AI聊天工具添加一个知识系统 之63 详细设计 之4:AI操作系统 之2 智能合约
开发语言·人工智能·python
倔强的石头1069 小时前
解锁辅助驾驶新境界:基于昇腾 AI 异构计算架构 CANN 的应用探秘
人工智能·架构