MMLab中自定义模块初始化方法

这方面好像介绍的少,看了下基本原理,下面介绍下如何在搭载权重的模型中加入自定义模块时方便的进行初始化。

MMLab的逻辑时对每个部分进行初始化,若此部分定义了初始化方法为Pretrained则加载权重,然后对内部其他模块不再进行其他的初始化操作。

但其存在着一个问题。例如,需要对backbone改进,加入自定义模块后同时需要原模型的预训练权重,此时无法方便的对新加入模块进行初始化操作(因为代码决定其会跳过了此部分初始化,直接进行下一部分的初始化操作)。源代码如下(在BaseModule)中:

复制代码
        if not self._is_init:
            if self.init_cfg:
                print_log(
                    f'initialize {module_name} with init_cfg {self.init_cfg}',
                    logger=logger_name)
                initialize(self, self.init_cfg)
                if isinstance(self.init_cfg, dict):
                    # prevent the parameters of
                    # the pre-trained model
                    # from being overwritten by
                    # the `init_weights`
                    if self.init_cfg['type'] == 'Pretrained':
                        return

            for m in self.children():
                if hasattr(m, 'init_weights'):
                    m.init_weights()
                    # users may overload the `init_weights`
                    update_init_info(
                        m,
                        init_info=f'Initialized by '
                        f'user-defined `init_weights`'
                        f' in {m.__class__.__name__} ')

            self._is_init = True

那么如何对自定义模块方便地进行初始化呢,下面介绍三种方法:

(1)定义一个my_weight_init()对自定义模块中的所有module进行初始化操作,其优点是可操作性强,但设置复杂。代码如下:

复制代码
def my_module_weights_init(target_module):
    for m in target_module.modules():
        if type(m) == nn.Conv2d:
            nn.init.xavier_normal_(m.weight.data)
            nn.init.constant_(m.bias.data, 0.0)

对自定义的模块的初始化直接调用apply即可。

(2)对于MMLab中定义好的模块,若其存在init_cfg则可直接输入相关设置参数进行初始化操作。

(3)最为方便的方法,在mmcv.cnn.utils.weight_init中存在initialize函数,可通过相关参数对函数内部所有相关层进行初始化操作,主要原理是建立初始化器的实例化对象,对模块参数进行处理。mmcv中目前可调用一下八种方法进行初始化,位于mmcv.cnn.utilsz中。

复制代码
'ConstantInit', 'XavierInit', 'NormalInit', 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit', 'Caffe2XavierInit'

上述initialize方法的相关代码如下:

复制代码
def my_module_weights_init(target_module, init_cfg):
    from mmcv.cnn.utils.weight_init import initialize
    initialize(target_module, init_cfg)

调用初始化方法的代码(可直接调用initialize方法,我为了方便好看改了个名):

复制代码
if self.training:
    my_module_init_cfg = [dict(type='TruncNormal', layer=['Conv2d', 'Linear'], std=.02, bias=0.), dict(type='Constant', layer=['LayerNorm'], val=1., bias=0.),]
    my_module_weights_init(self.gt_seg_downsample_layers, my_module_init_cfg)

欢迎补充其他方便的方法。

相关推荐
凡人叶枫2 分钟前
C++中智能指针详解(Linux实战版)| 彻底解决内存泄漏,新手也能吃透
java·linux·c语言·开发语言·c++·嵌入式开发
Tony Bai2 分钟前
再见,丑陋的 container/heap!Go 泛型堆 heap/v2 提案解析
开发语言·后端·golang
SEO_juper11 分钟前
2026内容营销破局指南:告别流量内卷,以价值赢信任
人工智能·ai·数字营销·2026
初恋叫萱萱14 分钟前
数据即燃料:用 `cann-data-augmentation` 实现高效训练预处理
人工智能
小糯米60123 分钟前
C++顺序表和vector
开发语言·c++·算法
一战成名99623 分钟前
CANN 仓库揭秘:昇腾 AI 算子开发的宝藏之地
人工智能
froginwe1129 分钟前
JavaScript 函数调用
开发语言
hnult29 分钟前
2026 在线培训考试系统选型指南:核心功能拆解与选型逻辑
人工智能·笔记·课程设计
A小码哥30 分钟前
AI 设计时代的到来:从 PS 到 Pencil,一个人如何顶替一个团队
人工智能
阔皮大师33 分钟前
INote轻量文本编辑器
java·javascript·python·c#