公开时间:2022年3月9号
项目地址:https://github.com/yzd-v/FGD
论文地址:https://arxiv.org/pdf/2111.11837
知识蒸馏已成功地应用于图像分类。然而,目标检测要复杂得多,大多数知识蒸馏方法都失败了
。本文指出,在目标检测中,教师和学生的特征在不同的领域有很大的差异,特别是在前景和背景上。如果我们平均地提取它们,特征图之间的不均匀差异将会对蒸馏产生负面影响
。因此,我们提出了聚焦蒸馏和全局蒸馏(FGD)。聚焦蒸馏将前景和背景分开,迫使学生专注于教师的临界像素和通道
。全局蒸馏重建了不同像素之间的关系,并将其从教师转移到学生身上,补偿了聚焦蒸馏中全局信息的缺失
。
由于我们的方法只需要计算特征图上的损失,因此FGD可以应用于各种检测器
。 我们在不同骨架的各种检测器上进行了实验,结果表明,该学生检测器取得了良好的mAP改进,为2~3个点。
1、核心观点
1.1 区分FG与BG的蒸馏差异
教师和学生的特征在不同的领域有很大的差异,特别是在前景和背景上。
作者通过实验表明,对fg与bg不做取得的蒸馏,还不如单独对fg或bg进行蒸馏。这里fb是是bbox对应的特征图区域,bg是背景对应的特征图区域。
1.2 具体实现
1、对backbone的输出进行Global Distillation操作,使教师模型与学生模型的输出解决
2、在neck的输出上,根据bbox区分前景与背景,分别进行蒸馏,然后loss加权
总体loss实现:
Focal Distillation
对前景与背景分别设定loss权重进行蒸馏,同时附加spatial和chanel的attention蒸馏结构,使学生模型模拟教师模型
Global Distillation
1.3 有益效果
基于表3可以发现FGD的蒸馏方式,对于各类任务(目标检测、实力分割、关键点检测)均有提升效果,基本能提升3个点左右。
与其他目标检测蒸馏策略相比,FGD方法能提升02~0.7个点的精度,同时蒸馏后的S模型精度比T模型要略高。
蒸馏后的特征图变化
2、消融实验
2.1 focal and global distillation
基于这里的对比可以发现,仅蒸馏backbone或对neck进行有区别蒸馏,均能取得良好效果。但
两个一起蒸馏能额外取得0.2个点的提升。
2.2 Spatial attention 与 Channel attention
这里的蒸馏效果差异如下,同样是结合2个维度蒸馏,能提升0.1~0.2个点。同时表明spatial蒸馏更有效
2.3 GcBlock作用
通常蒸馏是直接对比教师模型与学生模型的差异,而本文中提到基于GcBlock对二者进行高维度映射后在计算loss。这里可以发现GcBlock是蒸馏有效的基本条件,否则涨点幅度较小。
2.4 蒸馏温度
在neck中进行蒸馏时,考虑了教师输出的spatial与chanel的分布特征,具体如下所示
这里通过消融实验,表明蒸馏温度对效果的影响。0.5或0.8为最佳值,这表明需要对教师的输出进行加热,体现出显著的分布特征,学生模型才能学习好。
3、实现代码
基于mmdet进行实现
3.1 配置文件
基于对配置文件的分析,博主认为只有一个针对neck层的FeatureLoss
py
_base_ = [
'../../_base_/datasets/coco_detection.py',
'../../_base_/schedules/schedule_2x.py', '../../_base_/default_runtime.py'
]
# model settings
find_unused_parameters=True
temp=0.5
alpha_fgd=0.00005
beta_fgd=0.000025
gamma_fgd=0.00005
lambda_fgd=0.0000005
distiller = dict(
type='DetectionDistiller',
teacher_pretrained = 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r101_fpn_2x_coco/faster_rcnn_r101_fpn_2x_coco_bbox_mAP-0.398_20200504_210455-1d2dac9c.pth',
init_student = True,
distill_cfg = [ dict(student_module = 'neck.fpn_convs.3.conv',
teacher_module = 'neck.fpn_convs.3.conv',
output_hook = True,
methods=[dict(type='FeatureLoss',
name='loss_fgd_fpn_3',
student_channels = 256,
teacher_channels = 256,
temp = temp,
alpha_fgd=alpha_fgd,
beta_fgd=beta_fgd,
gamma_fgd=gamma_fgd,
lambda_fgd=lambda_fgd,
)
]
),
dict(student_module = 'neck.fpn_convs.2.conv',
teacher_module = 'neck.fpn_convs.2.conv',
output_hook = True,
methods=[dict(type='FeatureLoss',
name='loss_fgd_fpn_2',
student_channels = 256,
teacher_channels = 256,
temp = temp,
alpha_fgd=alpha_fgd,
beta_fgd=beta_fgd,
gamma_fgd=gamma_fgd,
lambda_fgd=lambda_fgd,
)
]
),
dict(student_module = 'neck.fpn_convs.1.conv',
teacher_module = 'neck.fpn_convs.1.conv',
output_hook = True,
methods=[dict(type='FeatureLoss',
name='loss_fgd_fpn_1',
student_channels = 256,
teacher_channels = 256,
temp = temp,
alpha_fgd=alpha_fgd,
beta_fgd=beta_fgd,
gamma_fgd=gamma_fgd,
lambda_fgd=lambda_fgd,
)
]
),
dict(student_module = 'neck.fpn_convs.0.conv',
teacher_module = 'neck.fpn_convs.0.conv',
output_hook = True,
methods=[dict(type='FeatureLoss',
name='loss_fgd_fpn_0',
student_channels = 256,
teacher_channels = 256,
temp = temp,
alpha_fgd=alpha_fgd,
beta_fgd=beta_fgd,
gamma_fgd=gamma_fgd,
lambda_fgd=lambda_fgd,
)
]
),
]
)
student_cfg = 'configs/faster_rcnn/faster_rcnn_r50_fpn_2x_coco.py'
teacher_cfg = 'configs/faster_rcnn/faster_rcnn_r101_fpn_2x_coco.py'
optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(_delete_=True, grad_clip=dict(max_norm=35, norm_type=2))
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,)
3.2 forward_train函数
detection_distiller.py 中的forward_train函数定义了模型蒸馏的前向推理流程,可以发现就是针对配置文件中的layer计算FeatureLoss
py
def forward_train(self, img, img_metas, **kwargs):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
Returns:
dict[str, Tensor]: A dictionary of loss components(student's losses and distiller's losses).
"""
with torch.no_grad():
self.teacher.eval()
feat = self.teacher.extract_feat(img)
student_loss = self.student.forward_train(img, img_metas, **kwargs)
buffer_dict = dict(self.named_buffers())
for item_loc in self.distill_cfg:
student_module = 'student_' + item_loc.student_module.replace('.','_')
teacher_module = 'teacher_' + item_loc.teacher_module.replace('.','_')
student_feat = buffer_dict[student_module]
teacher_feat = buffer_dict[teacher_module]
for item_loss in item_loc.methods:
loss_name = item_loss.name
student_loss[loss_name] = self.distill_losses[loss_name](student_feat,teacher_feat,kwargs['gt_bboxes'], img_metas)
return student_loss
3.3 Focal Global Distillation 代码
代码地址:
https://github.com/yzd-v/FGD/blob/master/mmdet/distillation/losses/fgd.py
这里的代码实现比较复杂,博主认为是将Focal Distillation部分+Global 部分的GcBlock针对同一layer对象进行实现,并没有像论文示意图中作用于不同的layer
py
import torch.nn as nn
import torch.nn.functional as F
import torch
from mmcv.cnn import constant_init, kaiming_init
from ..builder import DISTILL_LOSSES
@DISTILL_LOSSES.register_module()
class FeatureLoss(nn.Module):
"""PyTorch version of `Focal and Global Knowledge Distillation for Detectors`
Args:
student_channels(int): Number of channels in the student's feature map.
teacher_channels(int): Number of channels in the teacher's feature map.
temp (float, optional): Temperature coefficient. Defaults to 0.5.
name (str): the loss name of the layer
alpha_fgd (float, optional): Weight of fg_loss. Defaults to 0.001
beta_fgd (float, optional): Weight of bg_loss. Defaults to 0.0005
gamma_fgd (float, optional): Weight of mask_loss. Defaults to 0.001
lambda_fgd (float, optional): Weight of relation_loss. Defaults to 0.000005
"""
def __init__(self,
student_channels,
teacher_channels,
name,
temp=0.5,
alpha_fgd=0.001,
beta_fgd=0.0005,
gamma_fgd=0.001,
lambda_fgd=0.000005,
):
super(FeatureLoss, self).__init__()
self.temp = temp
self.alpha_fgd = alpha_fgd
self.beta_fgd = beta_fgd
self.gamma_fgd = gamma_fgd
self.lambda_fgd = lambda_fgd
if student_channels != teacher_channels:
self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0)
else:
self.align = None
self.conv_mask_s = nn.Conv2d(teacher_channels, 1, kernel_size=1)
self.conv_mask_t = nn.Conv2d(teacher_channels, 1, kernel_size=1)
self.channel_add_conv_s = nn.Sequential(
nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
nn.LayerNorm([teacher_channels//2, 1, 1]),
nn.ReLU(inplace=True), # yapf: disable
nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))
self.channel_add_conv_t = nn.Sequential(
nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
nn.LayerNorm([teacher_channels//2, 1, 1]),
nn.ReLU(inplace=True), # yapf: disable
nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1)) #FcBlock
self.reset_parameters()
def forward(self,
preds_S,
preds_T,
gt_bboxes,
img_metas):
"""Forward function.
Args:
preds_S(Tensor): Bs*C*H*W, student's feature map
preds_T(Tensor): Bs*C*H*W, teacher's feature map
gt_bboxes(tuple): Bs*[nt*4], pixel decimal: (tl_x, tl_y, br_x, br_y)
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
"""
assert preds_S.shape[-2:] == preds_T.shape[-2:],'the output dim of teacher and student differ'
if self.align is not None:
preds_S = self.align(preds_S)
N,C,H,W = preds_S.shape
S_attention_t, C_attention_t = self.get_attention(preds_T, self.temp)
S_attention_s, C_attention_s = self.get_attention(preds_S, self.temp)
Mask_fg = torch.zeros_like(S_attention_t)
Mask_bg = torch.ones_like(S_attention_t)
wmin,wmax,hmin,hmax = [],[],[],[]
for i in range(N):
new_boxxes = torch.ones_like(gt_bboxes[i])
new_boxxes[:, 0] = gt_bboxes[i][:, 0]/img_metas[i]['img_shape'][1]*W
new_boxxes[:, 2] = gt_bboxes[i][:, 2]/img_metas[i]['img_shape'][1]*W
new_boxxes[:, 1] = gt_bboxes[i][:, 1]/img_metas[i]['img_shape'][0]*H
new_boxxes[:, 3] = gt_bboxes[i][:, 3]/img_metas[i]['img_shape'][0]*H
wmin.append(torch.floor(new_boxxes[:, 0]).int())
wmax.append(torch.ceil(new_boxxes[:, 2]).int())
hmin.append(torch.floor(new_boxxes[:, 1]).int())
hmax.append(torch.ceil(new_boxxes[:, 3]).int())
area = 1.0/(hmax[i].view(1,-1)+1-hmin[i].view(1,-1))/(wmax[i].view(1,-1)+1-wmin[i].view(1,-1))
for j in range(len(gt_bboxes[i])):
Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1] = \
torch.maximum(Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1], area[0][j])
Mask_bg[i] = torch.where(Mask_fg[i]>0, 0, 1)
if torch.sum(Mask_bg[i]):
Mask_bg[i] /= torch.sum(Mask_bg[i])
fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, Mask_fg, Mask_bg,
C_attention_s, C_attention_t, S_attention_s, S_attention_t)
mask_loss = self.get_mask_loss(C_attention_s, C_attention_t, S_attention_s, S_attention_t)
rela_loss = self.get_rela_loss(preds_S, preds_T)
loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \
+ self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss
return loss
def get_attention(self, preds, temp):
""" preds: Bs*C*W*H """
N, C, H, W= preds.shape
value = torch.abs(preds)
# Bs*W*H
fea_map = value.mean(axis=1, keepdim=True)
S_attention = (H * W * F.softmax((fea_map/temp).view(N,-1), dim=1)).view(N, H, W)
# Bs*C
channel_map = value.mean(axis=2,keepdim=False).mean(axis=2,keepdim=False)
C_attention = C * F.softmax(channel_map/temp, dim=1)
return S_attention, C_attention
def get_fea_loss(self, preds_S, preds_T, Mask_fg, Mask_bg, C_s, C_t, S_s, S_t):
loss_mse = nn.MSELoss(reduction='sum')
Mask_fg = Mask_fg.unsqueeze(dim=1)
Mask_bg = Mask_bg.unsqueeze(dim=1)
C_t = C_t.unsqueeze(dim=-1)
C_t = C_t.unsqueeze(dim=-1)
S_t = S_t.unsqueeze(dim=1)
fea_t= torch.mul(preds_T, torch.sqrt(S_t))
fea_t = torch.mul(fea_t, torch.sqrt(C_t))
fg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_fg))
bg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_bg))
fea_s = torch.mul(preds_S, torch.sqrt(S_t))
fea_s = torch.mul(fea_s, torch.sqrt(C_t))
fg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_fg))
bg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_bg))
fg_loss = loss_mse(fg_fea_s, fg_fea_t)/len(Mask_fg)
bg_loss = loss_mse(bg_fea_s, bg_fea_t)/len(Mask_bg)
return fg_loss, bg_loss
def get_mask_loss(self, C_s, C_t, S_s, S_t):
mask_loss = torch.sum(torch.abs((C_s-C_t)))/len(C_s) + torch.sum(torch.abs((S_s-S_t)))/len(S_s)
return mask_loss
def spatial_pool(self, x, in_type):
batch, channel, width, height = x.size()
input_x = x
# [N, C, H * W]
input_x = input_x.view(batch, channel, height * width)
# [N, 1, C, H * W]
input_x = input_x.unsqueeze(1)
# [N, 1, H, W]
if in_type == 0:
context_mask = self.conv_mask_s(x)
else:
context_mask = self.conv_mask_t(x)
# [N, 1, H * W]
context_mask = context_mask.view(batch, 1, height * width)
# [N, 1, H * W]
context_mask = F.softmax(context_mask, dim=2)
# [N, 1, H * W, 1]
context_mask = context_mask.unsqueeze(-1)
# [N, 1, C, 1]
context = torch.matmul(input_x, context_mask)
# [N, C, 1, 1]
context = context.view(batch, channel, 1, 1)
return context
def get_rela_loss(self, preds_S, preds_T):
loss_mse = nn.MSELoss(reduction='sum')
context_s = self.spatial_pool(preds_S, 0)
context_t = self.spatial_pool(preds_T, 1)
out_s = preds_S
out_t = preds_T
channel_add_s = self.channel_add_conv_s(context_s)
out_s = out_s + channel_add_s
channel_add_t = self.channel_add_conv_t(context_t)
out_t = out_t + channel_add_t
rela_loss = loss_mse(out_s, out_t)/len(out_s)
return rela_loss
def last_zero_init(self, m):
if isinstance(m, nn.Sequential):
constant_init(m[-1], val=0)
else:
constant_init(m, val=0)
def reset_parameters(self):
kaiming_init(self.conv_mask_s, mode='fan_in')
kaiming_init(self.conv_mask_t, mode='fan_in')
self.conv_mask_s.inited = True
self.conv_mask_t.inited = True
self.last_zero_init(self.channel_add_conv_s)
self.last_zero_init(self.channel_add_conv_t)