pytorch训练模板

来源:http://worthpen.top/#/home/blog?blog=pot-blog36.md

引言

本项目实现了基于PyTorch Lightning的神经网络训练和测试管道。项目除了实现PyTorch Lightning的工作流外,还实现了通过任务池在训练过程中添加任务、k折交叉验证、将训练结果保存在.cvs中、接受随机种子进行恢复训练、将模型转换为.onnx和.tflite。

项目地址: https://github.com/shenyan233/machine_learning_template

使用方法

环境配置

python version:3.7-3.10

bash 复制代码
pip install -r requirements.txt

cuda and torch need to be installed by itself. Recommendation:

bash 复制代码
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113

配置网络架构和数据集

整个项目目录结构如下:

bash 复制代码
.
├── dataset
│   └── {dataset_name}
│        ├── test
│        │   ├── image
│        │   │   └── *.png
│        │   └── label.txt
│        ├── train
│        │   ├── image
│        │   │   └── *.png
│        │   └── label.txt
│        └── ...
├── network
│   └── {network_name}
│         ├── network.py
│         └── ...
└── ...

上述表示的文件或文件夹需要预先配置。省略号所代表的文件或文件夹保持默认即可。

数据集需要调整为自己的数据集,图像(*.png)名称为对应label.txt中的行号。您可以在此处自由调整数据集的保存格式,但Dataloder和其他类需要在'/dataset/{dataset name}/init.py'中重写。

network.py包含要训练的网络架构, 可以改为其他名称, 但是需要同步修改__init__.py。

任务流的配置参数保存在tasks.json中。

设置训练参数

在./network/{network_name}/config.json中设置参数,参数包括model_name、dataset_path、stage、max_epoch、batch_size等。训练参数包括可选参数和必选参数, 具体内容可浏览main.py内的注释。参数stage为'fit'或'test',分别表示处于训练阶段或测试阶段。

开始训练或测试

在终端或cmd内执行:

python3 main.py

相关推荐
心本无晴.1 分钟前
RAG中的混合检索(Hybrid Search):稀疏检索与稠密检索的强强联合
人工智能·python·算法
橙露3 分钟前
Python 办公自动化:批量处理 Excel/Word/PPT 实战教程
python·word·excel
咚咚王者3 分钟前
人工智能之视觉领域 计算机视觉 第十三章 视频背景减除
人工智能·计算机视觉·音视频
你的论文学长3 分钟前
对抗知网的 N-Gram 算法:基于语义解耦的【文本重构】与【事实性核验】架构设计
人工智能·算法·重构
Java_慈祥6 分钟前
My First AI智能体!!!
python·agent·coze
一水鉴天6 分钟前
关于“整体设计定稿” 的高阶表述 20260222
人工智能·架构
美酒没故事°7 分钟前
mac电脑安装OpenClaw步骤
人工智能·macos
用户3521802454758 分钟前
RAG 做不好?可能是你的 PDF 在"捣乱" 😅
后端·python·ai编程
FL16238631298 分钟前
基于yolov11+django+deepseek的脑肿瘤检测系统带登录界面python源码+onnx模型+精美web界面
python·yolo·django
沪漂阿龙9 分钟前
大模型幻觉深度解析:成因、检测与工程缓解策略
人工智能·深度学习·机器学习