模型蒸馏学习

知识蒸馏:获取学生网络和教师网络指定蒸馏位点输出特征并计算蒸馏 loss 的过程

知乎-mmrazor-模型蒸馏

知识蒸馏算法往往分为 reponse-based基于响应、feature-based基于特征 和 relation-based基于关系三类。

也可为 data-free KD、online KD、self KD(可视为一种特殊的 online KD)和比较经典的 offline KD

  • feature-based 方法以教师模型特征提取器产生的中间层特征为学习对象
    而ChannelWiseDivergence(cwd算法)使用的是预测之前的logistic特征图(在channel维度上取最大值,即为最终的预测结果),就是feature-based的蒸馏方法

mmrazor可以使用不同架构的student和teacher模型

  • 可以使用connector对不同维度进行对齐,以计算语义分割的蒸馏loss
  • 对于 feature-base 的方法,当学生和教师网络输出特征维度不同时,往往会对学生网络对应特征进行后处理以保证蒸馏 loss 正确计算(connector实现)

mmseg的模型,在deconde_head.conv_seg后拿到的特征图为logist的特征图,其内的元素都为小数,而非预测的0/1

cwd算法是一种什么样的蒸馏算法呢是data-free的吗,还是online的呢,还是offline的呢

适用于mmseg的cwd模型蒸馏配置文件(default_hook中的Student_CheckpointHook是自定义的hook,继承自mmegin中的CheekpointHook)

cwd算法:首先使用softmax归一化方法将每个通道的feature map转换成一个分布,然后最小化两个网络对应通道的Kullback Leibler (KL)散度。通过这样做,我们的方法着重于模拟网络间通道的软分布。特别的是,KL的差异使学习能够更多地关注通道图中最突出的区域,大概对应于语义分割最有用的信号

  • 现象:
    由于model capacity gap的存在,student往往弱于teacher模型,但也并不绝对,如果model本身的gap不是很离谱,student还是有超越teacher的可能的,因为student模型一般可以学习teacher模型蒸馏位点的特征和ground truth多种知识,学习效率会更高,如果本身student没有太大的问题,还是有机会学的更好的。
python 复制代码
_base_ = [
    'mmseg::_base_/datasets/pascal_voc12.py',
    'mmseg::_base_/schedules/schedule_160k.py',
    'mmseg::_base_/default_runtime.py'
]

# 模型的optim_wrapper,学习率和学习策略将来自于继承的schedule_160k,如果不改的话
# wandb的可视化设置在mmseg的default_runtime,也继承自mmseg

# schedule_160k.py中的自动保存权重的部分
default_hooks = dict(_delete_=True,
    timer=dict(type='IterTimerHook'),
    logger=dict(type='LoggerHook', interval=100, log_metric_by_epoch=False),
    param_scheduler=dict(type='ParamSchedulerHook'),
    # 使用了自定义的Student_CheckpointHook
    checkpoint=dict(type='Student_CheckpointHook', by_epoch=False, interval=-1, max_keep_ckpts=2, save_best=['mDice', 'mIoU']),
    # checkpoint中,interval=-1则不会保存least.pth
    sampler_seed=dict(type='DistSamplerSeedHook'),
    visualization=dict(type='SegVisualizationHook'))

teacher_ckpt = '/root/autodl-tmp/all_workdir/mmseg_work_dir/baseline-convnext-tiny-upernet-rotate/best_mDice_iter_6800.pth'  # noqa: E501
teacher_cfg_path = 'mmseg::all_changed/baseline-convnext-tiny_upernet-rotate.py'  # noqa: E501

student_cfg_path = 'mmseg::all_changed/pspnet_r18-d8_b16-160k_voc-material-512x512.py'  # noqa: E501
model = dict(
    _scope_='mmrazor',
    type='SingleTeacherDistill',
    architecture=dict(cfg_path=student_cfg_path, pretrained=False),
    teacher=dict(cfg_path=teacher_cfg_path, pretrained=False),
    teacher_ckpt=teacher_ckpt,
    distiller=dict(
        type='ConfigurableDistiller',
        distill_losses=dict(
            loss_cwd=dict(type='ChannelWiseDivergence', tau=1, loss_weight=5)),
        student_recorders=dict(
            logits=dict(type='ModuleOutputs', source='decode_head.conv_seg')),
        teacher_recorders=dict(
            logits=dict(type='ModuleOutputs', source='decode_head.conv_seg')),
        connectors=dict(
            loss_conv_stu=dict(type='ConvModuleConncetor', in_channel=2, out_channel=2, kernel_size=1, stride=1, padding=0,
                              norm_cfg=dict(type='BN')),
            loss_conv_tea=dict(type='ConvModuleConncetor', in_channel=2, out_channel=2, kernel_size=3, stride=2, padding=1, padding_mode='circular',
                              norm_cfg=dict(type='BN'))),
        loss_forward_mappings=dict(
            loss_cwd=dict(
                preds_S=dict(from_student=True, recorder='logits', connector='loss_conv_stu'), # 含义:从student_recorders(from_student=True)中读取名为logits的数据
                # 加上connnecor字段后,表示从student_recorders中读取名为logists的数据,而后将数据通过名为loss_conv_stu的连接器
                preds_T=dict(from_student=False, recorder='logits', connector='loss_conv_tea')))))
                # 从teacher_recorders中读取名为logits的数据,而后将数据通过名为loss_conv_tea的连接器
                # 而无论是loss_cwd、logits、loss_conv_stu、loss_conv_tea都是自定的名称

find_unused_parameters = True


train_cfg = dict(
    type='IterBasedTrainLoop', max_iters=160000, val_interval=200)
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')

train_dataloader = dict(batch_size=16)  # 更改batch_size,否则会继承到pascal_voc12.py中的设置
# 这个16会作为teacher模型的推理batch和student模型的训练batch

work_dir = '/root/autodl-tmp/all_workdir/mmrazor_wokdir/distill/convnext-tiny-upernet_to_pspnet-r18'

模型蒸馏一

    1. 基于响应的KD(DKD ,FitNets):
      基于响应的KD以teacher模型的分类预测结果为目标知识,具体指的是分类器最后一个全连接层的输出(成为logits)。与最终的输出相比,logits没有经过softmax进行归一化,非目标类对应的输出值尚未被抑制。
      教师模型和学生模型之间的损失差异一般用KD散度,一般会用temperature(tau)大于1的参数对logits进行软化,以减小目标类和非目标类的预测值差异。
      logits具备的含义为模型判断当前样本为各类别的信心为多少
      1) logits提供的软标签信息,比one-hot的真实标签有着更高的熵值,从而提供了更多的信息量和数据之间更小梯度差异
      2) 软标签有着与标签平滑类似的效果,提高了模型的泛化能力
      3)除了gt标签外,还学习了软标签,使得模型学到了更多的知识,更倾向于学到不同的知识,优化方向更稳定
    1. 基于特征的KD(AB,AT,ofd,Factor Transfer):
      蒸馏位点位于模型中途获得的特征
      通常,teacher模型的通道大于学生通道,二者无法完全对齐,一般在学生的特征图后面接卷积,将两者在维度和通道上对齐,从而实现特征点的一一对应
      1)特征维度对齐,特征加权,mmrazor的connector模块的抽象
      2)知识定位,设计规则选出更为重要的教师特征
    1. 基于关系的KD(FSP, RKD):也使用特征,但计算不是特征点的一对一差异,而是特征关系的差异
      1)样本间关系蒸馏:在分类和分割中应用广泛,因为构建高质量的关系矩阵需要大量样本

总结:在语义分割中的cwd算法,可以看作是基于响应的KD,也可以看作是基于特征的KD,因为在传统的cwd算法中,使用的是在通过softmax之前的位置作为蒸馏位点,输出对应的特征图,去计算损失。

相关推荐
知识分享小能手11 小时前
React学习教程,从入门到精通, React 属性(Props)语法知识点与案例详解(14)
前端·javascript·vue.js·学习·react.js·vue·react
茯苓gao14 小时前
STM32G4 速度环开环,电流环闭环 IF模式建模
笔记·stm32·单片机·嵌入式硬件·学习
是誰萆微了承諾14 小时前
【golang学习笔记 gin 】1.2 redis 的使用
笔记·学习·golang
DKPT15 小时前
Java内存区域与内存溢出
java·开发语言·jvm·笔记·学习
aaaweiaaaaaa15 小时前
HTML和CSS学习
前端·css·学习·html
看海天一色听风起雨落16 小时前
Python学习之装饰器
开发语言·python·学习
speop17 小时前
llm的一点学习笔记
笔记·学习
非凡ghost17 小时前
FxSound:提升音频体验,让音乐更动听
前端·学习·音视频·生活·软件需求
ue星空17 小时前
月2期学习笔记
学习·游戏·ue5
萧邀人18 小时前
第二课、熟悉Cocos Creator 编辑器界面
学习