参数的读取

argparse函数的读取

这个是以函数的形式嵌入到脚本中的

python 复制代码
def common_args():
    parser = argparse.ArgumentParser(description='common config')
    parser.add_argument('--test', action='store_true', help="test mode (load model and test dataset)")
    parser.add_argument('--iters', type=int, default=200000, help="training iters")
    parser.add_argument('--lr', type=float, default=1e-2, help="initial learning rate")
    parser.add_argument('--lr_net', type=float, default=1e-3, help="initial learning rate")
    parser.add_argument('--ckpt', type=str, default='latest')
    args = parser.parse_args()
    return args

py文件的读取

python 复制代码
import os
from pathlib import Path
from easydict import EasyDict as edict

FILE_PATH = Path(__file__).resolve()
ROOT_DIR = FILE_PATH.parents[1]

proj_conf = edict()

# 基本路径的设置
proj_conf.path = edict()
proj_conf.path.root_dir = str(ROOT_DIR)

# 其他参数的设置,比如网络模型dim
proj_conf.model = edict()
proj_conf.model.hidden_dim = 512

yaml文件的读取

python 复制代码
# coding:utf-8
import yaml
import os

# 获取当前脚本所在文件夹路径
curPath = os.path.dirname(os.path.realpath(__file__))
# 获取yaml文件路径
yamlPath = os.path.join(curPath, "cfgyaml.yaml")

# open方法打开直接读出来
f = open(yamlPath, 'r', encoding='utf-8')
cfg = f.read()
print(type(cfg))  # 读出来是字符串
print(cfg)

d = yaml.load(cfg)  # 用load方法转字典
print(d)
print(type(d))
# dict

@dataclass装饰器读取

python 复制代码
import json
import numpy as np
from dataclasses import dataclass
from typing import Optional, Tuple

@dataclass
class ModelArgs:
    channel: int = 128
    input_shape: tuple = (32, 32)
    schedule: str = "linear"
    num_timesteps: int = 1000
    schedule_low: float = 1e-4
    schedule_high: float = 0.02
    norm_eps: float = 1e-5
    cuda: bool = True
    max_batch_size: int = 32
    max_seq_len: int = 2048

    ffn_dim_multiplier: Optional[float] = None  # python 3.10 可以这么写: ffn_dim_multiplier: int | None = None

# 用法如下: 创建的时候传入就可以了,然后在主函数里面进行定义
class Diffusion:
    def __init__(self, args: ModelArgs):
        super(Diffusion, self).__init__()
        self.model_args = args
        
if __name__ == "__main__":
    with open("params.json", "r") as f:
        params = json.loads(f.read())

    max_seq_len = 2048
    max_batch_size = 16
    model_args: ModelArgs = ModelArgs(
        max_seq_len=max_seq_len,
        max_batch_size=max_batch_size,
        **params,
    )
相关推荐
qq_330037995 分钟前
如何配置ASM元数据备份_md_backup与md_restore重建磁盘组结构
jvm·数据库·python
昭昭日月明9 分钟前
前端仔速通 Python
javascript·python
a95114164229 分钟前
SQL触发器实现自动生成流水号_配合序列对象实现递增逻辑
jvm·数据库·python
哦哦~92134 分钟前
FDTD 与 Python 联合仿真的超表面智能设计技术与应用
python·fdtd·超表面
财经资讯数据_灵砚智能34 分钟前
基于全球经济类多源新闻的NLP情感分析与数据可视化(夜间-次晨)2026年4月21日
人工智能·python·信息可视化·自然语言处理·ai编程
解救女汉子40 分钟前
mysql如何配置元数据锁超时_mysql lock_wait_timeout设置
jvm·数据库·python
21439651 小时前
SQL注入防御技术方案_基于正则表达式的输入清洗
jvm·数据库·python
2401_832365521 小时前
SQL窗口函数与递归查询的区别_如何根据场景选择
jvm·数据库·python
u0109147601 小时前
c++如何处理文件路径中由于不规范的连续斜杠导致的路径解析错误【避坑】
jvm·数据库·python
2301_796588501 小时前
PHP源码开发用二手硬件划算吗_性价比与稳定性权衡【操作】
jvm·数据库·python