训练模型不能没有一个灵活的Trainer,就像纪录片不能没有麦克阿瑟
说到Trainer,大多人会想到pytorch lightning和huggingface,也有相关问题去对比这二者,在使用过huggingface的Trainer后,我认为它有以下两个缺点:
- 用多层封装换来了易用性,但如果要自定义模块(比如:想给cosine scheduler设置一个min_lr、想实现vit的学习率逐层decay)会比较麻烦
- 参数和功能有点多了,这些功能耦合在一起,会有些混乱,对于自己做小项目或者做科研,似乎不需要这么多功能
在上个月,蹭着通义千问的热度,我写了这篇
项目开源于:
该项目主要是重构lavis之后搭建的,lavis(github.com/salesforce/... )是多模态领域很火的一个开源仓库,像BLIP2、InstructBLIP、MiniGPT4等许多多模态大模型都是基于lavis进行进一步开发的。在仔细阅读其源码后,我非常喜欢它的代码框架,所以我针对其Trainer进行重构,可以更加灵活地适配或迁移到用户的任务、模型、数据集。
这个干净、灵活又不太冗杂的Trainer开源在:
欢迎大家在私信、知乎、github仓库issue中给这个项目提提建议,如果对你有帮助的话,请多多点star呀!这对我真很重要:)
实现的功能
- Registry机制:为model、dataset、processor(预处理的transform)、lr_scheduler、task(现在进行的task,如:分类、分割、image2prompt等)构建注册表
- 完整、灵活的配置文件:一个配置文件对应一次完整的运行(训练),有多而不冗余的参数可供设置
- 去冗余性:
- 对于上述的注册表中的每个组件,都提供有基类,减少代码重复
- 去除一些重复、冗余的功能
- 可扩展性/灵活性:自顶向下满足了
- 任务可扩展(类似于OpenMMLab基于MMEngine和MMCV支持了那么多视觉、多模态任务):对于所有任务均可支持,本项目支持了图像分类(以猫狗分类为例)、Image2Prompt(为了适配本菜鸡第一次kaggle比赛(www.kaggle.com/competition...,最终获得银牌,虽菜但难忘就多实现了它,简易化pipeline如下图)
- 模型可扩展
- 数据集可扩展性(包含预处理的可扩展性)
- scheduler的可扩展性

支持新功能的QuickTutorial
定义你的数据集
在datasets
目录下,继承BaseDataset
类,实现你的dataset,如果需要自定义collator
,请在这里完成
例如如下代码完成了一个新的分类数据集定义,其目录结构见example_data/classification
:
python
from common.registry import registry
from .base_dataset import BaseDataset
from PIL import Image
from pathlib import Path
import os
import torch
@registry.register_dataset('example_cls_dataset')
class ExampleClsDataset(BaseDataset):
def __init__(self, transform, data_root):
super().__init__(transform=transform)
self.data_root = Path(data_root)
self.cls_dir = sorted(list(os.listdir(self.data_root)))
self.data = []
self.labels = []
self.idx2cls,self.cls2idx = {}, {}
for i,cls in enumerate(self.cls_dir):
self.idx2cls[i] = cls
self.cls2idx[cls] = i
imgs = [str(self.data_root/cls/img) for img in os.listdir(self.data_root/cls) if self.is_img(img)]
self.data.extend(imgs)
self.labels.extend([i]*len(imgs))
assert len(self.data) == len(self.labels)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
image_path = self.data[index]
label = self.labels[index]
image = Image.open(image_path).convert("RGB")
return self.transform(image), torch.tensor(label,dtype=torch.long)
@staticmethod
def is_img(img_path):
return Path(img_path).suffix.lower() in ['.jpg', '.jpeg', '.png']
def collator(self,batch):
images, labels = zip(*batch)
images = torch.stack(images)
labels = torch.stack(labels)
return {"images": images, "labels": labels}
定义你的模型
在models
目录下,继承BaseModel
类进行实现,可以参考本库给出的resnet_clip
注意:请为你的模型实现train_step
和val_step
两个方法,会在task.train_step
和task.val_step
时调用
定义新的task
在tasks
目录下,继承BaseTask
类进行实现,可以参考本库给出的ClassificationTask
任务
注:一般来说,只需要针对任务和任务对应的metric修改 val_step
即可
为什么需要Trainer和registry机制
要想知道为什么需要Trainer,首先我们创造一个没有Trainer的时代,只使用原生pytorch去构建一个训练流程,这时我们需要做:
- 定义Dataset、Dataloader
- 定义model
- 定义损失函数
- 定义损失函数
- 定义优化器
- 定义训练过程中的学习率变化策略(scheduler)
- 循环、迭代更新模型
大致pytorch代码如下:
ini
import torch
from torch.utils.data import DataLoader
from dataset import train_data, val_data
from network import Net
# dataloader
train_loader = DataLoader(dataset=train_data, batch_size=16, shuffle=True)
val_loader = DataLoader(dataset=val_data, batch_size=16, shuffle=False)
# 定义模型
model = Net(...)
# 定义损失函数
criterion = torch.nn.CrossEntropyLoss()
# 定义优化器
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
# 定义scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,...)
# 定义训练epoch的次数
epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 循环迭代起来
for epoch in range(epochs):
train_loss = 0
train_acc = 0
model.train()
for i, (x, y) in enumerate(train_loader):
x = x.to(device)
y = y.to(device)
model = model.to(device)
out = model(x)
loss = criterion(out, y)
train_loss += loss.item()
prediction = torch.max(out,1)[1]
pred_correct = (prediction == y).sum()
train_acc += pred_correct.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(...)
# 验证
model.eval()
with torch.no_grad():
eval_loss = 0
eval_acc = 0
for i, (x, y) in enmuerate(val_loader):
out = model(x)
loss = criterion(out, y)
eval_loss += loss.item()
prediction = torch.max(out, 1)[1]
pred_correct = (prediction == y).sum()
eval_acc += pred_correct.item()
print(...)
这段代码看起来没什么问题,很简洁易懂,代价就是:
代码封装上:
- 当前代码完全没有怎么进行封装,完成pipeline的每个部分都呈现在一个脚本文件中,从程序设计的角度很不优美
事实上,只需要:
php
# 各种组件的定义
......
trainer.train()
trainer.eval()
将固定的迭代更新流程写在trainer的train()函数中,然后train里再
python
class Trainer:
......
def train(xxxx):
for _ in range(epoch):
self.train_step(xxxxx)
self.val_step(xxxx)
def train_step(xxx):
for sample in dataloader:
loss = model.train_step(sample)
# 反向传播
......
def val_step(xx):
for sample in dataloader:
metric = model.val_step(sample)
# log、save模型等其他功能
......
灵活性上:
-
当前任务,我要换个模型!
- 好的,from network import Net 改成 from network import NewNet
-
另一个任务,我要换另一个任务的模型!
- 好的,from network import Net 改成 from network import NewTaskNet
-
我要换个数据集!
- 好的,from dataset import train_data 改成 from dataset import new_train_data
-
我要换个优化器!我要换个scheduler!
-
好的,要么也像上面一样,在别的文件实现了,然后改import
-
要么,直接改这个脚本里的变量定义
-
这两个问题,就可以通过实现一个Trainer,并且加入register机制,将pipeline中各个模块的定义字符串化,可以通过一个yaml文件直接定义一次运行行为中所有组件。
可以认为,通过Register机制定义的是一个个组件,而Trainer就是一个封装所需功能,并给这些组件提供插槽!
灵活的基础组件定义:register机制------将类定义字符串化
现在的一些高影响力的开源仓库,经常离不开register机制,比如:timm、openmmlab系列仓库等,register机制的代码大致如下:
python
class Registry:
mapping = {
"dataset_name_mapping": {},
"task_name_mapping": {},
"processor_name_mapping": {},
"model_name_mapping": {},
"lr_scheduler_name_mapping": {},
}
@classmethod
def register_model(cls, name):
def wrap(model_cls):
from models import BaseModel
assert issubclass(
model_cls, BaseModel
), "All models must inherit BaseModel class"
if name in cls.mapping["model_name_mapping"]:
raise KeyError(
"Name '{}' already registered for {}.".format(
name, cls.mapping["model_name_mapping"][name]
)
)
cls.mapping[ "model_name_mapping" ][name] = model_cls
return model_cls
return wrap
@classmethod
def get_model_class(cls, name):
return cls.mapping["model_name_mapping"].get(name, None)
registry = Registry()
可以看到,关键的一行代码就是:
css
cls.mapping[ "model_name_mapping" ][name] = model_cls
即将一个定义好的model class放到model_name_mapping
这个字典(一般称为注册表)中,如果我们需要找到这个模型,只需要:
ini
model_name = xxx # name
model = register.get_model_class(model_name)(模型的init参数们)
这样就实现了字符串与model class的映射,后续就不需要像之前说的每次import新定义的model class,直接通过修改配置文件里的model_name即可。
完成注册操作的代码如下:
python
@registry.register_model(模型name)
class NewNet(BaseModel): # 继承预定义好的BaseModel类,减少重复代码的ctrl cv
......
前面在定义registry时,我们有:
makefile
mapping = {
"dataset_name_mapping": {},
"task_name_mapping": {},
"processor_name_mapping": {},
"model_name_mapping": {},
"lr_scheduler_name_mapping": {},
}
这些注册表,包含:dataset、task、processor(输入的预处理,比如:读取图像、数据增强、ToTensor、归一化)、model、lr_scheduler,如果有进一步的包括优化器等你愿意去修改的组件,都可以为他构建一个注册表
这样一来,训练pipeline中各个组件的定义就完全字符串化了,这里放上一个最后的配置文件的部分截图和简单注释,让大家直观地感受字符串化后的好处!

给各个组件来一个功能齐全且带"插槽"的Trainer
当定义好各个组件后,就需要一个功能齐全,带插槽的Trainer,来让它们发挥作用了,在Trainer中,需要定义一整个训练、验证的流程,需要将输入的组件们进一步封装,发挥作用(如:将dataset变成dataloader)。
从自底向上的编程角度来看,Trainer就应该处于最上层,他需要足够大、足够global,可以适配底部组件们的变化。
首先,Trainer将组件插入进来的过程:(这里并没有把所有参数介绍全,仅介绍了最常见的组件和参数们,更加细节可以去仓库看源码)
python
class Trainer:
def __init__(self,config,model,datasets,task,job_id):
self.config = config
self.job_id = job_id
self._model = model
self.datasets = datasets
self.task = task
self._wrapped_model = None
self._device = None
self._optimizer = None
self._scaler = None
self._dataloaders = None
self._lr_sched = None
self.start_epoch = 0
self.setup_output_dir()
@property
def device(self):
if self._device is None:
self._device = torch.device(self.config.run.device)
return self._device
@property
def use_distributed(self):
return self.config.run.distributed
@property
def model(self):
"""
A property to get the DDP-wrapped model on the device.
"""
# move model to device
if self._model.device != self.device:
self._model = self._model.to(self.device)
# ddp training wrapper
if self.use_distributed:
if self._wrapped_model is None:
self._wrapped_model = DDP(
self._model, device_ids=[self.config.run.gpu]
)
else:
self._wrapped_model = self._model
return self._wrapped_model
@property
def dataloaders(self) -> dict:
run_cfg = self.config.run
if self._dataloaders is None:
self._dataloaders = get_dataloaders(
datasets = self.datasets,
batch_size = run_cfg.batch_size,
batch_size_val = run_cfg.batch_size_val,
num_worker = run_cfg.num_worker,
ddp = run_cfg.distributed,
)
return self._dataloaders
@property
def optimizer(self):
if self._optimizer is None:
# 可以用这个实现逐层lr decay
# 需要重写model的get_optimizer_params,可以参考lavis的vit
lr_scale = self.config.run.get("lr_layer_decay", 1)
weight_decay = self.config.run.get("weight_decay", 0.05)
optim_params = self._model.get_optimizer_params(weight_decay,lr_scale)
num_parameters = 0
for p_group in optim_params:
for p in p_group["params"]:
num_parameters += p.data.nelement()
logging.info("number of trainable parameters: {}".format(num_parameters))
beta2 = self.config.run.get("beta2", 0.999)
self._optimizer = torch.optim.AdamW(
optim_params,
lr=float(self.config.run.init_lr),
betas=(0.9, beta2),
)
return self._optimizer
@property
def scaler(self):
amp = self.config.run.get("amp", False)
if amp:
if self._scaler is None:
self._scaler = torch.cuda.amp.GradScaler()
return self._scaler
@property
def lr_scheduler(self):
"""
A property to get and create learning rate scheduler by split just in need.
"""
if self._lr_sched is None:
lr_sched_cls = registry.get_lr_scheduler_class(self.config.run.lr_sched)
# max_epoch = self.config.run.max_epoch
max_epoch = self.max_epoch
# min_lr = self.config.run.min_lr
min_lr = self.min_lr
# init_lr = self.config.run.init_lr
init_lr = self.init_lr
# optional parameters
decay_rate = self.config.run.get("lr_decay_rate", None)
warmup_start_lr = self.config.run.get("warmup_lr", -1)
warmup_steps = self.config.run.get("warmup_steps", 0)
self._lr_sched = lr_sched_cls(
optimizer=self.optimizer,
max_epoch=max_epoch,
min_lr=min_lr,
init_lr=init_lr,
decay_rate=decay_rate,
warmup_start_lr=warmup_start_lr,
warmup_steps=warmup_steps,
)
return self._lr_sched
@property
def train_loader(self):
train_dataloader = self.dataloaders["train"]
return train_dataloader
而其训练流程中,大致完成以下功能:
- resume功能:保存训练中的checkpoint(包含模型、优化器状态、scheduler状态、epoch等),如果中断可以加载,继续训练
- 保存checkpoint,也会根据eval结果保存最好的一个checkpoint
- 日志功能
- 训练、验证的基本流程
python
@main_process
def _save_checkpoint(self, cur_epoch, is_best=False):
"""
Save the checkpoint at the current epoch.
"""
model_no_ddp = self.unwrap_dist_model(self.model)
param_grad_dic = {
k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()
}
state_dict = model_no_ddp.state_dict()
for k in list(state_dict.keys()):
if k in param_grad_dic.keys() and not param_grad_dic[k]:
# delete parameters that do not require gradient
del state_dict[k]
save_obj = {
"model": state_dict,
"optimizer": self.optimizer.state_dict(),
"config": OmegaConf.to_container(self.config),
"scaler": self.scaler.state_dict() if self.scaler else None,
"epoch": cur_epoch,
}
save_to = os.path.join(
self.output_dir,
"checkpoint_{}.pth".format("best" if is_best else cur_epoch),
)
logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to))
torch.save(save_obj, save_to)
def _reload_best_model(self, model):
"""
Load the best checkpoint for evaluation.
"""
checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth")
logging.info("Loading checkpoint from {}.".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path, map_location="cpu")
try:
model.load_state_dict(checkpoint["model"])
except RuntimeError as e:
logging.warning(
"""
Key mismatch when loading checkpoint. This is expected if only part of the model is saved.
Trying to load the model with strict=False.
"""
)
model.load_state_dict(checkpoint["model"], strict=False)
return model
def _load_checkpoint(self, filename):
"""
Resume from a checkpoint.
"""
if os.path.isfile(filename):
checkpoint = torch.load(filename, map_location=self.device)
else:
raise RuntimeError("checkpoint url or path is invalid")
state_dict = checkpoint["model"]
self.unwrap_dist_model(self.model).load_state_dict(state_dict)
self.optimizer.load_state_dict(checkpoint["optimizer"])
if self.scaler and "scaler" in checkpoint:
self.scaler.load_state_dict(checkpoint["scaler"])
self.start_epoch = checkpoint["epoch"] + 1
logging.info("Resume checkpoint from {}".format(filename))
@main_process
def log_stats(self, stats, split_name):
if isinstance(stats, dict):
log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}}
with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
f.write(json.dumps(log_stats) + "\n")
elif isinstance(stats, list):
pass
@main_process
def log_config(self):
with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
f.write(json.dumps(OmegaConf.to_container(self.config), indent=4) + "\n")
@torch.no_grad()
def eval_epoch(self, cur_epoch, skip_reload=False):
"""
Evaluate the model on a given split.
Args:
split_name (str): name of the split to evaluate on.
cur_epoch (int): current epoch.
skip_reload_best (bool): whether to skip reloading the best checkpoint.
During training, we will reload the best checkpoint for validation.
During testing, we will use provided weights and skip reloading the best checkpoint .
"""
data_loader = self.dataloaders.get('val', None)
assert data_loader, "data_loader for split {} is None.".format("val")
# TODO In validation, you need to compute loss as well as metrics
# TODO consider moving to model.before_evaluation()
model = self.unwrap_dist_model(self.model)
if not skip_reload and cur_epoch == "best":
model = self._reload_best_model(model)
model.eval()
self.task.before_evaluation(
model=model,
dataset=self.datasets["val"],
)
results = self.task.evaluation(model, data_loader)
if results is not None:
return self.task.after_evaluation(
val_result=results,
epoch=cur_epoch,
)
def train(self):
start_time = time.time()
best_agg_metric = 0
best_epoch = 0
best_metrics = {}
self.log_config()
# resume from checkpoint if specified
if not self.evaluate_only and self.resume_ckpt_path is not None:
self._load_checkpoint(self.resume_ckpt_path)
for cur_epoch in range(self.start_epoch, self.max_epoch):
# training phase
if not self.evaluate_only:
logging.info("Start training")
# See https://github.com/salesforce/LAVIS/issues/449
# if cur_epoch == self.start_epoch:
# self.task.before_training(
# model=self.unwrap_dist_model(self.model),
# dataset=self.datasets["train"],
# )
train_stats = self.train_epoch(cur_epoch)
self.log_stats(split_name="train", stats=train_stats)
# evaluation phase
if cur_epoch % self.eval_freq == 0 or cur_epoch == self.max_epoch -1:
logging.info("Evaluating on {}.".format("val"))
val_log = self.eval_epoch(
cur_epoch=cur_epoch,
)
if val_log is not None:
if is_main_process():
assert (
"agg_metrics" in val_log
), "No agg_metrics found in validation log."
agg_metrics = val_log["agg_metrics"]
if agg_metrics > best_agg_metric:
best_epoch, best_agg_metric = cur_epoch, agg_metrics
best_metrics = deepcopy(val_log)
self._save_checkpoint(cur_epoch, is_best=True)
if cur_epoch % self.save_freq == 0 or cur_epoch == self.max_epoch -1:
self._save_checkpoint(cur_epoch, is_best=False)
val_log.update({"best_epoch": best_epoch})
self.log_stats(val_log, "val")
else: # 没有定义task的evaluation
if cur_epoch % self.save_freq == 0 or cur_epoch == self.max_epoch -1:
self._save_checkpoint(cur_epoch, is_best=False)
else:
if not self.evaluate_only:
if cur_epoch % self.save_freq == 0:
self._save_checkpoint(cur_epoch, is_best=False)
if self.evaluate_only:
break
if is_dist_avail_and_initialized():
dist.barrier()
return best_metrics
def train_epoch(self, epoch):
# train
self.model.train()
return self.task.train_epoch(
epoch=epoch,
model=self.model,
data_loader=self.train_loader,
optimizer=self.optimizer,
scaler=self.scaler,
lr_scheduler=self.lr_scheduler,
cuda_enabled=self.cuda_enabled,
log_freq=self.log_freq,
accum_grad_iters=self.accum_grad_iters,
grad_norm_clip=self.grad_norm_clip,
)
train.py文件:定义各组件并插入进Trainer中完成训练
train.py文件的运行指令:
bash
# 单卡
python train.py --cfg-path projects/train_classification.yaml
# 多卡
python -m torch.distributed.run --nproc_per_node=4 train.py --cfg-path projects/train_classification.yaml
# 换个任务
python train.py --cfg-path projects/train_image2prompt.yaml
其主要完成以下内容:
- 解析配置的yaml文件
- 根据配置的yaml文件在注册表中找到并定义指定的model、dataset、processor、task、lr_scheduler等基础组件
- 将基础组件插入Trainer,调用
trainer.train()
进行训练和验证
scss
import os
from pathlib import Path
import warnings
import argparse
from omegaconf import OmegaConf
import random
import numpy as np
import torch
import torch.distributed as dist
from common.dist_utils import (
init_distributed_mode,
main_process,
)
from common.registry import registry
from common.logger import setup_logger
from tasks import setup_task
from trainer import Trainer
# imports modules for registration
from common.optims import (
LinearWarmupCosineLRScheduler,
LinearWarmupStepLRScheduler,
ConstantLRScheduler,
) # 加入到注册表里,不用直接使用(由于是from的import形式,optim.py里的所有类都会加入注册表,所以实际上import一个也可以)
from processors import load_processor
from models import *
from datasets import load_dataset
warnings.filterwarnings('ignore')
def now():
from datetime import datetime
return datetime.now().strftime("%Y%m%d%H%M")[:-1]
def seed_everything(seed):
os.environ['PYTHONHASHSEED'] = str(seed)
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
def get_config(args):
cfg_path = Path(args.cfg_path)
assert cfg_path.suffix == '.yaml', 'config file must be .yaml file'
config = OmegaConf.load(cfg_path)
init_distributed_mode(config.run)
return config
def get_transforms(config) -> dict:
dataset_cfg = config.dataset
transforms = {}
transforms['train'] = load_processor(**dataset_cfg.train_cfg.transform)
transforms['val'] = load_processor(**dataset_cfg.val_cfg.transform)
return transforms
def get_datasets(config,transforms) -> dict:
dataset_cfg = config.dataset
datasets = {}
train_cfg = dict(dataset_cfg.pop('train_cfg'))
val_cfg = dict(dataset_cfg.pop('val_cfg'))
train_cfg['transform'], val_cfg['transform']= transforms['train'],transforms['val']
datasets["train"] = load_dataset(train_cfg.pop('name'),train_cfg)
datasets['val'] = load_dataset(val_cfg.pop('name'),val_cfg)
return datasets
def get_model(config):
model_cfg = config.model
model_cls = registry.get_model_class(model_cfg.arch)
return model_cls.from_config(model_cfg)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--cfg-path',type=str)
parser.add_argument('--seed',type=int,default=42)
args = parser.parse_args()
seed_everything(args.seed)
config = get_config(args)
setup_logger()
transforms = get_transforms(config)
datasets = get_datasets(config,transforms)
model = get_model(config)
task = setup_task(config)
job_id = now()
trainer = Trainer(config,model,datasets,task,job_id)
trainer.train()
if __name__ == "__main__":
main()