目录
一、run
python
runner.run(data_loaders, cfg.workflow)
run 方法调用后才是真正开启工作流
workflow = [('train', 1)],表示只运行训练工作流
workflow = [('train', 2), ('val',1)],表示先训练2个 epoch ,然后切换到 val 工作流,运行 1 个 epoch,然后循环,直到训练 epoch 次数达到指定值
workflow = [('val', 1), ('train',1)],表示先验证1 个 epoch, 再训练1 个 epoch
run 方法中定义的是通用工作流切换流程,真正完成一个 epoch 工作流是调用了工作流函数。目前支持 train 和 val 两个工作流,那么 epoch_runner(data_loaders[i], **kwargs) 调用的实际上是 train 或者 val 方法:
python
def run(self, data_loaders, workflow, max_epochs=None, **kwargs):
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
if max_epochs is not None:
warnings.warn(
'setting max_epochs in run is deprecated, '
'please set max_epochs in runner_config', DeprecationWarning)
self._max_epochs = max_epochs
assert self._max_epochs is not None, (
'max_epochs must be specified during instantiation')
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode == 'train':
self._max_iters = self._max_epochs * len(data_loaders[i])
break
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('Hooks will be executed in the following order:\n%s',
self.get_hook_info())
self.logger.info('workflow: %s, max: %d epochs', workflow,
self._max_epochs)
self.call_hook('before_run')
while self.epoch < self._max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if isinstance(mode, str): # self.train()
if not hasattr(self, mode):
raise ValueError(
f'runner has no method named "{mode}" to run an '
'epoch')
epoch_runner = getattr(self, mode)
else:
raise TypeError(
'mode in workflow must be a str, but got {}'.format(
type(mode)))
for _ in range(epochs):
if mode == 'train' and self.epoch >= self._max_epochs:
break
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
二、train
遍历 data_loader,然后进行 batch 级别的迭代训练。真正完成一个 batch 的训练是调用了self.run_iter(data_batch, train_mode=True, **kwargs),同时调用了四次call_hook函数:
self.call_hook('before_train_epoch')
self.call_hook('before_train_iter')
self.call_hook('after_train_iter')
self.call_hook('after_train_epoch')
python
def train(self, data_loader, **kwargs):
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
self._iter += 1
self.call_hook('after_train_epoch')
self._epoch += 1
三、val
遍历 data_loader,然后进行 batch 级别的迭代验证。真正完成一个 batch 的验证是调用了 self.run_iter(data_batch, train_mode=False),同时调用了四次call_hook函数
self.call_hook('before_val_epoch')
self.call_hook('before_val_iter')
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
python
@torch.no_grad()
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self._inner_iter = i
self.call_hook('before_val_iter')
self.run_iter(data_batch, train_mode=False)
self.call_hook('after_val_iter')
self.call_hook('after_val_epoch')
四、run_iter
从上面train 和 val函数中可以看到,真正完成一个 batch 的训练或者验证是调用了 self.run_iter(),在函数run_iter()中调用 model 自身的 train_step 或者 val_step 方法。
model是一个MMDataParallel的类,所以实际上是调用了MMDataParallel的train_step()。
MMDataParallel里面还有一个module,是一个FasterRCNN类,
python
def run_iter(self, data_batch, train_mode, **kwargs):
if self.batch_processor is not None:
outputs = self.batch_processor(
self.model, data_batch, train_mode=train_mode, **kwargs)
elif train_mode:
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
'and "model.val_step()" must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
下面代码给出了类之间的继承关系,run_iter()中调用的 train_step,train_step()中调用的self(**data),self是FasterRCNN()类,这个方法执行的是__call__,查询父类,只有Module类中定义了_call_impl函数
python
class FasterRCNN(TwoStageDetector):
class TwoStageDetector(BaseDetector):
class BaseDetector(BaseModule, metaclass=ABCMeta):
def train_step(self, data, optimizer):
losses = self(**data)
loss, log_vars = self._parse_losses(losses)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))
return outputs
class BaseModule(nn.Module, metaclass=ABCMeta):
class Module:
def _call_impl(self, *input, **kwargs):
for hook in itertools.chain(
_global_forward_pre_hooks.values(),
self._forward_pre_hooks.values()):
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
for hook in itertools.chain(
_global_forward_hooks.values(),
self._forward_hooks.values()):
hook_result = hook(self, input, result)
if hook_result is not None:
result = hook_result
if (len(self._backward_hooks) > 0) or (len(_global_backward_hooks) > 0):
var = result
while not isinstance(var, torch.Tensor):
if isinstance(var, dict):
var = next((v for v in var.values() if isinstance(v, torch.Tensor)))
else:
var = var[0]
grad_fn = var.grad_fn
if grad_fn is not None:
for hook in itertools.chain(
_global_backward_hooks.values(),
self._backward_hooks.values()):
wrapper = functools.partial(hook, self)
functools.update_wrapper(wrapper, hook)
grad_fn.register_hook(wrapper)
return result
__call__ : Callable[..., Any] = _call_impl
五、save_checkpoint
保存权重参数
python
def save_checkpoint(self,
out_dir,
filename_tmpl='epoch_{}.pth',
save_optimizer=True,
meta=None,
create_symlink=True):
if meta is None:
meta = {}
elif not isinstance(meta, dict):
raise TypeError(
f'meta should be a dict or None, but got {type(meta)}')
if self.meta is not None:
meta.update(self.meta)
# Note: meta.update(self.meta) should be done before
# meta.update(epoch=self.epoch + 1, iter=self.iter) otherwise
# there will be problems with resumed checkpoints.
# More details in https://github.com/open-mmlab/mmcv/pull/1108
meta.update(epoch=self.epoch + 1, iter=self.iter)
filename = filename_tmpl.format(self.epoch + 1)
filepath = osp.join(out_dir, filename)
optimizer = self.optimizer if save_optimizer else None
save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
# in some environments, `os.symlink` is not supported, you may need to
# set `create_symlink` to False
if create_symlink:
dst_file = osp.join(out_dir, 'latest.pth')
if platform.system() != 'Windows':
mmcv.symlink(filename, dst_file)
else:
shutil.copy(filepath, dst_file)