pytorch训练模板

来源:http://worthpen.top/#/home/blog?blog=pot-blog36.md

引言

本项目实现了基于PyTorch Lightning的神经网络训练和测试管道。项目除了实现PyTorch Lightning的工作流外,还实现了通过任务池在训练过程中添加任务、k折交叉验证、将训练结果保存在.cvs中、接受随机种子进行恢复训练、将模型转换为.onnx和.tflite。

项目地址: https://github.com/shenyan233/machine_learning_template

使用方法

环境配置

python version:3.7-3.10

bash 复制代码
pip install -r requirements.txt

cuda and torch need to be installed by itself. Recommendation:

bash 复制代码
pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113

配置网络架构和数据集

整个项目目录结构如下:

bash 复制代码
.
├── dataset
│   └── {dataset_name}
│        ├── test
│        │   ├── image
│        │   │   └── *.png
│        │   └── label.txt
│        ├── train
│        │   ├── image
│        │   │   └── *.png
│        │   └── label.txt
│        └── ...
├── network
│   └── {network_name}
│         ├── network.py
│         └── ...
└── ...

上述表示的文件或文件夹需要预先配置。省略号所代表的文件或文件夹保持默认即可。

数据集需要调整为自己的数据集,图像(*.png)名称为对应label.txt中的行号。您可以在此处自由调整数据集的保存格式,但Dataloder和其他类需要在'/dataset/{dataset name}/init.py'中重写。

network.py包含要训练的网络架构, 可以改为其他名称, 但是需要同步修改__init__.py。

任务流的配置参数保存在tasks.json中。

设置训练参数

在./network/{network_name}/config.json中设置参数,参数包括model_name、dataset_path、stage、max_epoch、batch_size等。训练参数包括可选参数和必选参数, 具体内容可浏览main.py内的注释。参数stage为'fit'或'test',分别表示处于训练阶段或测试阶段。

开始训练或测试

在终端或cmd内执行:

python3 main.py

相关推荐
信创DevOps先锋几秒前
开源中国全栈式AI教育解决方案亮相 破解高校科研与人才培养双重痛点
人工智能·开源
QQ676580084 分钟前
城市治理之河道污染识别 无人机河道污染巡检 塑料带识别 瓶子图像识别 深度学习垃圾识别第10384期
人工智能·深度学习·yolo·河道污染·无人机河道污染·瓶子图像·塑料袋识别
风象南4 分钟前
当技术解决了一切“怎么做”,人类还剩下什么?
人工智能
skilllite作者9 分钟前
SkillLite 多入口架构实战:CLI / Python SDK / MCP / Desktop / Swarm 一页理清
开发语言·人工智能·python·安全·架构·rust·agentskills
2501_9333295510 分钟前
技术深度剖析:Infoseek 字节探索舆情处置系统的全链路架构与核心实现
大数据·数据仓库·人工智能·自然语言处理·架构
网安情报局11 分钟前
RSAC 2026深度解析:AI对抗AI成主流,九大安全能力全面升级
人工智能·网络安全
key_3_feng11 分钟前
揭秘AI的“语言积木“:Token科普之旅
人工智能·搜索引擎·token
代码丰12 分钟前
Zero Code Studio:LangChain4j 工具调用 + LangGraph4j 工作流双模式的 AI 网站生成系统
java·人工智能
人工智能培训13 分钟前
多模态AI模型融合难?核心问题与解决思路
人工智能·机器学习·prompt·agent·智能体
FAFU_kyp13 分钟前
AP2 (Agent Payments Protocol) 技术流程详细解析
人工智能