最近比较忙,有一段时间没更新了,最近yolov7用的比较多,总结一下。上一篇yolov5及yolov7实战之剪枝_CodingInCV的博客-CSDN博客 我们讲了通过剪枝来裁剪我们的模型,达到在精度损失不大的情况下,提高模型速度的目的。上一篇是从速度的角度,这一篇我们从检测性能的角度来改进yolov7(yolov5也类似)。
对于提高检测器的性能,我们除了可以从增加数据、修改模型结构、修改loss等模型本身的角度出发外,深度学习领域还有一个方式---蒸馏。简单的说,蒸馏就是让性能更强的模型(teacher, 参数量更大)来指导性能更弱student模型,从而提高student模型的性能。
蒸馏的方式有很多种,比较简单暴力的比如直接让student模型来拟合teacher模型的输出特征图,当然蒸馏也不是万能的,毕竟student模型和teacher模型的参数量有差距,student模型不一定能很好的学习teacher的知识,对于自己的任务有没有作用也需要尝试。
本篇选择的方法是去年CVPR上的针对目标检测的蒸馏算法:
yzd-v/FGD: Focal and Global Knowledge Distillation for Detectors (CVPR 2022) (github.com)
针对该方法的解读可以参考:FGD-CVPR2022:针对目标检测的焦点和全局蒸馏 - 知乎 (zhihu.com)
本篇暂时不涉及理论,重点在把这个方法集成到yolov7训练。步骤如下。
载入teacher模型
蒸馏首先需要有一个teacher模型,这个teacher模型一般和student同样结构,只是参数量更大、层数更多。比如对于yolov5,可以尝试用yolov5m来蒸馏yolov5s。
train.py增加一个命令行参数:
Python
parser.add_argument("--teacher-weights", type=str, default="", help="initial weights path")
在train函数中载入teacher weights,过程与原有的载入过程类似,注意,DP或者DDP模型也要对teacher模型做对应的处理。
Python
# teacher model
if opt.teacher_weights:
teacher_weights = opt.teacher_weights
# with torch_distributed_zero_first(rank):
# teacher_weights = attempt_download(teacher_weights) # download if not found locally
teacher_model = Model(teacher_weights, ch=3, nc=nc).to(device) # create
# load state_dict
ckpt = torch.load(teacher_weights, map_location=device) # load checkpoint
state_dict = ckpt["model"].float().state_dict() # to FP32
teacher_model.load_state_dict(state_dict, strict=True) # load
#set to eval
teacher_model.eval()
#set IDetect to train mode
# teacher_model.model[-1].train()
logger.info(f"Load teacher model from {teacher_weights}") # report
# DP mode
if cuda and rank == -1 and torch.cuda.device_count() > 1:
model = torch.nn.DataParallel(model)
if opt.teacher_weights:
teacher_model = torch.nn.DataParallel(teacher_model)
# SyncBatchNorm
if opt.sync_bn and cuda and rank != -1:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
logger.info("Using SyncBatchNorm()")
if opt.teacher_weights:
teacher_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(teacher_model).to(device)
teacher模型不进行梯度计算,因此:
Python
if opt.teacher_weights:
for param in teacher_model.parameters():
param.requires_grad = False
蒸馏Loss
蒸馏loss是计算teacher模型的一层或者多层与student的对应层的相似度,监督student模型向teacher模型靠近。对于yolov7,可以去监督三个特征层。
参考FGD的开源代码,我们在loss.py中增加一个FeatureLoss类, 参数暂时使用默认:
Python
class FeatureLoss(nn.Module):
"""PyTorch version of `Feature Distillation for General Detectors`
Args:
student_channels(int): Number of channels in the student's feature map.
teacher_channels(int): Number of channels in the teacher's feature map.
temp (float, optional): Temperature coefficient. Defaults to 0.5.
name (str): the loss name of the layer
alpha_fgd (float, optional): Weight of fg_loss. Defaults to 0.001
beta_fgd (float, optional): Weight of bg_loss. Defaults to 0.0005
gamma_fgd (float, optional): Weight of mask_loss. Defaults to 0.0005
lambda_fgd (float, optional): Weight of relation_loss. Defaults to 0.000005
"""
def __init__(self,
student_channels,
teacher_channels,
temp=0.5,
alpha_fgd=0.001,
beta_fgd=0.0005,
gamma_fgd=0.001,
lambda_fgd=0.000005,
):
super(FeatureLoss, self).__init__()
self.temp = temp
self.alpha_fgd = alpha_fgd
self.beta_fgd = beta_fgd
self.gamma_fgd = gamma_fgd
self.lambda_fgd = lambda_fgd
if student_channels != teacher_channels:
self.align = nn.Conv2d(student_channels, teacher_channels, kernel_size=1, stride=1, padding=0)
else:
self.align = None
self.conv_mask_s = nn.Conv2d(teacher_channels, 1, kernel_size=1)
self.conv_mask_t = nn.Conv2d(teacher_channels, 1, kernel_size=1)
self.channel_add_conv_s = nn.Sequential(
nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
nn.LayerNorm([teacher_channels//2, 1, 1]),
nn.ReLU(inplace=True), # yapf: disable
nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))
self.channel_add_conv_t = nn.Sequential(
nn.Conv2d(teacher_channels, teacher_channels//2, kernel_size=1),
nn.LayerNorm([teacher_channels//2, 1, 1]),
nn.ReLU(inplace=True), # yapf: disable
nn.Conv2d(teacher_channels//2, teacher_channels, kernel_size=1))
self.reset_parameters()
def forward(self,
preds_S,
preds_T,
gt_bboxes,
img_metas):
"""Forward function.
Args:
preds_S(Tensor): Bs*C*H*W, student's feature map
preds_T(Tensor): Bs*C*H*W, teacher's feature map
gt_bboxes(tuple): Bs*[nt*4], pixel decimal: (tl_x, tl_y, br_x, br_y)
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
"""
assert preds_S.shape[-2:] == preds_T.shape[-2:], 'the output dim of teacher and student differ'
device = gt_bboxes.device
self.to(device)
if self.align is not None:
preds_S = self.align(preds_S)
N,C,H,W = preds_S.shape
S_attention_t, C_attention_t = self.get_attention(preds_T, self.temp)
S_attention_s, C_attention_s = self.get_attention(preds_S, self.temp)
Mask_fg = torch.zeros_like(S_attention_t)
# Mask_bg = torch.ones_like(S_attention_t)
wmin,wmax,hmin,hmax = [],[],[],[]
img_h, img_w = img_metas
bboxes = gt_bboxes[:,2:6]
#xywh2xyxy
bboxes = xywh2xyxy(bboxes)
new_boxxes = torch.ones_like(bboxes)
new_boxxes[:, 0] = torch.floor(bboxes[:, 0]*W)
new_boxxes[:, 2] = torch.ceil(bboxes[:, 2]*W)
new_boxxes[:, 1] = torch.floor(bboxes[:, 1]*H)
new_boxxes[:, 3] = torch.ceil(bboxes[:, 3]*H)
#to int
new_boxxes = new_boxxes.int()
for i in range(N):
new_boxxes_i = new_boxxes[torch.where(gt_bboxes[:,0]==i)]
wmin.append(new_boxxes_i[:, 0])
wmax.append(new_boxxes_i[:, 2])
hmin.append(new_boxxes_i[:, 1])
hmax.append(new_boxxes_i[:, 3])
area = 1.0/(hmax[i].view(1,-1)+1-hmin[i].view(1,-1))/(wmax[i].view(1,-1)+1-wmin[i].view(1,-1))
for j in range(len(new_boxxes_i)):
Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1] = \
torch.maximum(Mask_fg[i][hmin[i][j]:hmax[i][j]+1, wmin[i][j]:wmax[i][j]+1], area[0][j])
Mask_bg = torch.where(Mask_fg > 0, 0., 1.)
Mask_bg_sum = torch.sum(Mask_bg, dim=(1,2))
Mask_bg[Mask_bg_sum>0] /= Mask_bg_sum[Mask_bg_sum>0].unsqueeze(1).unsqueeze(2)
fg_loss, bg_loss = self.get_fea_loss(preds_S, preds_T, Mask_fg, Mask_bg,
C_attention_s, C_attention_t, S_attention_s, S_attention_t)
mask_loss = self.get_mask_loss(C_attention_s, C_attention_t, S_attention_s, S_attention_t)
rela_loss = self.get_rela_loss(preds_S, preds_T)
loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \
+ self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss
return loss, loss.detach()
def get_attention(self, preds, temp):
""" preds: Bs*C*W*H """
N, C, H, W= preds.shape
value = torch.abs(preds)
# Bs*W*H
fea_map = value.mean(axis=1, keepdim=True)
S_attention = (H * W * F.softmax((fea_map/temp).view(N,-1), dim=1)).view(N, H, W)
# Bs*C
channel_map = value.mean(axis=2,keepdim=False).mean(axis=2,keepdim=False)
C_attention = C * F.softmax(channel_map/temp, dim=1)
return S_attention, C_attention
def get_fea_loss(self, preds_S, preds_T, Mask_fg, Mask_bg, C_s, C_t, S_s, S_t):
loss_mse = nn.MSELoss(reduction='sum')
Mask_fg = Mask_fg.unsqueeze(dim=1)
Mask_bg = Mask_bg.unsqueeze(dim=1)
C_t = C_t.unsqueeze(dim=-1)
C_t = C_t.unsqueeze(dim=-1)
S_t = S_t.unsqueeze(dim=1)
fea_t= torch.mul(preds_T, torch.sqrt(S_t))
fea_t = torch.mul(fea_t, torch.sqrt(C_t))
fg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_fg))
bg_fea_t = torch.mul(fea_t, torch.sqrt(Mask_bg))
fea_s = torch.mul(preds_S, torch.sqrt(S_t))
fea_s = torch.mul(fea_s, torch.sqrt(C_t))
fg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_fg))
bg_fea_s = torch.mul(fea_s, torch.sqrt(Mask_bg))
fg_loss = loss_mse(fg_fea_s, fg_fea_t)/len(Mask_fg)
bg_loss = loss_mse(bg_fea_s, bg_fea_t)/len(Mask_bg)
return fg_loss, bg_loss
def get_mask_loss(self, C_s, C_t, S_s, S_t):
mask_loss = torch.sum(torch.abs((C_s-C_t)))/len(C_s) + torch.sum(torch.abs((S_s-S_t)))/len(S_s)
return mask_loss
def spatial_pool(self, x, in_type):
batch, channel, width, height = x.size()
input_x = x
# [N, C, H * W]
input_x = input_x.view(batch, channel, height * width)
# [N, 1, C, H * W]
input_x = input_x.unsqueeze(1)
# [N, 1, H, W]
if in_type == 0:
context_mask = self.conv_mask_s(x)
else:
context_mask = self.conv_mask_t(x)
# [N, 1, H * W]
context_mask = context_mask.view(batch, 1, height * width)
# [N, 1, H * W]
context_mask = F.softmax(context_mask, dim=2)
# [N, 1, H * W, 1]
context_mask = context_mask.unsqueeze(-1)
# [N, 1, C, 1]
context = torch.matmul(input_x, context_mask)
# [N, C, 1, 1]
context = context.view(batch, channel, 1, 1)
return context
def get_rela_loss(self, preds_S, preds_T):
loss_mse = nn.MSELoss(reduction='sum')
context_s = self.spatial_pool(preds_S, 0)
context_t = self.spatial_pool(preds_T, 1)
out_s = preds_S
out_t = preds_T
channel_add_s = self.channel_add_conv_s(context_s)
out_s = out_s + channel_add_s
channel_add_t = self.channel_add_conv_t(context_t)
out_t = out_t + channel_add_t
rela_loss = loss_mse(out_s, out_t)/len(out_s)
return rela_loss
def last_zero_init(self, m):
if isinstance(m, nn.Sequential):
constant_init(m[-1], val=0)
else:
constant_init(m, val=0)
def reset_parameters(self):
kaiming_init(self.conv_mask_s, mode='fan_in')
kaiming_init(self.conv_mask_t, mode='fan_in')
self.conv_mask_s.inited = True
self.conv_mask_t.inited = True
self.last_zero_init(self.channel_add_conv_s)
self.last_zero_init(self.channel_add_conv_t)
实例化FeatureLoss
在train.py中,实例化我们定义的FeatureLoss,由于我们要蒸馏三层,所以需要定一个蒸馏损失的数组:
Python
if opt.teacher_weights:
student_kd_layers = hyp["student_kd_layers"]
teacher_kd_layers = hyp["teacher_kd_layers"]
dump_image = torch.zeros((1, 3, imgsz, imgsz), device=device)
targets = torch.Tensor([[0, 0, 0, 0, 0, 0]]).to(device)
_, features = model(dump_image, extra_features = student_kd_layers) # forward
_, teacher_features = teacher_model(dump_image,
extra_features=teacher_kd_layers)
kd_losses = []
for i in range(len(features)):
feature = features[i]
teacher_feature = teacher_features[i]
_, student_channels, _ , _ = feature.shape
_, teacher_channels, _ , _ = teacher_feature.shape
kd_losses.append(FeatureLoss(student_channels,teacher_channels))
其中hyp['xxx_kd_layers']是用于指定我们要蒸馏的层序号。
为了提取出我们需要的层的特征图,我们还需要对模型推理的代码进行修改,这个放在下一篇,这一篇先把主要流程过一遍。
蒸馏训练
与普通loss一样,在训练中,首先计算蒸馏loss, 然后进行反向传播,区别只是计算蒸馏loss时需要使用teacher模型也对数据进行推理。
Python
if opt.teacher_weights:
pred, features = model(imgs, extra_features = student_kd_layers) # forward
_, teacher_features = teacher_model(imgs, extra_features = teacher_kd_layers)
if "loss_ota" not in hyp or hyp["loss_ota"] == 1 and epoch >= ota_start:
loss, loss_items = compute_loss_ota(
pred, targets.to(device), imgs
)
else:
loss, loss_items = compute_loss(
pred, targets.to(device)
) # loss scaled by batch_size
# kd loss
loss_items = torch.cat((loss_items[0].unsqueeze(0), loss_items[1].unsqueeze(0), loss_items[2].unsqueeze(0), torch.zeros(1, device=device), loss_items[3].unsqueeze(0)))
loss_items[-1]*=imgs.shape[0]
for i in range(len(features)):
feature = features[i]
teacher_feature = teacher_features[i]
kd_loss, kd_loss_item = kd_losses[i](feature, teacher_feature, targets.to(device), [imgsz,imgsz])
loss += kd_loss
loss_items[3] += kd_loss_item
loss_items[4] += kd_loss_item
在这里,我们将kd_loss累加到了loss上。计算出总的loss,其他就与普通训练一样了。
结语
这篇文章简述了一下yolov7的蒸馏过程,更多细节将在下一篇中讲述。