知识蒸馏:获取学生网络和教师网络指定蒸馏位点 的输出特征并计算蒸馏 loss 的过程
知识蒸馏算法往往分为 reponse-based基于响应、feature-based基于特征 和 relation-based基于关系三类。
也可为 data-free KD、online KD、self KD(可视为一种特殊的 online KD)和比较经典的 offline KD
- feature-based 方法以教师模型特征提取器产生的中间层特征为学习对象
而ChannelWiseDivergence(cwd算法)使用的是预测之前的logistic特征图(在channel维度上取最大值,即为最终的预测结果),就是feature-based的蒸馏方法
mmrazor可以使用不同架构的student和teacher模型
- 可以使用connector对不同维度进行对齐,以计算语义分割的蒸馏loss
- 对于 feature-base 的方法,当学生和教师网络输出特征维度不同时,往往会对学生网络对应特征进行后处理以保证蒸馏 loss 正确计算(connector实现)
mmseg的模型,在deconde_head.conv_seg后拿到的特征图为logist的特征图,其内的元素都为小数,而非预测的0/1
cwd算法是一种什么样的蒸馏算法呢是data-free的吗,还是online的呢,还是offline的呢
适用于mmseg的cwd模型蒸馏配置文件(default_hook中的Student_CheckpointHook是自定义的hook,继承自mmegin中的CheekpointHook)
cwd算法:首先使用softmax归一化方法将每个通道的feature map转换成一个分布,然后最小化两个网络对应通道的Kullback Leibler (KL)散度。通过这样做,我们的方法着重于模拟网络间通道的软分布。特别的是,KL的差异使学习能够更多地关注通道图中最突出的区域,大概对应于语义分割最有用的信号
- 现象:
由于model capacity gap的存在,student往往弱于teacher模型,但也并不绝对,如果model本身的gap不是很离谱,student还是有超越teacher的可能的,因为student模型一般可以学习teacher模型蒸馏位点的特征和ground truth多种知识,学习效率会更高,如果本身student没有太大的问题,还是有机会学的更好的。
python
_base_ = [
'mmseg::_base_/datasets/pascal_voc12.py',
'mmseg::_base_/schedules/schedule_160k.py',
'mmseg::_base_/default_runtime.py'
]
# 模型的optim_wrapper,学习率和学习策略将来自于继承的schedule_160k,如果不改的话
# wandb的可视化设置在mmseg的default_runtime,也继承自mmseg
# schedule_160k.py中的自动保存权重的部分
default_hooks = dict(_delete_=True,
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=100, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
# 使用了自定义的Student_CheckpointHook
checkpoint=dict(type='Student_CheckpointHook', by_epoch=False, interval=-1, max_keep_ckpts=2, save_best=['mDice', 'mIoU']),
# checkpoint中,interval=-1则不会保存least.pth
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))
teacher_ckpt = '/root/autodl-tmp/all_workdir/mmseg_work_dir/baseline-convnext-tiny-upernet-rotate/best_mDice_iter_6800.pth' # noqa: E501
teacher_cfg_path = 'mmseg::all_changed/baseline-convnext-tiny_upernet-rotate.py' # noqa: E501
student_cfg_path = 'mmseg::all_changed/pspnet_r18-d8_b16-160k_voc-material-512x512.py' # noqa: E501
model = dict(
_scope_='mmrazor',
type='SingleTeacherDistill',
architecture=dict(cfg_path=student_cfg_path, pretrained=False),
teacher=dict(cfg_path=teacher_cfg_path, pretrained=False),
teacher_ckpt=teacher_ckpt,
distiller=dict(
type='ConfigurableDistiller',
distill_losses=dict(
loss_cwd=dict(type='ChannelWiseDivergence', tau=1, loss_weight=5)),
student_recorders=dict(
logits=dict(type='ModuleOutputs', source='decode_head.conv_seg')),
teacher_recorders=dict(
logits=dict(type='ModuleOutputs', source='decode_head.conv_seg')),
connectors=dict(
loss_conv_stu=dict(type='ConvModuleConncetor', in_channel=2, out_channel=2, kernel_size=1, stride=1, padding=0,
norm_cfg=dict(type='BN')),
loss_conv_tea=dict(type='ConvModuleConncetor', in_channel=2, out_channel=2, kernel_size=3, stride=2, padding=1, padding_mode='circular',
norm_cfg=dict(type='BN'))),
loss_forward_mappings=dict(
loss_cwd=dict(
preds_S=dict(from_student=True, recorder='logits', connector='loss_conv_stu'), # 含义:从student_recorders(from_student=True)中读取名为logits的数据
# 加上connnecor字段后,表示从student_recorders中读取名为logists的数据,而后将数据通过名为loss_conv_stu的连接器
preds_T=dict(from_student=False, recorder='logits', connector='loss_conv_tea')))))
# 从teacher_recorders中读取名为logits的数据,而后将数据通过名为loss_conv_tea的连接器
# 而无论是loss_cwd、logits、loss_conv_stu、loss_conv_tea都是自定的名称
find_unused_parameters = True
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=160000, val_interval=200)
val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
train_dataloader = dict(batch_size=16) # 更改batch_size,否则会继承到pascal_voc12.py中的设置
# 这个16会作为teacher模型的推理batch和student模型的训练batch
work_dir = '/root/autodl-tmp/all_workdir/mmrazor_wokdir/distill/convnext-tiny-upernet_to_pspnet-r18'
-
- 基于响应的KD(DKD ,FitNets):
基于响应的KD以teacher模型的分类预测结果为目标知识,具体指的是分类器最后一个全连接层的输出(成为logits)。与最终的输出相比,logits没有经过softmax进行归一化,非目标类对应的输出值尚未被抑制。
教师模型和学生模型之间的损失差异一般用KD散度,一般会用temperature(tau)大于1的参数对logits进行软化,以减小目标类和非目标类的预测值差异。
logits具备的含义为模型判断当前样本为各类别的信心为多少
1) logits提供的软标签信息,比one-hot的真实标签有着更高的熵值,从而提供了更多的信息量和数据之间更小梯度差异
2) 软标签有着与标签平滑类似的效果,提高了模型的泛化能力
3)除了gt标签外,还学习了软标签,使得模型学到了更多的知识,更倾向于学到不同的知识,优化方向更稳定
- 基于响应的KD(DKD ,FitNets):
-
- 基于特征的KD(AB,AT,ofd,Factor Transfer):
蒸馏位点位于模型中途获得的特征
通常,teacher模型的通道大于学生通道,二者无法完全对齐,一般在学生的特征图后面接卷积,将两者在维度和通道上对齐,从而实现特征点的一一对应
1)特征维度对齐,特征加权,mmrazor的connector模块的抽象
2)知识定位,设计规则选出更为重要的教师特征
- 基于特征的KD(AB,AT,ofd,Factor Transfer):
-
- 基于关系的KD(FSP, RKD):也使用特征,但计算不是特征点的一对一差异,而是特征关系的差异
1)样本间关系蒸馏:在分类和分割中应用广泛,因为构建高质量的关系矩阵需要大量样本
- 基于关系的KD(FSP, RKD):也使用特征,但计算不是特征点的一对一差异,而是特征关系的差异
总结:在语义分割中的cwd算法,可以看作是基于响应的KD,也可以看作是基于特征的KD,因为在传统的cwd算法中,使用的是在通过softmax之前的位置作为蒸馏位点,输出对应的特征图,去计算损失。