【深度学习实战(1)】如何使用argparse模块设置自己的训练参数

一、argparse模块用法

1、argparse是一个python模块,用途是:命令行选项、参数和子命令的解释。

2、argparse库下载:pip install argparse

3、使用步骤:

导入argparse模块,并创建解释器

添加所需参数

解析参数

二、代码

cpp 复制代码
import argparse


def add_common_arguments(parser):
    """Add common arguments for training and inference."""
    parser.add_argument('--save_best_weights',
                        default='model_data/best.pth',
                        help="save best weights name.")
    parser.add_argument('--phi', type=str, default='s')
    parser.add_argument('--num_classes', type=int, default=10)

def get_parser_for_training():
    """Return argument parser for training."""
    # -------------------------------------------#
    #   Step 1. 构造解析器 argparse.ArgumentParser()
    # -------------------------------------------#
    parser = argparse.ArgumentParser("Training args")
    # -------------------------------------------#
    #   Step 2. 添加参数 .add_argument()
    # -------------------------------------------#
    parser.add_argument('--train_path',default='/data/train',help="The location of dataset.")
    parser.add_argument('--sync_bn', type=bool,default=False,help='use SyncBatchNorm, only available in DDP mode')
    parser.add_argument('--Cuda', type=bool,default=True)
    parser.add_argument('--fp16', type=bool,default=False)
    parser.add_argument('--num_workers', type=int, default=8,help="Number of workers for data loading.")
    parser.add_argument('--Total_epoch', type=int, default=300,help='Total Epoch')
    parser.add_argument('--Batch_size', type=int, default=64,help='Batch_size')
    # -------------------------------------------#
    #   Step 2. 添加参数 .add_argument()
    # -------------------------------------------#
    add_common_arguments(parser)
    return parser


if __name__=='__main__':
    # -------------------------------------------#
    #   Step 3. 解析参数 .parse_args()
    # -------------------------------------------#
    train_parser = get_parser_for_training()
    train_args = train_parser.parse_args()
    print(train_args)
    # -------------------------------------------#
    #   training args
    # -------------------------------------------#
    print("training data path:",train_args.train_path)
    print("training batch size:",train_args.Batch_size)
    print("Cuda:",train_args.Cuda)
    # -------------------------------------------#
    #   common args
    # -------------------------------------------#
    print("num classes:",train_args.num_classes)
    print("phi:",train_args.phi)
    print("save model path:",train_args.save_best_weights)

运行结果

用命令行查看parser的所有参数选项

用命令行修改parser的特定参数

相关推荐
2的n次方_44 分钟前
CANN ascend-transformer-boost 架构解析:融合注意力算子管线、长序列分块策略与图引擎协同机制
深度学习·架构·transformer
人工智能培训1 小时前
具身智能视觉、触觉、力觉、听觉等信息如何实时对齐与融合?
人工智能·深度学习·大模型·transformer·企业数字化转型·具身智能
肖永威1 小时前
macOS环境安装/卸载python实践笔记
笔记·python·macos
TechWJ1 小时前
PyPTO编程范式深度解读:让NPU开发像写Python一样简单
开发语言·python·cann·pypto
枷锁—sha1 小时前
【SRC】SQL注入WAF 绕过应对策略(二)
网络·数据库·python·sql·安全·网络安全
abluckyboy2 小时前
Java 实现求 n 的 n^n 次方的最后一位数字
java·python·算法
喵手2 小时前
Python爬虫实战:构建各地统计局数据发布板块的自动化索引爬虫(附CSV导出 + SQLite持久化存储)!
爬虫·python·爬虫实战·零基础python爬虫教学·采集数据csv导出·采集各地统计局数据发布数据·统计局数据采集
pp起床2 小时前
Gen_AI 补充内容 Logit Lens 和 Patchscopes
人工智能·深度学习·机器学习
天天爱吃肉82183 小时前
跟着创意天才周杰伦学新能源汽车研发测试!3年从工程师到领域专家的成长秘籍!
数据库·python·算法·分类·汽车
m0_715575343 小时前
使用PyTorch构建你的第一个神经网络
jvm·数据库·python