MMCV之Runner介绍

文章目录

前言

1、BaseRunner类

2、EpochBasedRunner

3、IterBasedRunner

总结

前言

mmcv/runner/base_runner.py文件中,定义了runner类。该类用于管理一个模型的训练和评估过程。这里放张官方示意图(runner简单来说就是实现了右边是个红色框的类):

1、BaseRunner类

该类是所有子runne的r基类,贴下最核心的代码(好多细节我给删除掉了,因为太多了):

cpp 复制代码
 class BaseRunner(metaclass=ABCMeta):
"""The base class of Runner, a training helper for PyTorch.

All subclasses should implement the following APIs:

- ``run()``
- ``train()``
- ``val()``
- ``save_checkpoint()``

"""

def __init__(self,
             model,
             batch_processor=None,
             optimizer=None,
             work_dir=None,
             logger=None,
             meta=None,
             max_iters=None,
             max_epochs=None):
self._hooks = []       # 用来存储hook列表

@abstractmethod    
def train(self):       # 定义了train方法
    pass

@abstractmethod
def val(self):
    pass

@abstractmethod
def run(self, data_loaders, workflow, **kwargs):
    pass

@abstractmethod
def save_checkpoint(self,
                    out_dir,
                    filename_tmpl,
                    save_optimizer=True,
                    meta=None,
                    create_symlink=True):
    pass
# 注册hook的三个函数
def register_hook(self, hook, priority='NORMAL'):
     self._hooks.insert(0, hook)
     
def register_lr_hook(self, lr_config):
   self.register_hook(hook, priority='VERY_HIGH')
   
def register_training_hooks(self,
                           lr_config,
                           optimizer_config=None,
                           checkpoint_config=None,
                           log_config=None,
                           momentum_config=None,
                           timer_config=dict(type='IterTimerHook'),
                           custom_hooks_config=None):
   self.register_lr_hook(lr_config)
   self.register_momentum_hook(momentum_config)
   self.register_optimizer_hook(optimizer_config)
   self.register_checkpoint_hook(checkpoint_config)
   self.register_timer_hook(timer_config)
   self.register_logger_hooks(log_config)
   self.register_custom_hooks(custom_hooks_config)

1)初始化部分:包括(模型、批次数据、优化器、工作目录、meta(seed)和epoch数和iter数)。另外,值得注意的是,初始化了一个self.hooks列表,里面存储元素为Hook类实例出来的对象。

2)@abstractmethod:装饰器修饰了四个抽象方法:train、val、run和save_checkpoint。只要继承该类的子类必须实现这四个方法。

3)注册hook函数:关于hook我会单独出一篇博文,只需知道大概路程即可。以lr_hook为例:首先输入参数lr_config传给register_training_hooks,之后函数内部调用register_lr_hook函数,将lr_config实例成对应的hook对象,最终调用register_hook函数:将lr_hook添加到self.hooks列表中。

接下来介绍两个子类的Runner。

2、EpochBasedRunner

废话不多说,贴核心代码:

cpp 复制代码
 @RUNNERS.register_module()
class EpochBasedRunner(BaseRunner):
    """Epoch-based Runner.

    This runner train models epoch by epoch.
    """
	# 执行一次iter的训练或推理。
    def run_iter(self, data_batch, train_mode, **kwargs):
    	# 调用model的train_step或者val_step方法。
        if self.mode == 'train':
        	outputs = self.model.train_step(data_batch, self.optimizer,**kwargs)                        
        else:
            outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
        self.outputs = outputs

    def train(self, data_loader, **kwargs):
        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._max_iters = self._max_epochs * len(self.data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(self.data_loader):
            self._inner_iter = i
            self.call_hook('before_train_iter')
            self.run_iter(data_batch, train_mode=True, **kwargs) # 调用run_iter方法
            self.call_hook('after_train_iter')
            self._iter += 1

        self.call_hook('after_train_epoch')
        self._epoch += 1

    @torch.no_grad()
    def val(self, data_loader, **kwargs):
        self.model.eval()
        self.mode = 'val'
        self.data_loader = data_loader
        self.call_hook('before_val_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(self.data_loader):
            self._inner_iter = i
            self.call_hook('before_val_iter')
            self.run_iter(data_batch, train_mode=False)        # 调用run_iter方法
            self.call_hook('after_val_iter')

        self.call_hook('after_val_epoch')

    def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
        self.call_hook('before_run')
        while self.epoch < self._max_epochs:
            for i, flow in enumerate(workflow):
                epoch_runner = getattr(self, mode)       # 根据mode字段决定调用train方法还是val方法
                for _ in range(epochs):
                    if mode == 'train' and self.epoch >= self._max_epochs:
                        break
                    epoch_runner(data_loaders[i], **kwargs)   # 去执行train OR val
                    
        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_run')

    def save_checkpoint(self,
                        out_dir,
                        filename_tmpl='epoch_{}.pth',
                        save_optimizer=True,
                        meta=None,
                        create_symlink=True):
		pass

EpochBasedRunner继承了BaseRunner类,故实现了四种方法。save_checkpoing不多说了,核心是run方法,内部思路就是开头的那张图,借助mode字段是'train'或者'val'去调用不同train方法或者val方法。而train和val内部调用run_iter方法执行一次迭代的前向传播计算。 该runner借助epoch来训练模型,是mmdet中最常用的runner。

3、IterBasedRunner

cpp 复制代码
class IterLoader:

    def __init__(self, dataloader):
        self._dataloader = dataloader
        self.iter_loader = iter(self._dataloader)
        self._epoch = 0

    @property
    def epoch(self):
        return self._epoch

    def __next__(self):
        try:
            data = next(self.iter_loader) # 若没到Max_iter,则就正常取一个batch数据
        except StopIteration:
            self._epoch += 1              # 若迭代完当前epoch,则重新转成迭代器,重新取一个batch数据
            if hasattr(self._dataloader.sampler, 'set_epoch'):
                self._dataloader.sampler.set_epoch(self._epoch)
            time.sleep(2)  # Prevent possible deadlock during epoch transition
            self.iter_loader = iter(self._dataloader) # 重新转成迭代器
            data = next(self.iter_loader)             # 调用下一批次数据

        return data

    def __len__(self):
        return len(self._dataloader)


@RUNNERS.register_module()
class IterBasedRunner(BaseRunner):
    """Iteration-based Runner.

    This runner train models iteration by iteration.
    """

    def train(self, data_loader, **kwargs):
        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._epoch = data_loader.epoch
        data_batch = next(data_loader)              # 将data_loader变成迭代器
        self.call_hook('before_train_iter')
        outputs = self.model.train_step(data_batch, self.optimizer, **kwargs) # 执行前向传播
        self.outputs = outputs
        self.call_hook('after_train_iter')
        self._inner_iter += 1
        self._iter += 1

    @torch.no_grad()
    def val(self, data_loader, **kwargs):
        self.model.eval()
        self.mode = 'val'
        self.data_loader = data_loader
        data_batch = next(data_loader)            # 将data_loader变成迭代器
        self.call_hook('before_val_iter')
        outputs = self.model.val_step(data_batch, **kwargs)     # 执行前向传播
        self.outputs = outputs
        self.call_hook('after_val_iter')
        self._inner_iter += 1

    def run(self, data_loaders, workflow, max_iters=None, **kwargs):
        self.call_hook('before_run')
        iter_loaders = [IterLoader(x) for x in data_loaders]  # 转成迭代器
        self.call_hook('before_epoch')
        while self.iter < self._max_iters:                    # 不到最大迭代轮数就不停止
            for i, flow in enumerate(workflow):
                self._inner_iter = 0
                mode, iters = flow
                iter_runner = getattr(self, mode)             # 调用train OR val
                for _ in range(iters):
                    if mode == 'train' and self.iter >= self._max_iters: # 若超出max_iter则break
                        break
                    iter_runner(iter_loaders[i], **kwargs)    # 执行前向传播计算
        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_epoch')
        self.call_hook('after_run')

    def save_checkpoint(self,
                        out_dir,
                        filename_tmpl='iter_{}.pth',
                        meta=None,
                        save_optimizer=True,
                        create_symlink=True):
    	pass

IterBaseRunner大同小异,同样实现了四个方法。唯一和EpochBaseRunner区别是没有实现run_iter方法。由于该runner以最大迭代轮数训练,故分别在train和val方法中实现了run_iter的计算。另外,多了一个IterLoader类,作用是当迭代完一个epoch后,重新遍历数据,此时用该类就可以用try-except实现重新迭代,可以看我的注释。

总结

本文介绍了mmcv中runner介绍,基本所有mmdet模型都用到上述两个runner。

相关推荐
数据小爬虫@2 小时前
深入解析:使用 Python 爬虫获取苏宁商品详情
开发语言·爬虫·python
健胃消食片片片片2 小时前
Python爬虫技术:高效数据收集与深度挖掘
开发语言·爬虫·python
ℳ₯㎕ddzོꦿ࿐5 小时前
解决Python 在 Flask 开发模式下定时任务启动两次的问题
开发语言·python·flask
CodeClimb5 小时前
【华为OD-E卷 - 第k个排列 100分(python、java、c++、js、c)】
java·javascript·c++·python·华为od
一水鉴天5 小时前
为AI聊天工具添加一个知识系统 之63 详细设计 之4:AI操作系统 之2 智能合约
开发语言·人工智能·python
Channing Lewis5 小时前
什么是 Flask 的蓝图(Blueprint)
后端·python·flask
B站计算机毕业设计超人5 小时前
计算机毕业设计hadoop+spark股票基金推荐系统 股票基金预测系统 股票基金可视化系统 股票基金数据分析 股票基金大数据 股票基金爬虫
大数据·hadoop·python·spark·课程设计·数据可视化·推荐算法
觅远6 小时前
python+playwright自动化测试(四):元素操作(键盘鼠标事件)、文件上传
python·自动化
ghostwritten7 小时前
Python FastAPI 实战应用指南
开发语言·python·fastapi
CM莫问7 小时前
python实战(十五)——中文手写体数字图像CNN分类
人工智能·python·深度学习·算法·cnn·图像分类·手写体识别