模型蒸馏学习

知识蒸馏:获取学生网络和教师网络指定蒸馏位点输出特征并计算蒸馏 loss 的过程

知乎-mmrazor-模型蒸馏

知识蒸馏算法往往分为 reponse-based基于响应、feature-based基于特征 和 relation-based基于关系三类。

也可为 data-free KD、online KD、self KD(可视为一种特殊的 online KD)和比较经典的 offline KD

  • feature-based 方法以教师模型特征提取器产生的中间层特征为学习对象
    而ChannelWiseDivergence(cwd算法)使用的是预测之前的logistic特征图(在channel维度上取最大值,即为最终的预测结果),就是feature-based的蒸馏方法

mmrazor可以使用不同架构的student和teacher模型

  • 可以使用connector对不同维度进行对齐,以计算语义分割的蒸馏loss
  • 对于 feature-base 的方法,当学生和教师网络输出特征维度不同时,往往会对学生网络对应特征进行后处理以保证蒸馏 loss 正确计算(connector实现)

mmseg的模型,在deconde_head.conv_seg后拿到的特征图为logist的特征图,其内的元素都为小数,而非预测的0/1

cwd算法是一种什么样的蒸馏算法呢是data-free的吗,还是online的呢,还是offline的呢

适用于mmseg的cwd模型蒸馏配置文件(default_hook中的Student_CheckpointHook是自定义的hook,继承自mmegin中的CheekpointHook)

cwd算法:首先使用softmax归一化方法将每个通道的feature map转换成一个分布,然后最小化两个网络对应通道的Kullback Leibler (KL)散度。通过这样做,我们的方法着重于模拟网络间通道的软分布。特别的是,KL的差异使学习能够更多地关注通道图中最突出的区域,大概对应于语义分割最有用的信号

  • 现象:
    由于model capacity gap的存在,student往往弱于teacher模型,但也并不绝对,如果model本身的gap不是很离谱,student还是有超越teacher的可能的,因为student模型一般可以学习teacher模型蒸馏位点的特征和ground truth多种知识,学习效率会更高,如果本身student没有太大的问题,还是有机会学的更好的。
python 复制代码
_base_ = [
    'mmseg::_base_/datasets/pascal_voc12.py',
    'mmseg::_base_/schedules/schedule_160k.py',
    'mmseg::_base_/default_runtime.py'
]

# 模型的optim_wrapper,学习率和学习策略将来自于继承的schedule_160k,如果不改的话
# wandb的可视化设置在mmseg的default_runtime,也继承自mmseg

# schedule_160k.py中的自动保存权重的部分
default_hooks = dict(_delete_=True,
    timer=dict(type='IterTimerHook'),
    logger=dict(type='LoggerHook', interval=100, log_metric_by_epoch=False),
    param_scheduler=dict(type='ParamSchedulerHook'),
    # 使用了自定义的Student_CheckpointHook
    checkpoint=dict(type='Student_CheckpointHook', by_epoch=False, interval=-1, max_keep_ckpts=2, save_best=['mDice', 'mIoU']),
    # checkpoint中,interval=-1则不会保存least.pth
    sampler_seed=dict(type='DistSamplerSeedHook'),
    visualization=dict(type='SegVisualizationHook'))

teacher_ckpt = '/root/autodl-tmp/all_workdir/mmseg_work_dir/baseline-convnext-tiny-upernet-rotate/best_mDice_iter_6800.pth'  # noqa: E501
teacher_cfg_path = 'mmseg::all_changed/baseline-convnext-tiny_upernet-rotate.py'  # noqa: E501

student_cfg_path = 'mmseg::all_changed/pspnet_r18-d8_b16-160k_voc-material-512x512.py'  # noqa: E501
model = dict(
    _scope_='mmrazor',
    type='SingleTeacherDistill',
    architecture=dict(cfg_path=student_cfg_path, pretrained=False),
    teacher=dict(cfg_path=teacher_cfg_path, pretrained=False),
    teacher_ckpt=teacher_ckpt,
    distiller=dict(
        type='ConfigurableDistiller',
        distill_losses=dict(
            loss_cwd=dict(type='ChannelWiseDivergence', tau=1, loss_weight=5)),
        student_recorders=dict(
            logits=dict(type='ModuleOutputs', source='decode_head.conv_seg')),
        teacher_recorders=dict(
            logits=dict(type='ModuleOutputs', source='decode_head.conv_seg')),
        connectors=dict(
            loss_conv_stu=dict(type='ConvModuleConncetor', in_channel=2, out_channel=2, kernel_size=1, stride=1, padding=0,
                              norm_cfg=dict(type='BN')),
            loss_conv_tea=dict(type='ConvModuleConncetor', in_channel=2, out_channel=2, kernel_size=3, stride=2, padding=1, padding_mode='circular',
                              norm_cfg=dict(type='BN'))),
        loss_forward_mappings=dict(
            loss_cwd=dict(
                preds_S=dict(from_student=True, recorder='logits', connector='loss_conv_stu'), # 含义:从student_recorders(from_student=True)中读取名为logits的数据
                # 加上connnecor字段后,表示从student_recorders中读取名为logists的数据,而后将数据通过名为loss_conv_stu的连接器
                preds_T=dict(from_student=False, recorder='logits', connector='loss_conv_tea')))))
                # 从teacher_recorders中读取名为logits的数据,而后将数据通过名为loss_conv_tea的连接器
                # 而无论是loss_cwd、logits、loss_conv_stu、loss_conv_tea都是自定的名称

find_unused_parameters = True


train_cfg = dict(
    type='IterBasedTrainLoop', max_iters=160000, val_interval=200)
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')

train_dataloader = dict(batch_size=16)  # 更改batch_size,否则会继承到pascal_voc12.py中的设置
# 这个16会作为teacher模型的推理batch和student模型的训练batch

work_dir = '/root/autodl-tmp/all_workdir/mmrazor_wokdir/distill/convnext-tiny-upernet_to_pspnet-r18'

模型蒸馏一

    1. 基于响应的KD(DKD ,FitNets):
      基于响应的KD以teacher模型的分类预测结果为目标知识,具体指的是分类器最后一个全连接层的输出(成为logits)。与最终的输出相比,logits没有经过softmax进行归一化,非目标类对应的输出值尚未被抑制。
      教师模型和学生模型之间的损失差异一般用KD散度,一般会用temperature(tau)大于1的参数对logits进行软化,以减小目标类和非目标类的预测值差异。
      logits具备的含义为模型判断当前样本为各类别的信心为多少
      1) logits提供的软标签信息,比one-hot的真实标签有着更高的熵值,从而提供了更多的信息量和数据之间更小梯度差异
      2) 软标签有着与标签平滑类似的效果,提高了模型的泛化能力
      3)除了gt标签外,还学习了软标签,使得模型学到了更多的知识,更倾向于学到不同的知识,优化方向更稳定
    1. 基于特征的KD(AB,AT,ofd,Factor Transfer):
      蒸馏位点位于模型中途获得的特征
      通常,teacher模型的通道大于学生通道,二者无法完全对齐,一般在学生的特征图后面接卷积,将两者在维度和通道上对齐,从而实现特征点的一一对应
      1)特征维度对齐,特征加权,mmrazor的connector模块的抽象
      2)知识定位,设计规则选出更为重要的教师特征
    1. 基于关系的KD(FSP, RKD):也使用特征,但计算不是特征点的一对一差异,而是特征关系的差异
      1)样本间关系蒸馏:在分类和分割中应用广泛,因为构建高质量的关系矩阵需要大量样本

总结:在语义分割中的cwd算法,可以看作是基于响应的KD,也可以看作是基于特征的KD,因为在传统的cwd算法中,使用的是在通过softmax之前的位置作为蒸馏位点,输出对应的特征图,去计算损失。

相关推荐
茶馆大橘43 分钟前
(黑马点评)八、实现签到统计和uv统计
数据库·redis·学习·阿里云·黑马点评
铁匠匠匠1 小时前
【C总集篇】第八章 数组和指针
c语言·开发语言·数据结构·经验分享·笔记·学习·算法
因为奋斗超太帅啦2 小时前
React学习笔记(三)——React 组件通讯
笔记·学习·react.js
exploration-earth3 小时前
缓存技术的核心价值与应用
学习
IM_DALLA3 小时前
【Verilog学习日常】—牛客网刷题—Verilog快速入门—VL16
学习·fpga开发
IM_DALLA3 小时前
【Verilog学习日常】—牛客网刷题—Verilog快速入门—VL18
学习·fpga开发
完球了3 小时前
【Day02-JS+Vue+Ajax】
javascript·vue.js·笔记·学习·ajax
怀九日4 小时前
C++(学习)2024.9.19
开发语言·c++·学习·重构·对象·
白帽黑客cst4 小时前
网络安全(黑客技术) 最新三个月学习计划
网络·数据结构·windows·学习·安全·web安全·网络安全
RS&5 小时前
python学习笔记
笔记·python·学习