argparse函数的读取
这个是以函数的形式嵌入到脚本中的
python
def common_args():
parser = argparse.ArgumentParser(description='common config')
parser.add_argument('--test', action='store_true', help="test mode (load model and test dataset)")
parser.add_argument('--iters', type=int, default=200000, help="training iters")
parser.add_argument('--lr', type=float, default=1e-2, help="initial learning rate")
parser.add_argument('--lr_net', type=float, default=1e-3, help="initial learning rate")
parser.add_argument('--ckpt', type=str, default='latest')
args = parser.parse_args()
return args
py文件的读取
python
import os
from pathlib import Path
from easydict import EasyDict as edict
FILE_PATH = Path(__file__).resolve()
ROOT_DIR = FILE_PATH.parents[1]
proj_conf = edict()
# 基本路径的设置
proj_conf.path = edict()
proj_conf.path.root_dir = str(ROOT_DIR)
# 其他参数的设置,比如网络模型dim
proj_conf.model = edict()
proj_conf.model.hidden_dim = 512
yaml文件的读取
python
# coding:utf-8
import yaml
import os
# 获取当前脚本所在文件夹路径
curPath = os.path.dirname(os.path.realpath(__file__))
# 获取yaml文件路径
yamlPath = os.path.join(curPath, "cfgyaml.yaml")
# open方法打开直接读出来
f = open(yamlPath, 'r', encoding='utf-8')
cfg = f.read()
print(type(cfg)) # 读出来是字符串
print(cfg)
d = yaml.load(cfg) # 用load方法转字典
print(d)
print(type(d))
# dict
@dataclass装饰器读取
python
import json
import numpy as np
from dataclasses import dataclass
from typing import Optional, Tuple
@dataclass
class ModelArgs:
channel: int = 128
input_shape: tuple = (32, 32)
schedule: str = "linear"
num_timesteps: int = 1000
schedule_low: float = 1e-4
schedule_high: float = 0.02
norm_eps: float = 1e-5
cuda: bool = True
max_batch_size: int = 32
max_seq_len: int = 2048
ffn_dim_multiplier: Optional[float] = None # python 3.10 可以这么写: ffn_dim_multiplier: int | None = None
# 用法如下: 创建的时候传入就可以了,然后在主函数里面进行定义
class Diffusion:
def __init__(self, args: ModelArgs):
super(Diffusion, self).__init__()
self.model_args = args
if __name__ == "__main__":
with open("params.json", "r") as f:
params = json.loads(f.read())
max_seq_len = 2048
max_batch_size = 16
model_args: ModelArgs = ModelArgs(
max_seq_len=max_seq_len,
max_batch_size=max_batch_size,
**params,
)