pytorch中项目配置文件的管理与导入方式

1.yaml文件

在 PyTorch 深度学习项目中,使用 YAML(Yet Another Markup Language)作为配置文件是非常主流的做法。相比 JSON 或 XML,YAML 的可读性更强,非常适合用来管理复杂的超参数(Hyperparameters)、模型结构参数和文件路径。

1.1 为什么是yaml?

在深度学习中,我们经常需要调整 batch_size, learning_rate, optimizer 等参数。

  • 如果不使用配置: 你需要反复修改代码中的变量,容易出错且难以版本控制。
  • 使用 YAML: 将代码(逻辑)与参数(配置)分离。修改参数只需改动 YAML 文件,无需触碰核心代码。

1.2 文件的编写语法

YAML 的核心规则是依靠缩进(Indentation)来表示层级关系。基本的语法概括如下:

  • 缩进: 必须使用空格,不能使用 Tab 键(通常是 2 个或 4 个空格)。
  • 键值对: key: value(冒号后面必须有一个空格)。
  • 注释: 使用 #

细致总结一下:

1.大小写敏感 :True 和 true 是不同的(YAML 对"布尔值的关键字"不区分大小写,但 YAML 对"字符串内容"是区分大小写的)。

2.缩进表示层级关系:

  • 只能使用空格(Space)绝对不能用 Tab 键 (这是 YAML 最常见的错误来源)。

  • 缩进空格数不固定(可以是 2 个或 4 个),但同一层级必须对齐,子层级必须比父层级多缩进。

  • 示例(你的文件用 2 个空格):

YAML 复制代码
paths:                  # 第 0 级
  data_dir: "./data/cifar10"  # 第 1 级(缩进 2 空格)
  log_dir: "./logs/experiment_1"
  #如果缩进不一致(如一个 2 空格、一个 4 空格),解析器会报错。

3.键值对:格式为 key: value(冒号后必须有一个空格)。如果没空格,如 key:value,会解析失败。

4.注释:用 # 开头,从 #到行尾都被忽略。可以放在行首、行尾或单独一行。示例:

YAML 复制代码
use_gpu: true # 布尔值(注释在行尾)
# 路径配置(单独一行注释)
paths:
 ...

5.文档分隔:一个文件中可以有多个 YAML 文档,用 --- 分隔。例如

文档分隔的作用:
逻辑上将一个文件拆分成多个独立的配置对象:每个 --- 之前的部分是一个完整的、独立的 YAML 文档(相当于一个独立的字典、配置或数据结构)。

允许在同一个文件中存储多个相关或不相关的配置,而不需要拆分成多个物理文件。

方便某些工具一次性处理多个配置,比如批量导入、流水线处理等。

yaml 复制代码
#YAML 文件的标准规范允许一个物理文件中包含多个独立的 YAML 文档(相当于多个独立的配置对象),它们之间用 ---(三个连字符)来分隔。
# 第一个 YAML 文档
name: Alice
age: 30
hobbies:
 - reading
 - hiking

---   #用三个连字符或者三个点···来显示结束一个文档,通常不需要。如果在yaml文件中如果有和yaml没关系的内容,必须有结束符号。
# 第二个 YAML 文档
name: Bob
age: 25
hobbies:
 - gaming
 - cooking

---
# 第三个 YAML 文档
server:
 host: localhost
 port: 8080

6.数据类型详解--YAML 支持三种基本结构:

  • 标量(Scalars):单个值(如字符串、数字、布尔)。
  • 映射(Mappings):键值对集合(相当于字典/dict)。
  • 序列(Sequences):有序列表(相当于数组/list)。

(1) 字符串(String)最常见类型。**YAML 里的"字符串",本质就是:一段文字。不同写法,只是"YAML 怎么把这段文字当成什么样子来理解"。**可以不加引号(plain style):如果不含特殊字符(如 : { } [ ] , #),推荐不加引号,更简洁。

  • 示例:data_dir: "./data/cifar10"(路径通常加引号,避免解析问题 )。路径里可能有特殊字符,YAML 解析器容易误解

  • 单引号 '...':内容原样输出,双单引号 '' 表示单个 '。

单引号 '...'(原样保存)

python 复制代码
msg: 'hello\nworld' #不是换行实际上结果是 "hello\\nworld"------Python 用 \\ 来表示"字符串里有一个反斜杠"
# 单引号 `'...'`(原样保存) 不做任何的转义。

当连续出现两个单引号的时候 '' 表示单个单引号;

python 复制代码
msg: 'it''s good'

#等价于
"it's good"
  • 双引号 "...":支持转义(如 \n 换行、\t Tab)。支持转义符(和 Python 字符串一样)
python 复制代码
msg: "hello\nworld"

#双引号支持转义所以结果是
"hello
world"
  • 多行字符串:

  • | :保留换行(literal block)。| = "我写几行,你就给我几行"

    python 复制代码
    description: |
     this is line one
     this is line two
     this is line three
    
    
    #在python里边
    "this is line one\nthis is line two\nthis is line three\n"
  • >:折叠换行成空格(folded block)。> = "写的时候换行,读的时候当一行"

    python 复制代码
    description: >
     this is line one
     this is line two
     this is line three
    
    #实际上
    "this is line one this is line two this is line three\n"

(2)数字在深度学习 YAML 里,大多数只需要会这三种数字写法:

python 复制代码
epochs: 100          # 整数
lr: 0.001            # 小数
weight_decay: 1.0e-4 # 科学计数法----1e-4 = 1 × 10⁻⁴ = 0.0001


#其他格式:十六进制 0xFF、八进制 0o777。

(3)布尔值标准写法:true / false(小写推荐)。YAML 也支持变体:True、TRUE、Yes、No、On、Off(不区分大小写)。但是注意,不能加引号,不然会变成字符串

(4) Null(空值): 用 ~ 或 null 表示。

python 复制代码
#例如
optional: ~

(5)映射(字典/Dict):YAML 的"映射(Mapping)"= Python 的"字典(dict)", 本质就是:键 → 值 的对应关系

python 复制代码
#缩进只能用空格,不能用 Tab
# 对
train:
 batch_size: 64

# 错(Tab)
train:
↹batch_size: 64  #不能用tab键


#同一层级,缩进必须对齐
# 对
train:
 batch_size: 64
 epochs: 100

# 错
train:
 batch_size: 64
   epochs: 100

   
   
#必须唯一(同一层里)
train:
 batch_size: 64
 batch_size: 128   # 覆盖 / 非法

   
#冒号分左右,缩进分里外,对齐是同级,一切都是键值对
#YAML 的映射不是"复杂",而是"把 Python dict 写得更好看"

(6)序列:序列 = 一堆有顺序的元素,类似于python里边的list。

python 复制代码
#block风格
transform_list:  #transform_list: → 一个键
 - "RandomCrop"   # - → 一个列表元素,每个 - 表示一项
 - "RandomHorizontalFlip"
 - "Normalize"
#看到 -,就要想到"列表的一项" ,'-'后边一定要有空格

#flow风格,行内写法
transform_list: ["RandomCrop", "RandomHorizontalFlip", "Normalize"]
#一般在:列表很短,不嵌套,不需要注释。时使用

YAML = 映射(dict) + 序列(list) + 标量(string / number / bool)

yaml 复制代码
# config.yaml

project_name: "ResNet_Classification"
use_gpu: true  # 布尔值

# 路径配置
paths:
  data_dir: "./data/cifar10"
  log_dir: "./logs/experiment_1"

# 模型参数
model:
  type: "resnet18"
  num_classes: 10
  pretrained: true

# 训练超参数
train:
  batch_size: 64
  epochs: 100
  learning_rate: 0.001
  weight_decay: 1.0e-4  # 支持科学计数法
  optimizer: "Adam"
  
# 列表/数组写法
transform_list:
  - "RandomCrop"
  - "RandomHorizontalFlip"
  - "Normalize"

1.3 yaml文件的使用

1。使用 yaml

python 复制代码
import yaml

# 读取函数
def get_config(path):
    with open(path, 'r', encoding='utf-8') as f:
        return yaml.safe_load(f)     #把yaml文件内容转换为python字典。safe_load 推荐用,不会执行 YAML 文件里潜在的危险命令。

  
cfg = get_config("config.yaml")   #输入路径

# 使用方式:像查字典一样
print(cfg['learning_rate'])  # 输 出结果
# 缺点:如果层级很深,代码会变成 config['train']['params']['lr'],很难看且容易写错字符串
  1. 封装为对象

在日常的项目中,我们不希望在代码里写满 ['key']。我们更习惯用 . 来访问属性,比如 config.lr

python 复制代码
#利用SimpleNamespace实现---SimpleNamespace 是 Python 标准库(types 模块)中的一个非常轻量的类。它的作用:允许你动态地给一个对象添加属性,并用点号访问这些属性。相当于一个"可随意扩展属性的空对象"。
import yaml   #yaml 是 import 的 PyYAML 库。
from types import SimpleNamespace

def load_config_as_obj(yaml_path):
    """
    读取 yaml 并将字典递归转换为对象,方便用 . 属性访问
    """
    with open(yaml_path, 'r', encoding='utf-8') as f:   #open打开文件,返回一个文件对象f
        config_dict = yaml.safe_load(f)   #加载为python字典

    # 递归转换函数
    def dict_to_obj(d):
        if not isinstance(d, dict):  
            '''
             #isinstance(object, class_or_tuple):--判断这个对象是不是某种类型
             object:你要检查的变量.
             class_or_tuple:你想检查的类型(或者类型元组)
             返回值:布尔值 True / False
            '''
            return d
        # 将字典转为 SimpleNamespace 对象
        obj = SimpleNamespace()   #创建一个空的SimpleNamespace对象。调用 types.SimpleNamespace 类,创建一个空的、可动态加属性的对象。此时 obj 里面什么属性都没有。
        for k, v in d.items():  #d.items返回的是一个元组,for循环可以多个变量,但是要求可迭代对象的每个元素是元组或列表,元素的长度必须和变量数一致
            # 递归处理嵌套的字典
            setattr(obj, k, dict_to_obj(v))    #这里递归调用dict_to_obj函数。如果不是字典,则返回d(也就是v)。如果是字典在进来再进行调用,直到不是字典未知。----给对象 obj 动态增加一个属性,名字是 k,值是 dict_to_obj(v)。
            #setattr(object, name, value)---把 name 当作属性名,把 value 赋值给对象
            '''
            object:要操作的对象
		   name:属性名(字符串)
            value:要赋给属性的值
            '''
        return obj  #可调用对象

    return dict_to_obj(config_dict)  #返回值

# --- 使用演示 ---
# 假设 yaml 内容是:
# train:
#   lr: 0.01
#   device: "cuda"

cfg = load_config_as_obj("config.yaml")  #给一个yaml文件路径

# 现在的调用方式非常优雅:
print(cfg.train.lr)      # 0.01
print(cfg.train.device)  # cuda

其次可以使用 argparse 读取命令行参数,如果有输入,就覆盖 yaml 里的默认值。

argparse 是 Python 内置模块,用来 解析命令行参数。"命令行参数" = 你运行脚本时输入的参数,比如:

python 复制代码
python train.py --lr 0.001 --epochs 50

argparse 可以把这些字符串参数 转换成 Python 对象 ,方便在代码中使用使用 argparse 通常有三个步骤:

  1. 创建解析器
python 复制代码
parser = argparse.ArgumentParser()
  • ArgumentParser() 创建一个解析器对象。这个解析器负责定义你想接受哪些参数,以及解析命令行输入
  1. 添加参数定义

add_argumentargparse 模块里 ArgumentParser 对象的方法 ,作用是:告诉解析器你的程序可以接收哪些命令行参数,以及这些参数的类型、默认值和说明。

python 复制代码
parser.add_argument('--lr', type=float, default=None, help='学习率')
  • --lr → 命令行参数名
  • type=float → 解析后转换为浮点数
  • default=None → 如果命令行没提供,默认值是 None
  • help='学习率' → 提示信息(python train.py --help 会显示)

你可以添加多个参数:

python 复制代码
parser.add_argument('--epochs', type=int, default=None, help='训练轮数')
parser.add_argument('--config', type=str, default='./configs/resnet_train.yaml', help='配置文件路径')
  1. 解析命令行输入
python 复制代码
args = parser.parse_args() #`parse_args()` 会读取运行脚本时的命令行参数,返回一个对象 `args`,里面每个参数都是 **对象属性**

例如运行:

python 复制代码
python train.py --lr 0.001 --epochs 50

得到:

python 复制代码
args.lr      # 0.001
args.epochs  # 50
args.config  # './configs/resnet_train.yaml'

#如果命令行不输入某个参数,它就用你定义的 `default` 值。
python 复制代码
import argparse  #Python 内置模块,用来解析命令行参数(命令行参数也就是python运行脚本的时候输入的参数:python train.py --lr 0.001 --epochs 50)。argparse 可以把这些字符串参数转换成 Python 对象,方便在代码中使用

def get_args_and_config():   #读取 YAML 配置 + 解析命令行参数 + 覆盖默认值。返回最终的 cfg 对象,用于训练脚本中直接访问参数
    parser = argparse.ArgumentParser()  #ArgumentParser() 创建一个解析器对象,知道你程序允许哪些命令行参数,并解析这些参数
    parser.add_argument('--config', type=str, default='./configs/resnet_train.yaml', help='配置文件路径')  #拿到yaml文件的路径。
    parser.add_argument('--lr', type=float, default=None, help='临时修改学习率')
    parser.add_argument('--epochs', type=int, default=None, help='临时修改轮数')  #help是提示信息用于--help的时候显示
    args = parser.parse_args()   #当运行python train.py --lr 0.001 --epochs 50之后,可以用args.lr调取这个值是多少。
    
    # 1. 先加载 yaml 为对象
    cfg = load_config_as_obj(args.config)   #还是之前的SimpleNamespace。变为一个对象,可以用 点 调用。
    
    # 2. 如果命令行有指定参数,覆盖 yaml 中的值
    if args.lr is not None:  #如果通过命令行传递进来参数了。
        cfg.training.lr = args.lr   #重新赋值,进而覆盖Yaml的默认值。-这里不会修改yaml文件,只是会修改内存里的配置对象cfg
        print(f"注意:学习率被命令行参数覆盖为 {cfg.training.lr}")
        
    if args.epochs is not None:
        cfg.training.epochs = args.epochs

    return cfg

# 在 main 中调用:
# cfg = get_args_and_config()

2.json文件

2.1 json文件编写语法

JSON 是目前互联网最通用的数据格式。具有语法严格,不能注释,兼容性较好的特点。

  • 语法严格 :键值对必须用双引号 ""
  • 无注释 :不能写 #//,这是它不适合做配置文件的最大原因。
  • 兼容性好:网页、后端、Python 都能直接读写。

在深度学习中可以存日志 & 存结果 因为 JSON 机器读取速度快且格式标准,我们通常用它来保存训练过程中的各项指标 (Loss, Accuracy),或者数据集的标注信息(如 COCO 数据集)。

json的格式要求更为严格。

  • 严谨的键值对:类似于 Python 的字典,但要求更严格。
  • 双引号 :所有的键(Key)和字符串值(Value)必须用双引号 "",不能用单引号。
  • 不支持注释 :这是它最大的特点(也是作为配置文件的缺点),你不能在文件里写 //#
  • 数据类型 :支持 字符串、数字、布尔值 (true/false)、列表 []、字典 {}
元素 写法要求 示例
键(key) 必须是字符串,必须用双引号包裹 "batch_size": 128
值(value) 可以是: • 字符串(双引号) • 数字 • 布尔值 • null • 对象 • 数组 "resnet18" 0.001 true null
字符串 必须用双引号(不能用单引号) "data_dir": "./data"
数字 直接写,不需要引号,支持小数和科学计数法 "lr": 0.001 "weight_decay": 1e-4
布尔值 只能写 true 或 false(小写!) "pretrained": true
空值 只能写 null(小写) "optional": null
数组 用 [ ],元素之间用逗号分隔 "transforms": ["RandomCrop", "Normalize"]
嵌套 对象里面可以套对象或数组 见下面的完整例子
逗号 每个键值对或数组元素后面(除最后一个)必须有逗号 "batch_size": 128,
注释 不支持任何形式的注释(这是和 YAML 最大的区别!) 不能写 // 或 # 开头的注释
json 复制代码
{
    "experiment_id": "exp_2024",
    "metrics": {
        "accuracy": 0.95,
        "loss": 0.045
    },
    "classes": ["cat", "dog", "car"],  
    "is_finished": true
}

2.2 json的用法-----类似于yaml

python 复制代码
import json
from types import SimpleNamespace  # 可选:用来转成点号访问对象

def load_json_config(path="config.json"):
    with open(path, "r", encoding="utf-8") as f:
        config_dict = json.load(f)  # 注意:是 json.load(f),不是 json.loads()------返回对象也是一个字典
    
    return config_dict  #一个字典

# 使用
cfg = load_json_config("config.json")

# 字典方式访问
print(cfg["data"]["batch_size"])      # 128
print(cfg["optimizer"]["lr"])         # 0.001
python 复制代码
#转化为对象访问
def load_json_as_obj(path="config.json"):  #给一个默认值config.json,不传参数的时候就用默认值。
    with open(path, "r", encoding="utf-8") as f:
        data = json.load(f)
    
    def dict_to_obj(d):
        if isinstance(d, dict):  #判断是不是字典类型
            return SimpleNamespace(**{k: dict_to_obj(v) for k, v in d.items()}) #**字典 的意思是:把字典"拆开"成关键字参数
        '''
       { key_expression : value_expression  for 变量 in 可迭代对象 } --- {k : dict_to_obj(v)  for k, v in d.items()}
       d 是一个字典(比如 {'lr': 0.001, 'device': 'cuda'})
      d.items() 返回所有键值对:[('lr', 0.001), ('device', 'cuda')]
	for k, v in d.items():依次取出键(k)和值(v)
	k : dict_to_obj(v):新字典的键还是原来的 k,但值要先经过 dict_to_obj(v) 处理(如果 v 是字典,就递归转成对象;如果不是,就原样返回)
       ---------------------------
        {k: dict_to_obj(v) for k, v in d.items()}-一个经典的字典推导式。它本身就等价于"先创建一个空字典,再用 for 循环往里塞数据"。
        等价于:
        new_dict = {}              # 1 先创建空字典
	   for k, v in d.items():     # 2 遍历原字典
          new_dict[k] = dict_to_obj(v)   # 3 赋值
        '''
        elif isinstance(d, list):  #json可以做嵌套,因此可能包含列表的情况。"classes": ["cat", "dog", "car"],  
            return [dict_to_obj(i) if isinstance(i, dict) else i for i in d]
        '''
        new_list = []
        for i in d:
            if isinstance(i, dict):
                new_list.append(dict_to_obj(i))
            else:
                new_list.append(i)	
        '''
        '''
        [表达式 for 变量 in 可迭代对象 if 条件]
        表达式 → 每次循环计算出的值,会成为新列表的元素
		变量 → 循环中取出的每个元素
		可迭代对象 → 任何可遍历的对象,如列表、字典的 keys、range() 等
		if 条件 → 可选,对循环元素做过滤
        '''
        else:
            return d  #普通值
    
    return dict_to_obj(data)

cfg = load_json_as_obj("config.json")

# 现在可以用点号访问了!
print(cfg.data.batch_size)      # 128
print(cfg.optimizer.lr)         # 0.001
print(cfg.model.name)           # resnet18
python 复制代码
#写入json文件
import json

config = {                                  #创建一个字典
    "project_name": "MyExperiment",
    "final_accuracy": 92.5,
    "best_epoch": 87
}

with open("result.json", "w", encoding="utf-8") as f:  #with open自动打卡文件,用w 模式,with打开不用手动 f.close。with当代码结束会自动调用f.close()
    json.dump(config, f, indent=4, ensure_ascii=False)
    # indent=4:美化输出,方便阅读
    # ensure_ascii=False:支持中文等非ASCII字符
    
    

dump函数讲解

python 复制代码
json.dump(obj, fp, *, skipkeys=False, ensure_ascii=True, check_circular=True,     
   allow_nan=True, cls=None, indent=None, separators=None, default=None, sort_keys=False)
参数 类型 默认值 说明 推荐用法
obj Python对象 必填 要写入文件的 Python 数据(通常是 dict、list、str、int、float、bool、None 等) 你的配置字典
fp 文件对象 必填 已打开的、可写的文件对象(通常用 open(..., 'w')) with open(...) as f
indent int 或 None None 缩进空格数。如果设置(如 2 或 4),生成的 JSON 会格式化(美化),方便阅读。每层嵌套增加的空格数,例如每一层嵌套增加 4 个空格 indent=4(强烈推荐)
ensure_ascii bool True 如果为 True,非 ASCII 字符(如中文)会转成 \uXXXX 转义。如果为 False,直接保留原字符 ensure_ascii=False(有中文时必设)
sort_keys bool False 是否对字典的键进行排序(按字母顺序) sort_keys=True(调试时方便对比)
separators tuple (', ', ': ') 控制项分隔符和键值分隔符,通常不用改 一般不改
default callable None 如果对象有无法序列化的类型(如 set、datetime),可以用这个函数自定义转换 高级用法

3.py文件---实现"代码即配置"

把配置从"纯数据"(data)升级成"可执行代码"(code) 。 简单说,就是直接用一个 Python 文件(通常叫 config.py、models_config.py 等)来定义所有配置和逻辑,而不是用 YAML/JSON 只存静态值。这在深度学习项目中非常常见,尤其是当配置需要包含复杂逻辑时(比如动态构建模型、条件判断、计算路径等)。

python 复制代码
# config.py - 所有配置和逻辑集中在这里

import torch
import torch.nn as nn
from torchvision import models, transforms

# ================== 数据配置 ==================
DATA_DIR = "./data"
BATCH_SIZE = 128
NUM_WORKERS = 4

# ================== 模型配置 ==================
MODEL_NAME = "resnet18"   # 改这里就能换模型
NUM_CLASSES = 10
PRETRAINED = True

def get_model():
    if MODEL_NAME == "resnet18":
        base = models.resnet18(pretrained=PRETRAINED)
    elif MODEL_NAME == "resnet50":
        base = models.resnet50(pretrained=PRETRAINED)
    elif MODEL_NAME == "mobilenet_v2":
        base = models.mobilenet_v2(pretrained=PRETRAINED)
    else:
        raise ValueError(f"Unknown model: {MODEL_NAME}")
    
    # 统一修改最后一层
    if hasattr(base, 'fc'):  # ResNet 系列
        base.fc = nn.Linear(base.fc.in_features, NUM_CLASSES)
    elif hasattr(base, 'classifier'):  # MobileNet
        base.classifier[1] = nn.Linear(base.classifier[1].in_features, NUM_CLASSES)
    
    return base

# ================== 训练配置 ==================
LR = 0.001
EPOCHS = 100
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ================== 数据增强 ==================
def get_transforms():
    return transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
python 复制代码
# train.py
from config import get_model, get_transforms, BATCH_SIZE, NUM_WORKERS, LR, EPOCHS, DEVICE, DATA_DIR

model = get_model().to(DEVICE)
transform = get_transforms()

# 数据加载、优化器、训练循环...

4.XML

XML(eXtensible Markup Language) 是一种 可扩展标记语言,用于存储和传输数据,类似 JSON/YAML。具有一下特点:

  1. 可读性强,层级清晰
  2. 支持嵌套和属性

但是比 JSON/YAML 冗长,而且使用起来有点复杂。

在目标检测(Object Detection)领域,尤其是经典的 Pascal VOC 数据集 (2007/2012)和很多自定义数据集,标注信息都是用 XML 文件 来存储的。 每一个图像对应一个 .xml 文件,里面记录了图像中所有目标的类别、边界框坐标(bounding box)等信息。

xml 复制代码
<?xml version="1.0" encoding="UTF-8"?>
<config>
    <project_name>MyExperiment</project_name>
    <model>
        <type>resnet18</type>
        <num_classes>10</num_classes>
        <pretrained>true</pretrained>
    </model>
    <training>
        <batch_size>64</batch_size>
        <epochs>100</epochs>
        <learning_rate>0.001</learning_rate>
        <optimizer>Adam</optimizer>
    </training>
</config>
  • <config> → 根节点(root element)
  • <model> / <training> → 子节点
  • <type>resnet18</type> → 标签 + 内容
  • XML 支持 嵌套层级,适合复杂配置

4.1读取

python 复制代码
#Python 内置库 xml.etree.ElementTree 可以解析 XML。解析、创建、操作 XML 文件
import xml.etree.ElementTree as ET

# 1. 读取 XML 文件
tree = ET.parse("config.xml")  # 返回 ElementTree 对象 --tree → XML 的整个树形结构
root = tree.getroot()          # 根节点 <config>----从 ElementTree 中获取 根节点

# 2. 访问数据
project_name = root.find("project_name").text  #root.find--查找 <config> 下的第一个 <project_name> 子节点,返回一个 Element 对象
										 #.text 获取该节点的文本内容 "MyExperiment"。
print(project_name)  # MyExperiment---project_name → 字符串类

model_type = root.find("model/type").text
num_classes = int(root.find("model/num_classes").text)
pretrained = root.find("model/pretrained").text == "true"  #== "true" → 转成布尔值

batch_size = int(root.find("training/batch_size").text)
learning_rate = float(root.find("training/learning_rate").text)

print(model_type, num_classes, pretrained, batch_size, learning_rate)

5.TOML

TOML(Tom's Obvious, Minimal Language)是一种现代、简洁、人性化的配置文件格式,由 GitHub 联合创始人 Tom Preston-Werner 创建。它的设计目标是尽可能明显、直观 ,比 JSON 可读性更强(支持注释),比 YAML 更简单(缩进不敏感)。在深度学习项目中,TOML 的最主流、最核心用法 不是在代码里读写超参数,而是用于项目依赖管理和构建配置------即 pyproject.toml 文件。

5.1 TOML的语法格式

元素类型 写法要求 示例 说明
键值对 key = value(等号两边有空格) batch_size = 128 最基本的配置方式
字符串 单引号 '...' 或双引号 "..." name = "resnet18" 可使用转义字符
数字 直接写,支持整数、浮点数、科学计数法 lr = 0.001weight_decay = 1e-4 默认是数字类型
布尔值 truefalse(小写) pretrained = true XML/JSON 没有布尔类型要特别注意
数组/列表 [elem1, elem2, ...] transforms = ["crop", "flip"] 支持不同类型混合元素
表(Table) [table_name] 或点号嵌套 table.subtable [model]model.name = "resnet18" 用于分组或嵌套配置
注释 # 开头 # 这是注释 注释不会被解析
多行字符串 三个引号 """...""" desc = """多行文本""" 支持换行
嵌套表 [table.subtable]table.subtable.key = value [data.train] 支持多层嵌套结构
日期/时间 ISO 8601 格式 start_date = 2025-12-24T22:00:00Z TOML 内置日期时间类型
toml 复制代码
# config.toml
project_name = "CIFAR10_Classification"
seed = 42

[data]
dataset = "CIFAR10"
data_dir = "./data"
batch_size = 128
num_workers = 4

[model]  #TOML 使用 表(table) 来表示 嵌套结构或命名空间,[model] 表示 一个名为 model 的表,表下面的键值对都属于这个表的 子空间
name = "resnet18"
pretrained = true
num_classes = 10

[optimizer]
name = "Adam"
lr = 0.001
weight_decay = 1e-4

[train]
epochs = 100
device = "cuda"

5.2 toml的用法

python 复制代码
#可以做项目依赖管理
'''
TOML 的 最常见用途不是存超参,而是 管理 Python 项目的依赖和构建配置
在现代 Python 项目中,它已经取代了:
requirements.txt(老式依赖列表)
setup.py(旧版打包配置)
存放位置:项目根目录
'''

#例如
[project]
name = "cifar10-resnet"
version = "0.1.0"
description = "CIFAR-10 分类实验"
authors = [{name = "张三", email = "zhangsan@example.com"}]
requires-python = ">=3.9"
dependencies = [
    "torch>=2.0.0",
    "torchvision>=0.15.0",
    "pyyaml>=6.0",
    "matplotlib>=3.5",
    "tqdm"
]

[project.optional-dependencies]
dev = ["black", "flake8", "pytest"]
train = ["wandb", "tensorboard"]

[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"


#可以这么用

# 安装项目依赖
pip install .

# 安装开发依赖
pip install -e ".[dev]"

# 用 poetry 管理
poetry install          # 自动安装所有依赖
poetry add torch==2.1.0 # 添加新依赖,自动更新 toml


'''
环境可复现:别人 clone 你的代码后,只需 pip install . 就能装好相同版本的包
版本锁定:精确控制 torch、torchvision 等版本
现代标准:pip、poetry、pdm 等工具都支持 TOML
分组依赖:区分运行、开发、训练依赖

在工程中,TOML 最重要的用途是 依赖管理和项目构建,而不是超参配置
'''
python 复制代码
#作为超参配置--代码读取 TOML 作为超参配置
#虽然不常用,但可以把 TOML 当作 YAML/JSON 的替代品,存超参。需要安装toml库。pip install toml

#读取toml为对象
import toml
from types import SimpleNamespace

def load_toml_config(path="config.toml"):
    data = toml.load(path)   #返回字典类型
    
    def dict_to_obj(d):
        if isinstance(d, dict):
            return SimpleNamespace(**{k: dict_to_obj(v) for k, v in d.items()})
        elif isinstance(d, list):  #可能有泪飙类型  transforms = ["crop", "flip", "normalize"]   让列表里的字典也能用 点号访问。
            return [dict_to_obj(i) for i in d]
        else:
            return d
    
    return dict_to_obj(data)

cfg = load_toml_config("config.toml")
print(cfg.data.batch_size)   # 128
print(cfg.optimizer.lr)      # 0.001



#写入toml
import toml

config = {"train": {"epochs": 100, "lr": 0.001}}
with open("config.toml", "w") as f:
    toml.dump(config, f)
相关推荐
CodeCraft Studio2 小时前
Stimulsoft报表与仪表板产品重磅发布2026.1版本:进一步强化跨平台、数据可视化、合规及 AI 辅助设计等
人工智能·信息可视化·报表开发·数据可视化·stimulsoft·仪表板·报表工具
AndrewHZ2 小时前
【图像处理基石】[特殊字符]圣诞特辑:10+经典图像处理算法,让你的图片充满节日氛围感!
图像处理·人工智能·opencv·算法·计算机视觉·stable diffusion·节日氛围感
千匠网络2 小时前
千匠大宗电商系统:赋能煤炭能源行业产业升级
大数据·人工智能·区块链·大宗电商·大宗电商系统
シ風箏2 小时前
Ascend C 异构编程环境搭建全流程指南
人工智能
Ama_tor2 小时前
Obsidian + Ollama本地AI集成|把每日日记自动归类成主题笔记
人工智能
冰西瓜6002 小时前
通俗易懂讲解马尔可夫模型
人工智能·机器学习
飞Link2 小时前
【论文笔记】A Survey on Data Synthesis and Augmentation for Large Language Models
论文阅读·人工智能·语言模型·自然语言处理
想用offer打牌2 小时前
Reasoning + Acting: ReAct范式与ReAct Agent
人工智能·后端·llm
BBB努力学习程序设计2 小时前
Python模块与包:构建可维护的代码结构
python