MMLab中自定义模块初始化方法

这方面好像介绍的少,看了下基本原理,下面介绍下如何在搭载权重的模型中加入自定义模块时方便的进行初始化。

MMLab的逻辑时对每个部分进行初始化,若此部分定义了初始化方法为Pretrained则加载权重,然后对内部其他模块不再进行其他的初始化操作。

但其存在着一个问题。例如,需要对backbone改进,加入自定义模块后同时需要原模型的预训练权重,此时无法方便的对新加入模块进行初始化操作(因为代码决定其会跳过了此部分初始化,直接进行下一部分的初始化操作)。源代码如下(在BaseModule)中:

        if not self._is_init:
            if self.init_cfg:
                print_log(
                    f'initialize {module_name} with init_cfg {self.init_cfg}',
                    logger=logger_name)
                initialize(self, self.init_cfg)
                if isinstance(self.init_cfg, dict):
                    # prevent the parameters of
                    # the pre-trained model
                    # from being overwritten by
                    # the `init_weights`
                    if self.init_cfg['type'] == 'Pretrained':
                        return

            for m in self.children():
                if hasattr(m, 'init_weights'):
                    m.init_weights()
                    # users may overload the `init_weights`
                    update_init_info(
                        m,
                        init_info=f'Initialized by '
                        f'user-defined `init_weights`'
                        f' in {m.__class__.__name__} ')

            self._is_init = True

那么如何对自定义模块方便地进行初始化呢,下面介绍三种方法:

(1)定义一个my_weight_init()对自定义模块中的所有module进行初始化操作,其优点是可操作性强,但设置复杂。代码如下:

def my_module_weights_init(target_module):
    for m in target_module.modules():
        if type(m) == nn.Conv2d:
            nn.init.xavier_normal_(m.weight.data)
            nn.init.constant_(m.bias.data, 0.0)

对自定义的模块的初始化直接调用apply即可。

(2)对于MMLab中定义好的模块,若其存在init_cfg则可直接输入相关设置参数进行初始化操作。

(3)最为方便的方法,在mmcv.cnn.utils.weight_init中存在initialize函数,可通过相关参数对函数内部所有相关层进行初始化操作,主要原理是建立初始化器的实例化对象,对模块参数进行处理。mmcv中目前可调用一下八种方法进行初始化,位于mmcv.cnn.utilsz中。

'ConstantInit', 'XavierInit', 'NormalInit', 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit', 'Caffe2XavierInit'

上述initialize方法的相关代码如下:

def my_module_weights_init(target_module, init_cfg):
    from mmcv.cnn.utils.weight_init import initialize
    initialize(target_module, init_cfg)

调用初始化方法的代码(可直接调用initialize方法,我为了方便好看改了个名):

if self.training:
    my_module_init_cfg = [dict(type='TruncNormal', layer=['Conv2d', 'Linear'], std=.02, bias=0.), dict(type='Constant', layer=['LayerNorm'], val=1., bias=0.),]
    my_module_weights_init(self.gt_seg_downsample_layers, my_module_init_cfg)

欢迎补充其他方便的方法。

相关推荐
陌小呆^O^几秒前
Cmakelist.txt之win-c-udp-server
c语言·开发语言·udp
lindsayshuo2 分钟前
jetson orin系列开发版安装cuda的gpu版本的opencv
人工智能·opencv
向阳逐梦2 分钟前
ROS机器视觉入门:从基础到人脸识别与目标检测
人工智能·目标检测·计算机视觉
Gu Gu Study7 分钟前
枚举与lambda表达式,枚举实现单例模式为什么是安全的,lambda表达式与函数式接口的小九九~
java·开发语言
Mr.Q10 分钟前
OpenCV和Qt坐标系不一致问题
qt·opencv
时光の尘22 分钟前
C语言菜鸟入门·关键字·float以及double的用法
运维·服务器·c语言·开发语言·stm32·单片机·c
陈鋆27 分钟前
智慧城市初探与解决方案
人工智能·智慧城市
qdprobot28 分钟前
ESP32桌面天气摆件加文心一言AI大模型对话Mixly图形化编程STEAM创客教育
网络·人工智能·百度·文心一言·arduino
QQ395753323728 分钟前
金融量化交易模型的突破与前景分析
人工智能·金融
QQ395753323729 分钟前
金融量化交易:技术突破与模型优化
人工智能·金融