MMLab中自定义模块初始化方法

这方面好像介绍的少,看了下基本原理,下面介绍下如何在搭载权重的模型中加入自定义模块时方便的进行初始化。

MMLab的逻辑时对每个部分进行初始化,若此部分定义了初始化方法为Pretrained则加载权重,然后对内部其他模块不再进行其他的初始化操作。

但其存在着一个问题。例如,需要对backbone改进,加入自定义模块后同时需要原模型的预训练权重,此时无法方便的对新加入模块进行初始化操作(因为代码决定其会跳过了此部分初始化,直接进行下一部分的初始化操作)。源代码如下(在BaseModule)中:

        if not self._is_init:
            if self.init_cfg:
                print_log(
                    f'initialize {module_name} with init_cfg {self.init_cfg}',
                    logger=logger_name)
                initialize(self, self.init_cfg)
                if isinstance(self.init_cfg, dict):
                    # prevent the parameters of
                    # the pre-trained model
                    # from being overwritten by
                    # the `init_weights`
                    if self.init_cfg['type'] == 'Pretrained':
                        return

            for m in self.children():
                if hasattr(m, 'init_weights'):
                    m.init_weights()
                    # users may overload the `init_weights`
                    update_init_info(
                        m,
                        init_info=f'Initialized by '
                        f'user-defined `init_weights`'
                        f' in {m.__class__.__name__} ')

            self._is_init = True

那么如何对自定义模块方便地进行初始化呢,下面介绍三种方法:

(1)定义一个my_weight_init()对自定义模块中的所有module进行初始化操作,其优点是可操作性强,但设置复杂。代码如下:

def my_module_weights_init(target_module):
    for m in target_module.modules():
        if type(m) == nn.Conv2d:
            nn.init.xavier_normal_(m.weight.data)
            nn.init.constant_(m.bias.data, 0.0)

对自定义的模块的初始化直接调用apply即可。

(2)对于MMLab中定义好的模块,若其存在init_cfg则可直接输入相关设置参数进行初始化操作。

(3)最为方便的方法,在mmcv.cnn.utils.weight_init中存在initialize函数,可通过相关参数对函数内部所有相关层进行初始化操作,主要原理是建立初始化器的实例化对象,对模块参数进行处理。mmcv中目前可调用一下八种方法进行初始化,位于mmcv.cnn.utilsz中。

'ConstantInit', 'XavierInit', 'NormalInit', 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit', 'Caffe2XavierInit'

上述initialize方法的相关代码如下:

def my_module_weights_init(target_module, init_cfg):
    from mmcv.cnn.utils.weight_init import initialize
    initialize(target_module, init_cfg)

调用初始化方法的代码(可直接调用initialize方法,我为了方便好看改了个名):

if self.training:
    my_module_init_cfg = [dict(type='TruncNormal', layer=['Conv2d', 'Linear'], std=.02, bias=0.), dict(type='Constant', layer=['LayerNorm'], val=1., bias=0.),]
    my_module_weights_init(self.gt_seg_downsample_layers, my_module_init_cfg)

欢迎补充其他方便的方法。

相关推荐
数据小爬虫@2 小时前
深入解析:使用 Python 爬虫获取苏宁商品详情
开发语言·爬虫·python
健胃消食片片片片2 小时前
Python爬虫技术:高效数据收集与深度挖掘
开发语言·爬虫·python
王老师青少年编程3 小时前
gesp(C++五级)(14)洛谷:B4071:[GESP202412 五级] 武器强化
开发语言·c++·算法·gesp·csp·信奥赛
井底哇哇4 小时前
ChatGPT是强人工智能吗?
人工智能·chatgpt
Coovally AI模型快速验证4 小时前
MMYOLO:打破单一模式限制,多模态目标检测的革命性突破!
人工智能·算法·yolo·目标检测·机器学习·计算机视觉·目标跟踪
一只小bit4 小时前
C++之初识模版
开发语言·c++
AI浩4 小时前
【面试总结】FFN(前馈神经网络)在Transformer模型中先升维再降维的原因
人工智能·深度学习·计算机视觉·transformer
王磊鑫4 小时前
C语言小项目——通讯录
c语言·开发语言
钢铁男儿4 小时前
C# 委托和事件(事件)
开发语言·c#
可为测控4 小时前
图像处理基础(4):高斯滤波器详解
人工智能·算法·计算机视觉