【MMEngine】由EpochBased 切换至 IterBased参数修改解析及源码讲解

目录

🎵🎵背景

👉👉动机

[🌟 🌟EpochBased 配置举例](#🌟 🌟EpochBased 配置举例)

[🌸🌸按照 iter 训练模型](#🌸🌸按照 iter 训练模型)

🌷train_cfg

🌷default_hooks

🌷param_scheduler

🌷log_processor

[🌸🌸按照 epoch 训练模型](#🌸🌸按照 epoch 训练模型)

🌷train_cfg

🌷default_hooks

🌷param_scheduler

🌷log_processor

🐸🐸注意事项

✌️完整mmengine源码:链接

整理不易,欢迎一键三连!!!送你们一条美丽的--分割线--


🎵🎵背景

  • 上一篇讲到如何基于mmengine设置train_cfg参数,感兴趣的童鞋可以移步。
  • 上一篇主要讲解通过RUNNER.ITERBASEDTRAINLOOP与RUNNER.EPOCHBASEDTRAINLOOP 源码解析说明了基于迭代次数和轮数的config文件设置。本篇主要讲解除train_cfg之外,由EPOCHBASED 切换至 ITERBASED的其他参数的完整设置。

👉👉动机

基于MMEngine做模型训练,设置各种hook时,总是看不到源码,只能按照既定模式进行网络训练,要修改就得自己试参数,索性咱们就一次深挖到底,看看最底层的代码是如何写的,就不用每次猜参数了。

MMEngine 支持两种训练模式:

  • 基于轮次的 EpochBased 方式
  • 基于迭代次数的 IterBased 方式

这两种方式在下游算法库均有使用,例如MMDetection 默认使用 EpochBased 方式,MMSegmentation默认使用 IterBased 方式。本篇主要讲解除train_cfg参数之外,由EPOCHBASED 切换至 ITERBASED的其他参数的完整设置。

🌟 🌟EpochBased 配置举例

MMEngine 很多模块默认以 EpochBased 的模式执行,如 ParamScheduler, LoggerHook, CheckpointHook 等,常见的 EpochBased 配置写法如下:

python 复制代码
param_scheduler = dict(
    type='MultiStepLR',
    milestones=[6, 8]
    by_epoch=True  # by_epoch 默认为 True,这边显式的写出来只是为了方便对比
)

default_hooks = dict(
    logger=dict(type='LoggerHook'),
    checkpoint=dict(type='CheckpointHook', interval=2),
)

train_cfg = dict(
    by_epoch=True,  # by_epoch 默认为 True,这边显式的写出来只是为了方便对比
    max_epochs=10,
    val_interval=2
)

log_processor = dict(
    by_epoch=True
)  # log_processor 的 by_epoch 默认为 True,这边显式的写出来只是为了方便对比, 实际上不需要设置

runner = Runner(
    model=ResNet18(),
    work_dir='./work_dir',
    train_dataloader=train_dataloader_cfg,
    optim_wrapper=dict(optimizer=dict(type='SGD', lr=0.001, momentum=0.9)),
    param_scheduler=param_scheduler
    default_hooks=default_hooks,
    log_processor=log_processor,
    train_cfg=train_cfg,
    resume=True,
)

🌸🌸按照 iter 训练模型

如果想按照 iter 训练模型,需要做以下改动:

🌷train_cfg

  • train_cfg 中的 by_epoch 设置为 False,同时将 max_iters 设置为训练的总 iter 数,val_iterval 设置为验证间隔的 iter 数。示例代码如下:
python 复制代码
train_cfg = dict(
    by_epoch=False,
    max_iters=10000,
    val_interval=2000
)

🌷default_hooks

  • default_hooks 中的 loggerlog_metric_by_epoch 设置为 False, checkpointby_epoch 设置为 False。示例代码如下:
python 复制代码
default_hooks = dict(
    logger=dict(type='LoggerHook', log_metric_by_epoch=False),
    checkpoint=dict(type='CheckpointHook', by_epoch=False, interval=2000),
)

🌷param_scheduler

  • param_scheduler 中的 by_epoch 设置为 False,并將 epoch 相关的参数换算成 iter。示例代码如下:
python 复制代码
param_scheduler = dict(
    type='MultiStepLR',
    milestones=[6000, 8000],
    by_epoch=False,
)

🌷log_processor

  • log_processorby_epoch 设置为 False。示例代码如下:
python 复制代码
log_processor = dict(
    by_epoch=False
)

📢📢注意:如果你能保证IterBasedTrainingEpochBasedTraining 总 iter 数 一致,直接设置 convert_to_iter_based 为 True 即可。

🌸🌸按照 epoch 训练模型

🌷train_cfg

如果想按照 epoch 训练模型,需要做以下改动:

  • train_cfg 中的 by_epoch 设置为 True,同时将 max_epochs 设置为训练的总epoch数,val_iterval 设置为验证间隔的 epoch数。示例代码如下:
python 复制代码
train_cfg = dict(
    by_epoch=True,
    max_epochs=10,
    val_interval=2
)

🌷default_hooks

  • default_hooks 中的 loggerlog_metric_by_epoch 设置为 Truecheckpointby_epoch 设置为 True。示例代码如下:
python 复制代码
default_hooks = dict(
    logger=dict(type='LoggerHook', log_metric_by_epoch=True),
    checkpoint=dict(type='CheckpointHook', interval=2, by_epoch=True),
)

🌷param_scheduler

  • param_scheduler 中的 by_epoch 设置为 True,并將iter 相关的参数换算成 epoch 示例代码如下:
python 复制代码
param_scheduler = dict(
    type='MultiStepLR',
    milestones=[6, 8],
    by_epoch=True,
)

🌷log_processor

  • log_processorby_epoch 设置为 True。示例代码如下:
python 复制代码
log_processor = dict(
    by_epoch=True
)

📢​​​​​​​📢注意:如果你能保证IterBasedTrainingEpochBasedTraining 总 iter 数 一致,直接设置 convert_to_iter_based 为 True 即可。示例代码如下:

🐸🐸注意事项

如果基础配置文件为 train_dataloader 配置了基于iteration/epoch 采样的 sampler,则需要在当前配置文件中将其更改为指定类型的 sampler,或将其设置为 None 。当 dataloader 中的 sampler 为 None,MMEngine 或根据 train_cfg 中的 by_epoch参数选择 InfiniteSampler(False) 或 DefaultSampler(True)

如果基础配置文件在 ++train_cfg 中指定了 type++,那么必须在当前配置文件中将 type 覆盖为(IterBasedTrainLoopEpochBasedTrainLoop ),而++不能简单的指定 by_epoch++ 参数。

✌️完整mmengine源码:链接

整理不易,欢迎一键三连!!!

送你们一条美丽的--分割线--

🌷🌷🍀🍀🌾🌾🍓🍓🍂🍂🙋🙋🐸🐸🙋🙋💖💖🍌🍌🔔🔔🍉🍉🍭🍭🍋🍋🍇🍇🏆🏆📸📸⛵⛵⭐⭐🍎🍎👍👍🌷🌷

相关推荐
童话名剑11 小时前
目标检测(吴恩达深度学习笔记)
人工智能·目标检测·滑动窗口·目标定位·yolo算法·特征点检测
木卫四科技12 小时前
【木卫四 CES 2026】观察:融合智能体与联邦数据湖的安全数据运营成为趋势
人工智能·安全·汽车
吃茄子的猫17 小时前
quecpython中&的具体含义和使用场景
开发语言·python
珠海西格电力17 小时前
零碳园区有哪些政策支持?
大数据·数据库·人工智能·物联网·能源
じ☆冷颜〃17 小时前
黎曼几何驱动的算法与系统设计:理论、实践与跨领域应用
笔记·python·深度学习·网络协议·算法·机器学习
数据大魔方17 小时前
【期货量化实战】日内动量策略:顺势而为的短线交易法(Python源码)
开发语言·数据库·python·mysql·算法·github·程序员创富
启途AI17 小时前
2026免费好用的AIPPT工具榜:智能演示文稿制作新纪元
人工智能·powerpoint·ppt
TH_117 小时前
35、AI自动化技术与职业变革探讨
运维·人工智能·自动化
APIshop17 小时前
Python 爬虫获取 item_get_web —— 淘宝商品 SKU、详情图、券后价全流程解析
前端·爬虫·python
楚来客18 小时前
AI基础概念之八:Transformer算法通俗解析
人工智能·算法·transformer