目录
[🌟 🌟EpochBased 配置举例](#🌟 🌟EpochBased 配置举例)
[🌸🌸按照 iter 训练模型](#🌸🌸按照 iter 训练模型)
[🌸🌸按照 epoch 训练模型](#🌸🌸按照 epoch 训练模型)
🎵🎵背景
- 上一篇讲到如何基于mmengine设置train_cfg参数,感兴趣的童鞋可以移步。
- 上一篇主要讲解通过RUNNER.ITERBASEDTRAINLOOP与RUNNER.EPOCHBASEDTRAINLOOP 源码解析说明了基于迭代次数和轮数的config文件设置。本篇主要讲解除train_cfg之外,由EPOCHBASED 切换至 ITERBASED的其他参数的完整设置。
👉👉动机
基于MMEngine做模型训练,设置各种hook时,总是看不到源码,只能按照既定模式进行网络训练,要修改就得自己试参数,索性咱们就一次深挖到底,看看最底层的代码是如何写的,就不用每次猜参数了。
MMEngine 支持两种训练模式:
- 基于轮次的 EpochBased 方式
- 基于迭代次数的 IterBased 方式
这两种方式在下游算法库均有使用,例如MMDetection 默认使用 EpochBased 方式,MMSegmentation默认使用 IterBased 方式。本篇主要讲解除train_cfg参数之外,由EPOCHBASED 切换至 ITERBASED的其他参数的完整设置。
🌟 🌟EpochBased 配置举例
MMEngine 很多模块默认以 EpochBased 的模式执行,如 ParamScheduler
, LoggerHook
, CheckpointHook
等,常见的 EpochBased 配置写法如下:
python
param_scheduler = dict(
type='MultiStepLR',
milestones=[6, 8]
by_epoch=True # by_epoch 默认为 True,这边显式的写出来只是为了方便对比
)
default_hooks = dict(
logger=dict(type='LoggerHook'),
checkpoint=dict(type='CheckpointHook', interval=2),
)
train_cfg = dict(
by_epoch=True, # by_epoch 默认为 True,这边显式的写出来只是为了方便对比
max_epochs=10,
val_interval=2
)
log_processor = dict(
by_epoch=True
) # log_processor 的 by_epoch 默认为 True,这边显式的写出来只是为了方便对比, 实际上不需要设置
runner = Runner(
model=ResNet18(),
work_dir='./work_dir',
train_dataloader=train_dataloader_cfg,
optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.001, momentum=0.9)),
param_scheduler=param_scheduler
default_hooks=default_hooks,
log_processor=log_processor,
train_cfg=train_cfg,
resume=True,
)
🌸🌸按照 iter 训练模型
如果想按照 iter 训练模型,需要做以下改动:
🌷train_cfg
- 将
train_cfg
中的by_epoch
设置为False
,同时将max_iters
设置为训练的总 iter 数,val_iterval
设置为验证间隔的 iter 数。示例代码如下:
python
train_cfg = dict(
by_epoch=False,
max_iters=10000,
val_interval=2000
)
🌷default_hooks
- 将
default_hooks
中的logger
的log_metric_by_epoch
设置为 False,checkpoint
的by_epoch
设置为False
。示例代码如下:
python
default_hooks = dict(
logger=dict(type='LoggerHook', log_metric_by_epoch=False),
checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
)
🌷param_scheduler
- 将
param_scheduler
中的by_epoch
设置为False
,并將epoch
相关的参数换算成iter。
示例代码如下:
python
param_scheduler = dict(
type='MultiStepLR',
milestones=[6000, 8000],
by_epoch=False,
)
🌷log_processor
- 将
log_processor
的by_epoch
设置为False
。示例代码如下:
python
log_processor = dict(
by_epoch=False
)
📢📢注意:如果你能保证IterBasedTraining 和 EpochBasedTraining 总 iter 数 一致,直接设置 convert_to_iter_based 为 True 即可。
🌸🌸按照 epoch 训练模型
🌷train_cfg
如果想按照 epoch 训练模型,需要做以下改动:
- 将
train_cfg
中的by_epoch
设置为True
,同时将max_epochs
设置为训练的总epoch数,val_iterval
设置为验证间隔的 epoch数。示例代码如下:
python
train_cfg = dict(
by_epoch=True,
max_epochs=10,
val_interval=2
)
🌷default_hooks
- 将
default_hooks
中的logger
的log_metric_by_epoch
设置为True
,checkpoint
的by_epoch
设置为True
。示例代码如下:
python
default_hooks = dict(
logger=dict(type='LoggerHook', log_metric_by_epoch=True),
checkpoint=dict(type='CheckpointHook', interval=2, by_epoch=True),
)
🌷param_scheduler
- 将
param_scheduler
中的by_epoch
设置为True
,并將iter
相关的参数换算成epoch
。
示例代码如下:
python
param_scheduler = dict(
type='MultiStepLR',
milestones=[6, 8],
by_epoch=True,
)
🌷log_processor
- 将
log_processor
的by_epoch
设置为True
。示例代码如下:
python
log_processor = dict(
by_epoch=True
)
📢📢注意:如果你能保证IterBasedTraining 和 EpochBasedTraining 总 iter 数 一致,直接设置 convert_to_iter_based 为 True 即可。示例代码如下:
🐸🐸注意事项
如果基础配置文件为 train_dataloader 配置了基于iteration/epoch 采样的 sampler,则需要在当前配置文件中将其更改为指定类型的 sampler,或将其设置为 None 。当 dataloader 中的 sampler 为 None,MMEngine 或根据 train_cfg 中的 by_epoch参数选择 InfiniteSampler
(False) 或 DefaultSampler
(True)。
如果基础配置文件在 ++train_cfg 中指定了 type++,那么必须在当前配置文件中将 type 覆盖为(IterBasedTrainLoop 或 EpochBasedTrainLoop ),而++不能简单的指定 by_epoch++ 参数。
✌️完整mmengine源码:链接
整理不易,欢迎一键三连!!!
送你们一条美丽的--分割线--
🌷🌷🍀🍀🌾🌾🍓🍓🍂🍂🙋🙋🐸🐸🙋🙋💖💖🍌🍌🔔🔔🍉🍉🍭🍭🍋🍋🍇🍇🏆🏆📸📸⛵⛵⭐⭐🍎🍎👍👍🌷🌷