文章目录
前言
1、BaseRunner类
2、EpochBasedRunner
3、IterBasedRunner
总结
前言
mmcv/runner/base_runner.py文件中,定义了runner类。该类用于管理一个模型的训练和评估过程。这里放张官方示意图(runner简单来说就是实现了右边是个红色框的类):
1、BaseRunner类
该类是所有子runne的r基类,贴下最核心的代码(好多细节我给删除掉了,因为太多了):
cpp
class BaseRunner(metaclass=ABCMeta):
"""The base class of Runner, a training helper for PyTorch.
All subclasses should implement the following APIs:
- ``run()``
- ``train()``
- ``val()``
- ``save_checkpoint()``
"""
def __init__(self,
model,
batch_processor=None,
optimizer=None,
work_dir=None,
logger=None,
meta=None,
max_iters=None,
max_epochs=None):
self._hooks = [] # 用来存储hook列表
@abstractmethod
def train(self): # 定义了train方法
pass
@abstractmethod
def val(self):
pass
@abstractmethod
def run(self, data_loaders, workflow, **kwargs):
pass
@abstractmethod
def save_checkpoint(self,
out_dir,
filename_tmpl,
save_optimizer=True,
meta=None,
create_symlink=True):
pass
# 注册hook的三个函数
def register_hook(self, hook, priority='NORMAL'):
self._hooks.insert(0, hook)
def register_lr_hook(self, lr_config):
self.register_hook(hook, priority='VERY_HIGH')
def register_training_hooks(self,
lr_config,
optimizer_config=None,
checkpoint_config=None,
log_config=None,
momentum_config=None,
timer_config=dict(type='IterTimerHook'),
custom_hooks_config=None):
self.register_lr_hook(lr_config)
self.register_momentum_hook(momentum_config)
self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config)
self.register_timer_hook(timer_config)
self.register_logger_hooks(log_config)
self.register_custom_hooks(custom_hooks_config)
1)初始化部分:包括(模型、批次数据、优化器、工作目录、meta(seed)和epoch数和iter数)。另外,值得注意的是,初始化了一个self.hooks列表,里面存储元素为Hook类实例出来的对象。
2)@abstractmethod:装饰器修饰了四个抽象方法:train、val、run和save_checkpoint。只要继承该类的子类必须实现这四个方法。
3)注册hook函数:关于hook我会单独出一篇博文,只需知道大概路程即可。以lr_hook为例:首先输入参数lr_config传给register_training_hooks,之后函数内部调用register_lr_hook函数,将lr_config实例成对应的hook对象,最终调用register_hook函数:将lr_hook添加到self.hooks列表中。
接下来介绍两个子类的Runner。
2、EpochBasedRunner
废话不多说,贴核心代码:
cpp
@RUNNERS.register_module()
class EpochBasedRunner(BaseRunner):
"""Epoch-based Runner.
This runner train models epoch by epoch.
"""
# 执行一次iter的训练或推理。
def run_iter(self, data_batch, train_mode, **kwargs):
# 调用model的train_step或者val_step方法。
if self.mode == 'train':
outputs = self.model.train_step(data_batch, self.optimizer,**kwargs)
else:
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
self.outputs = outputs
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs) # 调用run_iter方法
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
@torch.no_grad()
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
self.run_iter(data_batch, train_mode=False) # 调用run_iter方法
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
self.call_hook('before_run')
while self.epoch < self._max_epochs:
for i, flow in enumerate(workflow):
epoch_runner = getattr(self, mode) # 根据mode字段决定调用train方法还是val方法
for _ in range(epochs):
if mode == 'train' and self.epoch >= self._max_epochs:
break
epoch_runner(data_loaders[i], **kwargs) # 去执行train OR val
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
def save_checkpoint(self,
out_dir,
filename_tmpl='epoch_{}.pth',
save_optimizer=True,
meta=None,
create_symlink=True):
pass
EpochBasedRunner继承了BaseRunner类,故实现了四种方法。save_checkpoing不多说了,核心是run方法,内部思路就是开头的那张图,借助mode字段是'train'或者'val'去调用不同train方法或者val方法。而train和val内部调用run_iter方法执行一次迭代的前向传播计算。 该runner借助epoch来训练模型,是mmdet中最常用的runner。
3、IterBasedRunner
cpp
class IterLoader:
def __init__(self, dataloader):
self._dataloader = dataloader
self.iter_loader = iter(self._dataloader)
self._epoch = 0
@property
def epoch(self):
return self._epoch
def __next__(self):
try:
data = next(self.iter_loader) # 若没到Max_iter,则就正常取一个batch数据
except StopIteration:
self._epoch += 1 # 若迭代完当前epoch,则重新转成迭代器,重新取一个batch数据
if hasattr(self._dataloader.sampler, 'set_epoch'):
self._dataloader.sampler.set_epoch(self._epoch)
time.sleep(2) # Prevent possible deadlock during epoch transition
self.iter_loader = iter(self._dataloader) # 重新转成迭代器
data = next(self.iter_loader) # 调用下一批次数据
return data
def __len__(self):
return len(self._dataloader)
@RUNNERS.register_module()
class IterBasedRunner(BaseRunner):
"""Iteration-based Runner.
This runner train models iteration by iteration.
"""
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._epoch = data_loader.epoch
data_batch = next(data_loader) # 将data_loader变成迭代器
self.call_hook('before_train_iter')
outputs = self.model.train_step(data_batch, self.optimizer, **kwargs) # 执行前向传播
self.outputs = outputs
self.call_hook('after_train_iter')
self._inner_iter += 1
self._iter += 1
@torch.no_grad()
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
data_batch = next(data_loader) # 将data_loader变成迭代器
self.call_hook('before_val_iter')
outputs = self.model.val_step(data_batch, **kwargs) # 执行前向传播
self.outputs = outputs
self.call_hook('after_val_iter')
self._inner_iter += 1
def run(self, data_loaders, workflow, max_iters=None, **kwargs):
self.call_hook('before_run')
iter_loaders = [IterLoader(x) for x in data_loaders] # 转成迭代器
self.call_hook('before_epoch')
while self.iter < self._max_iters: # 不到最大迭代轮数就不停止
for i, flow in enumerate(workflow):
self._inner_iter = 0
mode, iters = flow
iter_runner = getattr(self, mode) # 调用train OR val
for _ in range(iters):
if mode == 'train' and self.iter >= self._max_iters: # 若超出max_iter则break
break
iter_runner(iter_loaders[i], **kwargs) # 执行前向传播计算
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_epoch')
self.call_hook('after_run')
def save_checkpoint(self,
out_dir,
filename_tmpl='iter_{}.pth',
meta=None,
save_optimizer=True,
create_symlink=True):
pass
IterBaseRunner大同小异,同样实现了四个方法。唯一和EpochBaseRunner区别是没有实现run_iter方法。由于该runner以最大迭代轮数训练,故分别在train和val方法中实现了run_iter的计算。另外,多了一个IterLoader类,作用是当迭代完一个epoch后,重新遍历数据,此时用该类就可以用try-except实现重新迭代,可以看我的注释。
总结
本文介绍了mmcv中runner介绍,基本所有mmdet模型都用到上述两个runner。