MMLab中自定义模块初始化方法

这方面好像介绍的少,看了下基本原理,下面介绍下如何在搭载权重的模型中加入自定义模块时方便的进行初始化。

MMLab的逻辑时对每个部分进行初始化,若此部分定义了初始化方法为Pretrained则加载权重,然后对内部其他模块不再进行其他的初始化操作。

但其存在着一个问题。例如,需要对backbone改进,加入自定义模块后同时需要原模型的预训练权重,此时无法方便的对新加入模块进行初始化操作(因为代码决定其会跳过了此部分初始化,直接进行下一部分的初始化操作)。源代码如下(在BaseModule)中:

复制代码
        if not self._is_init:
            if self.init_cfg:
                print_log(
                    f'initialize {module_name} with init_cfg {self.init_cfg}',
                    logger=logger_name)
                initialize(self, self.init_cfg)
                if isinstance(self.init_cfg, dict):
                    # prevent the parameters of
                    # the pre-trained model
                    # from being overwritten by
                    # the `init_weights`
                    if self.init_cfg['type'] == 'Pretrained':
                        return

            for m in self.children():
                if hasattr(m, 'init_weights'):
                    m.init_weights()
                    # users may overload the `init_weights`
                    update_init_info(
                        m,
                        init_info=f'Initialized by '
                        f'user-defined `init_weights`'
                        f' in {m.__class__.__name__} ')

            self._is_init = True

那么如何对自定义模块方便地进行初始化呢,下面介绍三种方法:

(1)定义一个my_weight_init()对自定义模块中的所有module进行初始化操作,其优点是可操作性强,但设置复杂。代码如下:

复制代码
def my_module_weights_init(target_module):
    for m in target_module.modules():
        if type(m) == nn.Conv2d:
            nn.init.xavier_normal_(m.weight.data)
            nn.init.constant_(m.bias.data, 0.0)

对自定义的模块的初始化直接调用apply即可。

(2)对于MMLab中定义好的模块,若其存在init_cfg则可直接输入相关设置参数进行初始化操作。

(3)最为方便的方法,在mmcv.cnn.utils.weight_init中存在initialize函数,可通过相关参数对函数内部所有相关层进行初始化操作,主要原理是建立初始化器的实例化对象,对模块参数进行处理。mmcv中目前可调用一下八种方法进行初始化,位于mmcv.cnn.utilsz中。

复制代码
'ConstantInit', 'XavierInit', 'NormalInit', 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit', 'Caffe2XavierInit'

上述initialize方法的相关代码如下:

复制代码
def my_module_weights_init(target_module, init_cfg):
    from mmcv.cnn.utils.weight_init import initialize
    initialize(target_module, init_cfg)

调用初始化方法的代码(可直接调用initialize方法,我为了方便好看改了个名):

复制代码
if self.training:
    my_module_init_cfg = [dict(type='TruncNormal', layer=['Conv2d', 'Linear'], std=.02, bias=0.), dict(type='Constant', layer=['LayerNorm'], val=1., bias=0.),]
    my_module_weights_init(self.gt_seg_downsample_layers, my_module_init_cfg)

欢迎补充其他方便的方法。

相关推荐
IT考试认证9 分钟前
华为人工智能认证 HCIA-AI Solution H13-313 题库
人工智能·华为·题库·hcia-ai·h13-313
AI technophile25 分钟前
OpenCV计算机视觉实战(31)——人脸识别详解
人工智能·opencv·计算机视觉
S***H28327 分钟前
JavaScript原型链继承
开发语言·javascript·原型模式
kk”27 分钟前
C++ map
开发语言·c++
九河云28 分钟前
汽车轻量化部件智造:碳纤维成型 AI 调控与强度性能数字孪生验证实践
人工智能·汽车·数字化转型
3DVisionary31 分钟前
DIC技术如何重新定义汽车板料成形测试
人工智能·汽车·材料力学性能·dic技术·汽车板料·成形极限图·非接触式测量
5***o50032 分钟前
深度学习代码库
人工智能·深度学习
车端域控测试工程师32 分钟前
Autosar网络管理测试用例 - TC003
c语言·开发语言·学习·汽车·测试用例·capl·canoe
2501_9416649633 分钟前
AI在创意产业的应用:从艺术到娱乐的数字变革
人工智能
二川bro34 分钟前
Python模型优化实战:深度学习加速与压缩技巧
python