多模态大模型实战-MiniGPT4Qwen系列2:回到世界原点-基于lavis和registry机制搭建更加灵活的Trainer

训练模型不能没有一个灵活的Trainer,就像纪录片不能没有麦克阿瑟

说到Trainer,大多人会想到pytorch lightning和huggingface,也有相关问题去对比这二者,在使用过huggingface的Trainer后,我认为它有以下两个缺点:

  • 用多层封装换来了易用性,但如果要自定义模块(比如:想给cosine scheduler设置一个min_lr、想实现vit的学习率逐层decay)会比较麻烦
  • 参数和功能有点多了,这些功能耦合在一起,会有些混乱,对于自己做小项目或者做科研,似乎不需要这么多功能

在上个月,蹭着通义千问的热度,我写了这篇

juejin.cn/post/729832...

项目开源于:

github.com/Coobiw/Mini...

该项目主要是重构lavis之后搭建的,lavis(github.com/salesforce/... )是多模态领域很火的一个开源仓库,像BLIP2、InstructBLIP、MiniGPT4等许多多模态大模型都是基于lavis进行进一步开发的。在仔细阅读其源码后,我非常喜欢它的代码框架,所以我针对其Trainer进行重构,可以更加灵活地适配或迁移到用户的任务、模型、数据集。

这个干净、灵活又不太冗杂的Trainer开源在:

github.com/Coobiw/Mini...

欢迎大家在私信、知乎、github仓库issue中给这个项目提提建议,如果对你有帮助的话,请多多点star呀!这对我真很重要:)

实现的功能

  • Registry机制:为model、dataset、processor(预处理的transform)、lr_scheduler、task(现在进行的task,如:分类、分割、image2prompt等)构建注册表
  • 完整、灵活的配置文件:一个配置文件对应一次完整的运行(训练),有多而不冗余的参数可供设置
  • 去冗余性:
    • 对于上述的注册表中的每个组件,都提供有基类,减少代码重复
    • 去除一些重复、冗余的功能
  • 可扩展性/灵活性:自顶向下满足了
    • 任务可扩展(类似于OpenMMLab基于MMEngine和MMCV支持了那么多视觉、多模态任务):对于所有任务均可支持,本项目支持了图像分类(以猫狗分类为例)、Image2Prompt(为了适配本菜鸡第一次kaggle比赛(www.kaggle.com/competition...,最终获得银牌,虽菜但难忘就多实现了它,简易化pipeline如下图)
    • 模型可扩展
    • 数据集可扩展性(包含预处理的可扩展性)
    • scheduler的可扩展性

支持新功能的QuickTutorial

定义你的数据集

datasets目录下,继承BaseDataset类,实现你的dataset,如果需要自定义collator,请在这里完成

例如如下代码完成了一个新的分类数据集定义,其目录结构见example_data/classification

python 复制代码
from common.registry import registry
from .base_dataset import BaseDataset

from PIL import Image
from pathlib import Path
import os

import torch

@registry.register_dataset('example_cls_dataset')
class ExampleClsDataset(BaseDataset):
    def __init__(self, transform, data_root):
        super().__init__(transform=transform)
        self.data_root = Path(data_root)
        self.cls_dir = sorted(list(os.listdir(self.data_root)))

        self.data = []
        self.labels = []

        self.idx2cls,self.cls2idx = {}, {}
        for i,cls in enumerate(self.cls_dir):
            self.idx2cls[i] = cls
            self.cls2idx[cls] = i
            imgs = [str(self.data_root/cls/img) for img in os.listdir(self.data_root/cls) if self.is_img(img)]
            self.data.extend(imgs)
            self.labels.extend([i]*len(imgs))

        assert len(self.data) == len(self.labels)


    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        image_path = self.data[index]
        label = self.labels[index]

        image = Image.open(image_path).convert("RGB")

        return self.transform(image), torch.tensor(label,dtype=torch.long)

    @staticmethod
    def is_img(img_path):
        return Path(img_path).suffix.lower() in ['.jpg', '.jpeg', '.png']

    def collator(self,batch):
        images, labels = zip(*batch)
        images = torch.stack(images)
        labels = torch.stack(labels)

        return {"images": images, "labels": labels}

定义你的模型

models目录下,继承BaseModel类进行实现,可以参考本库给出的resnet_clip

注意:请为你的模型实现train_stepval_step两个方法,会在task.train_steptask.val_step时调用

定义新的task

tasks目录下,继承BaseTask类进行实现,可以参考本库给出的ClassificationTask任务

注:一般来说,只需要针对任务和任务对应的metric修改 val_step 即可

为什么需要Trainer和registry机制

要想知道为什么需要Trainer,首先我们创造一个没有Trainer的时代,只使用原生pytorch去构建一个训练流程,这时我们需要做:

  1. 定义Dataset、Dataloader
  2. 定义model
  3. 定义损失函数
  4. 定义损失函数
  5. 定义优化器
  6. 定义训练过程中的学习率变化策略(scheduler)
  7. 循环、迭代更新模型

大致pytorch代码如下:

ini 复制代码
import torch
from torch.utils.data import DataLoader

from dataset import train_data, val_data
from network import Net

# dataloader
train_loader = DataLoader(dataset=train_data, batch_size=16, shuffle=True)
val_loader = DataLoader(dataset=val_data, batch_size=16, shuffle=False)

# 定义模型
model = Net(...)

# 定义损失函数
criterion = torch.nn.CrossEntropyLoss()

# 定义优化器
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

# 定义scheduler
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,...)

# 定义训练epoch的次数
epochs = 50
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 循环迭代起来
for epoch in range(epochs):
    train_loss = 0
    train_acc = 0

    model.train()
    for i, (x, y) in enumerate(train_loader):
        x = x.to(device)
        y = y.to(device)
        model = model.to(device)
        out = model(x)
        loss = criterion(out, y)
        train_loss += loss.item()
        prediction = torch.max(out,1)[1]
        pred_correct = (prediction == y).sum()
        train_acc += pred_correct.item()
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(...)
    # 验证
    model.eval()
    with torch.no_grad():
        eval_loss = 0
        eval_acc = 0
        for i, (x, y) in enmuerate(val_loader):
            out = model(x)
            loss = criterion(out, y)
            eval_loss += loss.item()
            prediction = torch.max(out, 1)[1]
            pred_correct = (prediction == y).sum()
            eval_acc += pred_correct.item()
        print(...)

这段代码看起来没什么问题,很简洁易懂,代价就是:

代码封装上:

  • 当前代码完全没有怎么进行封装,完成pipeline的每个部分都呈现在一个脚本文件中,从程序设计的角度很不优美

事实上,只需要:

php 复制代码
# 各种组件的定义
......

trainer.train()

trainer.eval()

将固定的迭代更新流程写在trainer的train()函数中,然后train里再

python 复制代码
class Trainer:
    ......
    
    def train(xxxx):
        for _ in range(epoch):
            self.train_step(xxxxx)
            
            self.val_step(xxxx)
      
    def train_step(xxx):
        for sample in dataloader:
            loss = model.train_step(sample)
            # 反向传播
            ......
    def val_step(xx):
        for sample in dataloader:
            metric = model.val_step(sample)
            # log、save模型等其他功能
            ......

灵活性上:

  • 当前任务,我要换个模型!

    • 好的,from network import Net 改成 from network import NewNet
  • 另一个任务,我要换另一个任务的模型!

    • 好的,from network import Net 改成 from network import NewTaskNet
  • 我要换个数据集!

    • 好的,from dataset import train_data 改成 from dataset import new_train_data
  • 我要换个优化器!我要换个scheduler!

    • 好的,要么也像上面一样,在别的文件实现了,然后改import

    • 要么,直接改这个脚本里的变量定义

这两个问题,就可以通过实现一个Trainer,并且加入register机制,将pipeline中各个模块的定义字符串化,可以通过一个yaml文件直接定义一次运行行为中所有组件。

可以认为,通过Register机制定义的是一个个组件,而Trainer就是一个封装所需功能,并给这些组件提供插槽!

灵活的基础组件定义:register机制------将类定义字符串化

现在的一些高影响力的开源仓库,经常离不开register机制,比如:timm、openmmlab系列仓库等,register机制的代码大致如下:

python 复制代码
class Registry:
    mapping = {
        "dataset_name_mapping": {},
        "task_name_mapping": {},
        "processor_name_mapping": {},
        "model_name_mapping": {},
        "lr_scheduler_name_mapping": {},
    }
    @classmethod
    def register_model(cls, name):
        def wrap(model_cls):
            from models import BaseModel

            assert issubclass(
                model_cls, BaseModel
            ), "All models must inherit BaseModel class"
            if name in cls.mapping["model_name_mapping"]:
                raise KeyError(
                    "Name '{}' already registered for {}.".format(
                        name, cls.mapping["model_name_mapping"][name]
                    )
                )
            cls.mapping[  "model_name_mapping"  ][name] = model_cls
            return model_cls

        return wrap
        
    @classmethod
    def get_model_class(cls, name):
        return cls.mapping["model_name_mapping"].get(name, None)

registry = Registry()

可以看到,关键的一行代码就是:

css 复制代码
cls.mapping[  "model_name_mapping"  ][name] = model_cls

即将一个定义好的model class放到model_name_mapping这个字典(一般称为注册表)中,如果我们需要找到这个模型,只需要:

ini 复制代码
model_name = xxx # name
model = register.get_model_class(model_name)(模型的init参数们)

这样就实现了字符串与model class的映射,后续就不需要像之前说的每次import新定义的model class,直接通过修改配置文件里的model_name即可。

完成注册操作的代码如下:

python 复制代码
@registry.register_model(模型name)
class NewNet(BaseModel): # 继承预定义好的BaseModel类,减少重复代码的ctrl cv
    ......

前面在定义registry时,我们有:

makefile 复制代码
mapping = {
        "dataset_name_mapping": {},
        "task_name_mapping": {},
        "processor_name_mapping": {},
        "model_name_mapping": {},
        "lr_scheduler_name_mapping": {},
    }

这些注册表,包含:dataset、task、processor(输入的预处理,比如:读取图像、数据增强、ToTensor、归一化)、model、lr_scheduler,如果有进一步的包括优化器等你愿意去修改的组件,都可以为他构建一个注册表

这样一来,训练pipeline中各个组件的定义就完全字符串化了,这里放上一个最后的配置文件的部分截图和简单注释,让大家直观地感受字符串化后的好处!

给各个组件来一个功能齐全且带"插槽"的Trainer

当定义好各个组件后,就需要一个功能齐全,带插槽的Trainer,来让它们发挥作用了,在Trainer中,需要定义一整个训练、验证的流程,需要将输入的组件们进一步封装,发挥作用(如:将dataset变成dataloader)。

从自底向上的编程角度来看,Trainer就应该处于最上层,他需要足够大、足够global,可以适配底部组件们的变化。

首先,Trainer将组件插入进来的过程:(这里并没有把所有参数介绍全,仅介绍了最常见的组件和参数们,更加细节可以去仓库看源码)

python 复制代码
class Trainer:
    def __init__(self,config,model,datasets,task,job_id):
        self.config = config
        self.job_id = job_id
        self._model = model
        self.datasets = datasets
        self.task = task

        self._wrapped_model = None
        self._device = None
        self._optimizer = None
        self._scaler = None
        self._dataloaders = None
        self._lr_sched = None

        self.start_epoch = 0

        self.setup_output_dir()
 
    @property
    def device(self):
        if self._device is None:
            self._device = torch.device(self.config.run.device)

        return self._device

    @property
    def use_distributed(self):
        return self.config.run.distributed

    @property
    def model(self):
        """
        A property to get the DDP-wrapped model on the device.
        """
        # move model to device
        if self._model.device != self.device:
            self._model = self._model.to(self.device)

            # ddp training wrapper
            if self.use_distributed:
                if self._wrapped_model is None:
                    self._wrapped_model = DDP(
                        self._model, device_ids=[self.config.run.gpu]
                    )
            else:
                self._wrapped_model = self._model

        return self._wrapped_model

    @property
    def dataloaders(self) -> dict:
        run_cfg = self.config.run
        if self._dataloaders is None:
            self._dataloaders = get_dataloaders(
                datasets = self.datasets,
                batch_size = run_cfg.batch_size,
                batch_size_val = run_cfg.batch_size_val,
                num_worker = run_cfg.num_worker,
                ddp = run_cfg.distributed,
            )
        return self._dataloaders

    @property
    def optimizer(self):
        if self._optimizer is None:
            # 可以用这个实现逐层lr decay
            # 需要重写model的get_optimizer_params,可以参考lavis的vit
            lr_scale = self.config.run.get("lr_layer_decay", 1)
            weight_decay = self.config.run.get("weight_decay", 0.05)
            optim_params = self._model.get_optimizer_params(weight_decay,lr_scale)

            num_parameters = 0
            for p_group in optim_params:
                for p in p_group["params"]:
                    num_parameters += p.data.nelement()
            logging.info("number of trainable parameters: {}".format(num_parameters))

            beta2 = self.config.run.get("beta2", 0.999)

            self._optimizer = torch.optim.AdamW(
                optim_params,
                lr=float(self.config.run.init_lr),
                betas=(0.9, beta2),
            )
        return self._optimizer

    @property
    def scaler(self):
        amp = self.config.run.get("amp", False)

        if amp:
            if self._scaler is None:
                self._scaler = torch.cuda.amp.GradScaler()

        return self._scaler

    @property
    def lr_scheduler(self):
        """
        A property to get and create learning rate scheduler by split just in need.
        """
        if self._lr_sched is None:
            lr_sched_cls = registry.get_lr_scheduler_class(self.config.run.lr_sched)

            # max_epoch = self.config.run.max_epoch
            max_epoch = self.max_epoch
            # min_lr = self.config.run.min_lr
            min_lr = self.min_lr
            # init_lr = self.config.run.init_lr
            init_lr = self.init_lr

            # optional parameters
            decay_rate = self.config.run.get("lr_decay_rate", None)
            warmup_start_lr = self.config.run.get("warmup_lr", -1)
            warmup_steps = self.config.run.get("warmup_steps", 0)

            self._lr_sched = lr_sched_cls(
                optimizer=self.optimizer,
                max_epoch=max_epoch,
                min_lr=min_lr,
                init_lr=init_lr,
                decay_rate=decay_rate,
                warmup_start_lr=warmup_start_lr,
                warmup_steps=warmup_steps,
            )

        return self._lr_sched

    @property
    def train_loader(self):
        train_dataloader = self.dataloaders["train"]

        return train_dataloader

而其训练流程中,大致完成以下功能:

  • resume功能:保存训练中的checkpoint(包含模型、优化器状态、scheduler状态、epoch等),如果中断可以加载,继续训练
  • 保存checkpoint,也会根据eval结果保存最好的一个checkpoint
  • 日志功能
  • 训练、验证的基本流程
python 复制代码
    @main_process
    def _save_checkpoint(self, cur_epoch, is_best=False):
        """
        Save the checkpoint at the current epoch.
        """
        model_no_ddp = self.unwrap_dist_model(self.model)
        param_grad_dic = {
            k: v.requires_grad for (k, v) in model_no_ddp.named_parameters()
        }
        state_dict = model_no_ddp.state_dict()
        for k in list(state_dict.keys()):
            if k in param_grad_dic.keys() and not param_grad_dic[k]:
                # delete parameters that do not require gradient
                del state_dict[k]

        save_obj = {
            "model": state_dict,
            "optimizer": self.optimizer.state_dict(),
            "config": OmegaConf.to_container(self.config),
            "scaler": self.scaler.state_dict() if self.scaler else None,
            "epoch": cur_epoch,
        }
        save_to = os.path.join(
            self.output_dir,
            "checkpoint_{}.pth".format("best" if is_best else cur_epoch),
        )
        logging.info("Saving checkpoint at epoch {} to {}.".format(cur_epoch, save_to))
        torch.save(save_obj, save_to)

    def _reload_best_model(self, model):
        """
        Load the best checkpoint for evaluation.
        """
        checkpoint_path = os.path.join(self.output_dir, "checkpoint_best.pth")

        logging.info("Loading checkpoint from {}.".format(checkpoint_path))
        checkpoint = torch.load(checkpoint_path, map_location="cpu")
        try:
            model.load_state_dict(checkpoint["model"])
        except RuntimeError as e:
            logging.warning(
                """
                Key mismatch when loading checkpoint. This is expected if only part of the model is saved.
                Trying to load the model with strict=False.
                """
            )
            model.load_state_dict(checkpoint["model"], strict=False)
        return model

    def _load_checkpoint(self, filename):
        """
        Resume from a checkpoint.
        """
        if os.path.isfile(filename):
            checkpoint = torch.load(filename, map_location=self.device)
        else:
            raise RuntimeError("checkpoint url or path is invalid")

        state_dict = checkpoint["model"]
        self.unwrap_dist_model(self.model).load_state_dict(state_dict)

        self.optimizer.load_state_dict(checkpoint["optimizer"])
        if self.scaler and "scaler" in checkpoint:
            self.scaler.load_state_dict(checkpoint["scaler"])

        self.start_epoch = checkpoint["epoch"] + 1
        logging.info("Resume checkpoint from {}".format(filename))

    @main_process
    def log_stats(self, stats, split_name):
        if isinstance(stats, dict):
            log_stats = {**{f"{split_name}_{k}": v for k, v in stats.items()}}
            with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
                f.write(json.dumps(log_stats) + "\n")
        elif isinstance(stats, list):
            pass

    @main_process
    def log_config(self):
        with open(os.path.join(self.output_dir, "log.txt"), "a") as f:
            f.write(json.dumps(OmegaConf.to_container(self.config), indent=4) + "\n")
  
    @torch.no_grad()
    def eval_epoch(self, cur_epoch, skip_reload=False):
        """
        Evaluate the model on a given split.

        Args:
            split_name (str): name of the split to evaluate on.
            cur_epoch (int): current epoch.
            skip_reload_best (bool): whether to skip reloading the best checkpoint.
                During training, we will reload the best checkpoint for validation.
                During testing, we will use provided weights and skip reloading the best checkpoint .
        """
        data_loader = self.dataloaders.get('val', None)
        assert data_loader, "data_loader for split {} is None.".format("val")

        # TODO In validation, you need to compute loss as well as metrics
        # TODO consider moving to model.before_evaluation()
        model = self.unwrap_dist_model(self.model)
        if not skip_reload and cur_epoch == "best":
            model = self._reload_best_model(model)
        model.eval()

        self.task.before_evaluation(
            model=model,
            dataset=self.datasets["val"],
        )
        results = self.task.evaluation(model, data_loader)

        if results is not None:
            return self.task.after_evaluation(
                val_result=results,
                epoch=cur_epoch,
            )

    def train(self):
        start_time = time.time()
        best_agg_metric = 0
        best_epoch = 0
        best_metrics = {}

        self.log_config()

        # resume from checkpoint if specified
        if not self.evaluate_only and self.resume_ckpt_path is not None:
            self._load_checkpoint(self.resume_ckpt_path)

        for cur_epoch in range(self.start_epoch, self.max_epoch):
            # training phase
            if not self.evaluate_only:
                logging.info("Start training")
                # See https://github.com/salesforce/LAVIS/issues/449
                # if cur_epoch == self.start_epoch:
                #     self.task.before_training(
                #         model=self.unwrap_dist_model(self.model),
                #         dataset=self.datasets["train"],
                #     )
                train_stats = self.train_epoch(cur_epoch)
                self.log_stats(split_name="train", stats=train_stats)

            # evaluation phase
            if cur_epoch % self.eval_freq == 0 or cur_epoch == self.max_epoch -1:
                logging.info("Evaluating on {}.".format("val"))

                val_log = self.eval_epoch(
                    cur_epoch=cur_epoch,
                )
                if val_log is not None:
                    if is_main_process():
                        assert (
                            "agg_metrics" in val_log
                        ), "No agg_metrics found in validation log."

                        agg_metrics = val_log["agg_metrics"]
                        if agg_metrics > best_agg_metric:
                            best_epoch, best_agg_metric = cur_epoch, agg_metrics
                            best_metrics = deepcopy(val_log)

                            self._save_checkpoint(cur_epoch, is_best=True)

                        if cur_epoch % self.save_freq == 0 or cur_epoch == self.max_epoch -1:
                            self._save_checkpoint(cur_epoch, is_best=False)
                        val_log.update({"best_epoch": best_epoch})
                        self.log_stats(val_log, "val")
                else:  # 没有定义task的evaluation
                    if cur_epoch % self.save_freq == 0 or cur_epoch == self.max_epoch -1:
                        self._save_checkpoint(cur_epoch, is_best=False)

            else:
                if not self.evaluate_only:
                    if cur_epoch % self.save_freq == 0:
                        self._save_checkpoint(cur_epoch, is_best=False)

            if self.evaluate_only:
                break

            if is_dist_avail_and_initialized():
                dist.barrier()

        return best_metrics

    def train_epoch(self, epoch):
        # train
        self.model.train()

        return self.task.train_epoch(
            epoch=epoch,
            model=self.model,
            data_loader=self.train_loader,
            optimizer=self.optimizer,
            scaler=self.scaler,
            lr_scheduler=self.lr_scheduler,
            cuda_enabled=self.cuda_enabled,
            log_freq=self.log_freq,
            accum_grad_iters=self.accum_grad_iters,
            grad_norm_clip=self.grad_norm_clip,
        )

train.py文件:定义各组件并插入进Trainer中完成训练

train.py文件的运行指令:

bash 复制代码
# 单卡
python train.py --cfg-path projects/train_classification.yaml

# 多卡
python -m torch.distributed.run --nproc_per_node=4 train.py --cfg-path projects/train_classification.yaml

# 换个任务
python train.py --cfg-path projects/train_image2prompt.yaml

其主要完成以下内容:

  • 解析配置的yaml文件
  • 根据配置的yaml文件在注册表中找到并定义指定的model、dataset、processor、task、lr_scheduler等基础组件
  • 将基础组件插入Trainer,调用trainer.train()进行训练和验证
scss 复制代码
import os
from pathlib import Path

import warnings

import argparse
from omegaconf import OmegaConf

import random
import numpy as np
import torch
import torch.distributed as dist

from common.dist_utils import (
    init_distributed_mode,
    main_process,
)

from common.registry import registry
from common.logger import setup_logger
from tasks import setup_task

from trainer import Trainer

# imports modules for registration
from common.optims import (
    LinearWarmupCosineLRScheduler,
    LinearWarmupStepLRScheduler,
    ConstantLRScheduler,
)  # 加入到注册表里,不用直接使用(由于是from的import形式,optim.py里的所有类都会加入注册表,所以实际上import一个也可以)

from processors import load_processor
from models import *
from datasets import load_dataset

warnings.filterwarnings('ignore')

def now():
    from datetime import datetime

    return datetime.now().strftime("%Y%m%d%H%M")[:-1]

def seed_everything(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)

    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.backends.cudnn.deterministic = True

def get_config(args):
    cfg_path = Path(args.cfg_path)
    assert cfg_path.suffix == '.yaml', 'config file must be .yaml file'
    config = OmegaConf.load(cfg_path)
    init_distributed_mode(config.run)
    return config

def get_transforms(config) -> dict:
    dataset_cfg = config.dataset

    transforms = {}
    transforms['train'] = load_processor(**dataset_cfg.train_cfg.transform)
    transforms['val'] = load_processor(**dataset_cfg.val_cfg.transform)

    return transforms

def get_datasets(config,transforms) -> dict:
    dataset_cfg = config.dataset

    datasets = {}
    train_cfg = dict(dataset_cfg.pop('train_cfg'))
    val_cfg = dict(dataset_cfg.pop('val_cfg'))
    train_cfg['transform'], val_cfg['transform']= transforms['train'],transforms['val']
    datasets["train"] = load_dataset(train_cfg.pop('name'),train_cfg)
    datasets['val'] = load_dataset(val_cfg.pop('name'),val_cfg)

    return datasets

def get_model(config):
    model_cfg = config.model
    model_cls = registry.get_model_class(model_cfg.arch)
    return model_cls.from_config(model_cfg)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--cfg-path',type=str)
    parser.add_argument('--seed',type=int,default=42)
    args = parser.parse_args()

    seed_everything(args.seed)
    config = get_config(args)

    setup_logger()

    transforms = get_transforms(config)
    datasets = get_datasets(config,transforms)
    model = get_model(config)
    task = setup_task(config)
    job_id = now()

    trainer = Trainer(config,model,datasets,task,job_id)
    trainer.train()

if __name__ == "__main__":
    main()
相关推荐
hsling松子19 分钟前
使用PaddleHub智能生成,献上浓情国庆福
人工智能·算法·机器学习·语言模型·paddlepaddle
正在走向自律22 分钟前
机器学习框架
人工智能·机器学习
dengqingrui12344 分钟前
【树形DP】AT_dp_p Independent Set 题解
c++·学习·算法·深度优先·图论·dp
C++忠实粉丝1 小时前
前缀和(8)_矩阵区域和
数据结构·c++·线性代数·算法·矩阵
好吃番茄1 小时前
U mamba配置问题;‘KeyError: ‘file_ending‘
人工智能·机器学习
ZZZ_O^O1 小时前
二分查找算法——寻找旋转排序数组中的最小值&点名
数据结构·c++·学习·算法·二叉树
CV-King2 小时前
opencv实战项目(三十):使用傅里叶变换进行图像边缘检测
人工智能·opencv·算法·计算机视觉
禁默2 小时前
2024年计算机视觉与艺术研讨会(CVA 2024)
人工智能·计算机视觉
代码雕刻家2 小时前
数据结构-3.9.栈在递归中的应用
c语言·数据结构·算法
雨中rain2 小时前
算法 | 位运算(哈希思想)
算法