参数的读取

argparse函数的读取

这个是以函数的形式嵌入到脚本中的

python 复制代码
def common_args():
    parser = argparse.ArgumentParser(description='common config')
    parser.add_argument('--test', action='store_true', help="test mode (load model and test dataset)")
    parser.add_argument('--iters', type=int, default=200000, help="training iters")
    parser.add_argument('--lr', type=float, default=1e-2, help="initial learning rate")
    parser.add_argument('--lr_net', type=float, default=1e-3, help="initial learning rate")
    parser.add_argument('--ckpt', type=str, default='latest')
    args = parser.parse_args()
    return args

py文件的读取

python 复制代码
import os
from pathlib import Path
from easydict import EasyDict as edict

FILE_PATH = Path(__file__).resolve()
ROOT_DIR = FILE_PATH.parents[1]

proj_conf = edict()

# 基本路径的设置
proj_conf.path = edict()
proj_conf.path.root_dir = str(ROOT_DIR)

# 其他参数的设置,比如网络模型dim
proj_conf.model = edict()
proj_conf.model.hidden_dim = 512

yaml文件的读取

python 复制代码
# coding:utf-8
import yaml
import os

# 获取当前脚本所在文件夹路径
curPath = os.path.dirname(os.path.realpath(__file__))
# 获取yaml文件路径
yamlPath = os.path.join(curPath, "cfgyaml.yaml")

# open方法打开直接读出来
f = open(yamlPath, 'r', encoding='utf-8')
cfg = f.read()
print(type(cfg))  # 读出来是字符串
print(cfg)

d = yaml.load(cfg)  # 用load方法转字典
print(d)
print(type(d))
# dict

@dataclass装饰器读取

python 复制代码
import json
import numpy as np
from dataclasses import dataclass
from typing import Optional, Tuple

@dataclass
class ModelArgs:
    channel: int = 128
    input_shape: tuple = (32, 32)
    schedule: str = "linear"
    num_timesteps: int = 1000
    schedule_low: float = 1e-4
    schedule_high: float = 0.02
    norm_eps: float = 1e-5
    cuda: bool = True
    max_batch_size: int = 32
    max_seq_len: int = 2048

    ffn_dim_multiplier: Optional[float] = None  # python 3.10 可以这么写: ffn_dim_multiplier: int | None = None

# 用法如下: 创建的时候传入就可以了,然后在主函数里面进行定义
class Diffusion:
    def __init__(self, args: ModelArgs):
        super(Diffusion, self).__init__()
        self.model_args = args
        
if __name__ == "__main__":
    with open("params.json", "r") as f:
        params = json.loads(f.read())

    max_seq_len = 2048
    max_batch_size = 16
    model_args: ModelArgs = ModelArgs(
        max_seq_len=max_seq_len,
        max_batch_size=max_batch_size,
        **params,
    )
相关推荐
思则变1 小时前
[Pytest] [Part 2]增加 log功能
开发语言·python·pytest
漫谈网络2 小时前
WebSocket 在前后端的完整使用流程
javascript·python·websocket
try2find3 小时前
安装llama-cpp-python踩坑记
开发语言·python·llama
博观而约取4 小时前
Django ORM 1. 创建模型(Model)
数据库·python·django
精灵vector5 小时前
构建专家级SQL Agent交互
python·aigc·ai编程
Zonda要好好学习6 小时前
Python入门Day2
开发语言·python
Vertira6 小时前
pdf 合并 python实现(已解决)
前端·python·pdf
太凉6 小时前
Python之 sorted() 函数的基本语法
python
项目題供诗6 小时前
黑马python(二十四)
开发语言·python
晓13137 小时前
OpenCV篇——项目(二)OCR文档扫描
人工智能·python·opencv·pycharm·ocr