在上一篇博文中,我们讲解了YOLOv8
实例分割的训练过程,已将前向传播过程分析完毕,那么,接下来便是损失计算过程了。
文章目录
训练整体流程
获得预测结果与真值后,即可计算损失。
整体流程如下:
预测结果preds
如下:
真值batch
如下:
分割整体损失函数
v8SegmentationLoss
的计算过程如下,从最终的结果来看,其计算了四个损失,分别是目标预测框损失、mask
损失、类别损失以及DEL
损失,博主已将每段代码的结果标注在对应的代码位置。同时,在损失计算过程中不可避免的需要使用其他方法,博主将一些较为重要的方法也罗列出来了。
python
def __call__(self, preds, batch):
"""Calculate and return the loss for the YOLO model."""
loss = torch.zeros(4, device=self.device) # box, cls, dfl,mask,共 4 个
feats, pred_masks, proto = preds if len(preds) == 3 else preds[1] #feature是目标检测头输出的三个特征图list类型,pred_masks为torch.Size([4, 32, 8400]) ,oroto为torch.Size([4, 32, 160, 160])
batch_size, _, mask_h, mask_w = proto.shape # batch size, number of masks, mask height, mask width
pred_distri, pred_scores = torch.cat([xi.view(feats[0].shape[0], self.no, -1) for xi in feats], 2).split(
(self.reg_max * 4, self.nc), 1
)#拆分目标检测特征图结果,并融合在一起,得到(4,64,8400)与(4,80,8400)
# B, grids, .. 维度转换
pred_scores = pred_scores.permute(0, 2, 1).contiguous()#(4,8400,80)
pred_distri = pred_distri.permute(0, 2, 1).contiguous()#(4,8400,64)
pred_masks = pred_masks.permute(0, 2, 1).contiguous()#(4,8400,32)
dtype = pred_scores.dtype #torch.float16
imgsz = torch.tensor(feats[0].shape[2:], device=self.device, dtype=dtype) * self.stride[0] # image size (h,w)
anchor_points, stride_tensor = make_anchors(feats, self.stride, 0.5)
# Targets
try:
batch_idx = batch["batch_idx"].view(-1, 1)
targets = torch.cat((batch_idx, batch["cls"].view(-1, 1), batch["bboxes"]), 1)#torch.Size([22, 6]) 即batchid 类别id x y w h 共6个数
targets = self.preprocess(targets.to(self.device), batch_size, scale_tensor=imgsz[[1, 0, 1, 0]])#(batch,最大目标数量,5) 5=1+4
gt_labels, gt_bboxes = targets.split((1, 4), 2) # cls, xyxy (4,7,1)(4,7,4)
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0.0)
# 这行代码的作用是创建一个掩码张量 mask_gt,用于标识哪些目标是有效的(即有实际的边界框),哪些目标是无效的(即填充的零)。
# sum(2, keepdim=True):沿着第 2 维(即每个边界框的坐标维度)进行求和,保持维度不变。
# 这样会将每个边界框的 4 个坐标值相加,如果该目标是有效的边界框,和将大于 0;如果该目标是填充的零,和将等于 0。
except RuntimeError as e
# Pboxes
pred_bboxes = self.bbox_decode(anchor_points, pred_distri) # xyxy, (b, h*w, 4)
_, target_bboxes, target_scores, fg_mask, target_gt_idx = self.assigner(
pred_scores.detach().sigmoid(),
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,#(4,7,1)7指这个四个batch中
gt_bboxes,#(4,7,4)
mask_gt,
)
"""
self.assigner即根据分类与回归的分数加权的分数选择正样本 输入(共6个输入值)
1.pred_scores:表示模型预测的每个锚点位置的分类分数 形状为:(batch_size, 8400, 80)
2.(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype)
pred_bboxes:解码后的边界框坐标,形状为 [batch_size, 8400, 4]。
detach():从计算图中分离出预测的边界框,避免反向传播时更新它们。
* stride_tensor:将边界框坐标乘以步幅张量 stride_tensor,恢复到原图尺度。
type(gt_bboxes.dtype):将边界框的类型转换为与 gt_bboxes 相同的类型。
3.anchor_points * stride_tensor
anchor_points:锚点位置,形状为 [num_anchors, 2]
将锚点位置乘以步幅张量 stride_tensor,恢复到原图尺度
4.gt_labels:真实目标的类别标签
5.gt_bboxes:真实目标的边界框坐标
6.mask_gt:掩码张量,标识哪些目标是有效的
"""
"""
输出(共5个返回值)
target_labels, 这里用_代替了:形状为 [batch_size, num_anchors],包含每个锚点的目标标签
target_bboxes:形状为 [batch_size, num_anchors, 4],包含每个锚点的目标边界框
target_scores:形状为 [batch_size, num_anchors, num_classes],包含每个锚点的目标得分。
fg_mask:形状为 [batch_size, num_anchors],标识哪些锚点是前景(即有效的目标, 正样本)。
fg_mask作用:标识哪些锚点是前景(正样本),哪些是背景(负样本)。
正样本:锚点被分配给一个真实的目标,表示这个锚点负责检测这个目标。
负样本:锚点未被分配给任何目标,表示这个锚点不负责检测任何目标
target_gt_idx, 这里用_代替了:形状为 [batch_size, num_anchors],包含每个锚点对应的真实目标索引。
"""
target_scores_sum = max(target_scores.sum(), 1)
# Cls loss
# loss[1] = self.varifocal_loss(pred_scores, target_scores, target_labels) / target_scores_sum # VFL way
loss[2] = self.bce(pred_scores, target_scores.to(dtype)).sum() / target_scores_sum # BCE 即求分类损失
if fg_mask.sum():
# Bbox loss
loss[0], loss[3] = self.bbox_loss( #求Box损失和DEL损失
pred_distri,
pred_bboxes,
anchor_points,
target_bboxes / stride_tensor,
target_scores,
target_scores_sum,
fg_mask,
)
# Masks loss
masks = batch["masks"].to(self.device).float()
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
loss[1] = self.calculate_segmentation_loss(
fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
)
# WARNING: lines below prevent Multi-GPU DDP 'unused gradient' PyTorch errors, do not remove
else:
loss[1] += (proto * 0).sum() + (pred_masks * 0).sum() # inf sums may lead to nan loss
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.box # seg gain
loss[2] *= self.hyp.cls # cls gain
loss[3] *= self.hyp.dfl # dfl gain
return loss.sum() * batch_size, loss.detach() # loss(box, cls, dfl)
生成网格
make_anchor
方法,其作用是利用特征图生成网格的形式来创建预测框,这种产方式自YOLOv3
起便一直沿用。
这里生成网格使用的方法为:x,y=torch.meshgrid(a,b)
torch.meshgrid(a,b)的功能是生成网格,可以用于生成坐标。函数输入两个数据类型相同的一维张量,两个输出张量的行数为第一个输入张量的元素个数,列数为第二个输入张量的元素个数,当两个输入张量数据类型不同或维度不是一维时会报错。
其中第一个输出张量填充第一个输入张量中的元素,各行元素相同;第二个输出张量填充第二个输入张量中的元素,各列元素相同。
a
tensor([1, 2, 3, 4, 5, 6])
b
tensor([ 7, 8, 9, 10])
x
tensor([[1, 1, 1, 1],[2, 2, 2, 2],
[3, 3, 3, 3],
[4, 4, 4, 4],
[5, 5, 5, 5],
[6, 6, 6, 6]])
y
tensor([[4, 5, 6],[4, 5, 6],
[4, 5, 6],
[4, 5, 6]])
python
def make_anchors(feats, strides, grid_cell_offset=0.5):
"""Generate anchors from features."""
anchor_points, stride_tensor = [], []
assert feats is not None
dtype, device = feats[0].dtype, feats[0].device
for i, stride in enumerate(strides):
_, _, h, w = feats[i].shape
sx = torch.arange(end=w, device=device, dtype=dtype) + grid_cell_offset # shift x
sy = torch.arange(end=h, device=device, dtype=dtype) + grid_cell_offset # shift y
sy, sx = torch.meshgrid(sy, sx, indexing="ij") if TORCH_1_10 else torch.meshgrid(sy, sx)
anchor_points.append(torch.stack((sx, sy), -1).view(-1, 2))
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
return torch.cat(anchor_points), torch.cat(stride_tensor)
我们以第一层特征图生成的网格为例,得到的sy
,sx
如下:
python
sy: tensor([[ 0.5000, 0.5000, 0.5000, ..., 0.5000, 0.5000, 0.5000],
[ 1.5000, 1.5000, 1.5000, ..., 1.5000, 1.5000, 1.5000],
[ 2.5000, 2.5000, 2.5000, ..., 2.5000, 2.5000, 2.5000],
...,
[77.5000, 77.5000, 77.5000, ..., 77.5000, 77.5000, 77.5000],
[78.5000, 78.5000, 78.5000, ..., 78.5000, 78.5000, 78.5000],
[79.5000, 79.5000, 79.5000, ..., 79.5000, 79.5000, 79.5000]], device='cuda:0', dtype=torch.float16)
sx: tensor([[ 0.5000, 1.5000, 2.5000, ..., 77.5000, 78.5000, 79.5000],
[ 0.5000, 1.5000, 2.5000, ..., 77.5000, 78.5000, 79.5000],
[ 0.5000, 1.5000, 2.5000, ..., 77.5000, 78.5000, 79.5000],
...,
[ 0.5000, 1.5000, 2.5000, ..., 77.5000, 78.5000, 79.5000],
[ 0.5000, 1.5000, 2.5000, ..., 77.5000, 78.5000, 79.5000],
[ 0.5000, 1.5000, 2.5000, ..., 77.5000, 78.5000, 79.5000]], device='cuda:0', dtype=torch.float16)
生成的anchor-point
即为两者合并的结果,维度为(6400,2)
最终,将三个尺度的特征图产生的anchor
合并,得到:
python
anchor_point : torch.Size([8400, 2])
stride_tensor torch.Size([8400, 1]),其结果如下:
tensor([[ 8.], 8,16,32代表放大比例
[ 8.],
[ 8.],
...,
[32.],
[32.],
[32.]], device='cuda:0', dtype=torch.float16)
前处理过程(真值转换)
这个方法是将target
进行处理,初始的target
为(22,6)
转换为(batch,最大目标数量,5)
5即类别id+xywh
,当然最后还会将xywh
转换为坐标形式
python
def preprocess(self, targets, batch_size, scale_tensor):
"""Preprocesses the target counts and matches with the input batch size to output a tensor."""
nl, ne = targets.shape
if nl == 0:
out = torch.zeros(batch_size, 0, ne - 1, device=self.device)
else:
i = targets[:, 0] # image index
_, counts = i.unique(return_counts=True)#统计每张图像中目标的个数
counts = counts.to(dtype=torch.int32)
out = torch.zeros(batch_size, counts.max(), ne - 1, device=self.device)#根据最大数量最大的图像设置维度,即(batch,最大目标数量,5)5即类别id+xywh
for j in range(batch_size): # 遍历每个批次
matches = i == j # 匹配当前批次的目标
n = matches.sum() # 目标数量
if n:
out[j, :n] = targets[matches, 1:] # 填充目标
out[..., 1:5] = xywh2xyxy(out[..., 1:5].mul_(scale_tensor)) # 转换坐标
return out
随后拆分得到cls
与box
:
python
cls tensor([[[22.],
[22.],
[22.],
[22.],
[ 0.],
[ 0.],
[ 0.]],
[[45.],
[50.],
[45.],
[45.],
[49.],
[49.],
[49.]],
[[22.],
[58.],
[75.],
[58.],
[58.],
[75.],
[75.]],
[[58.],
[75.],
[23.],
[23.],
[ 0.],
[ 0.],
[ 0.]]], device='cuda:0')
box tensor([[[2.7307e+02, 2.7907e+02, 5.4232e+02, 5.1043e+02],
[2.7307e+02, 1.9674e+01, 5.4232e+02, 2.5103e+02],
[6.8230e-01, 1.9674e+01, 1.5169e+02, 2.4743e+02],
[6.8230e-01, 2.7907e+02, 1.5169e+02, 5.0683e+02],
[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]],
[[1.0442e+00, 1.5163e+02, 6.3999e+02, 5.4066e+02],
[1.8152e-01, 2.0824e+02, 3.4580e+02, 5.4177e+02],
[9.4192e+01, 8.0139e-02, 6.3958e+02, 3.8079e+02],
[3.3554e-01, 4.0424e-01, 1.9339e+02, 2.1323e+02],
[4.6191e+01, 2.7344e-02, 1.6055e+02, 9.2354e+01],
[1.2538e+02, 2.2156e-02, 1.6390e+02, 1.4374e+01],
[1.2584e+01, 2.2949e-02, 5.1544e+01, 1.2688e+01]],
[[1.5334e+02, 1.5542e+02, 4.2644e+02, 3.9008e+02],
[2.7954e+02, 4.2535e+02, 4.3731e+02, 6.2589e+02],
[2.9995e+02, 5.0253e+02, 4.0279e+02, 6.2340e+02],
[2.1976e-01, 4.3853e+02, 4.1101e+01, 5.4790e+02],
[2.1976e-01, 1.7356e+02, 4.1101e+01, 2.8293e+02],
[4.0710e-02, 2.3757e+02, 6.5855e+00, 2.8233e+02],
[4.0710e-02, 5.0253e+02, 6.5855e+00, 5.4730e+02]],
[[2.2732e+02, 1.2739e+02, 5.0236e+02, 4.7698e+02],
[2.8749e+02, 2.6194e+02, 4.6677e+02, 4.7264e+02],
[5.2375e+02, 1.8790e+01, 6.3991e+02, 7.6174e+01],
[1.2816e+02, 4.9530e-01, 2.5127e+02, 1.9529e+01],
[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]], device='cuda:0')
预测框解码
bbox_decode
函数的目的是从预测的物体边界框坐标分布(pred_dist)和参考点(anchor_points)解码出实际的边界框坐标xyxy。(此时pred_dist
为4,8400,64
)进行解码,得到(4,8400,4)
matmul方法为矩阵乘法,pred_dist在matmul之前, shape为[b, a, 4, 16],
self.proj的shape为[16], 最终的pred_dist的shape为[b, a, 4] 如果不理解可以直接使用 a =
torch.ones((1, 3, 4, 16)), 与b=torch.rand(16)进行matmul
python
def bbox_decode(self, anchor_points, pred_dist):
"""从锚点和分布预测中解码出预测的目标边界框坐标。
参数:
anchor_points (torch.Tensor): 锚点坐标,形状为 [num_anchors, 2]。
pred_dist (torch.Tensor): 预测的边界框分布,形状为 [batch_size, num_anchors, num_channels]。
返回:
torch.Tensor: 解码后的边界框坐标,形状为 [batch_size, num_anchors, 4]。
"""
if self.use_dfl:
b, a, c = pred_dist.shape # 获取 batch 大小、锚点数量和通道数
# 将预测分布变形为 [batch_size, num_anchors, 4, num_channels // 4],并在通道维度上应用 softmax
pred_dist = pred_dist.view(b, a, 4, c // 4).softmax(3).matmul(self.proj.type(pred_dist.dtype))#self.proj的值为tensor([ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.], device='cuda:0')
return dist2bbox(pred_dist, anchor_points, xywh=False)
TaskAlignedAssigner方法
该方法用于分配正负样本,即哪些锚点负责预测那个目标,没有分配到目标的锚点则认为预测的是背景。
TaskAlignedAssigner 的匹配策略简单总结为:根据分类与回归的分数加权的分数选择正样本。
- 计算真实框和预测框的匹配程度(分类与回归的分数加权的分数)。
其中,s
是预测类别分值,u
是预测框和真实框的ciou
值,α
和 β
为权重超参数,两者相乘就可以衡量匹配程度,当分类的分值越高且ciou越高时,align_metric
的值就越接近于1,此时预测框就与真实框越匹配,就越符合正样本的标准。
- 对于每个真实框,直接对
align_metric
匹配程度排序,选取topK
个预测框作为正样本。 - 对一个预测框与多个真实框匹配测情况进行处理,保留
ciou
值最大的真实框。
得到的结果如下:
-
target_labels, 这里用_代替了:形状为 [batch_size, num_anchors],包含每个锚点的目标标签
-
target_bboxes:形状为 [batch_size, num_anchors, 4],包含每个锚点的目标边界框
-
target_scores:形状为 [batch_size, num_anchors, num_classes],包含每个锚点的目标得分。
-
fg_mask:形状为 [batch_size, num_anchors],标识哪些锚点是前景(即有效的目标, 正样本)。
fg_mask作用:标识哪些锚点是前景(正样本),哪些是背景(负样本)。
正样本:锚点被分配给一个真实的目标,表示这个锚点负责检测这个目标。
负样本:锚点未被分配给任何目标,表示这个锚点不负责检测任何目标
-
target_gt_idx, 这里用_代替了:形状为 [batch_size, num_anchors],包含每个锚点对应的真实目标索引。
python
target_gt_idx的值如下:其行为8400个
tensor([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 6, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], device='cuda:0')
Box损失
Box
损失中包含IOU
损失和DEL
损失
传入的参数如下:
这里的pred_bboxes
是预测的检测框,pred_dist
则是用于DEL
损失计算,
python
def BoxLoss(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
"""IoU loss."""
weight = target_scores.sum(-1)[fg_mask].unsqueeze(-1)#(209,1)这个209代表匹配上的锚点数量
iou = bbox_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask], xywh=False, CIoU=True)#也是(209,1)只不过该值求了
loss_iou = ((1.0 - iou) * weight).sum() / target_scores_sum
# DFL loss
if self.dfl_loss:
target_ltrb = bbox2dist(anchor_points, target_bboxes, self.dfl_loss.reg_max - 1)
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
loss_dfl = loss_dfl.sum() / target_scores_sum
else:
loss_dfl = torch.tensor(0.0).to(pred_dist.device)
return loss_iou, loss_dfl
bbox2dist方法
在求dfl
损失时,有个bbox2dist
方法,传入的参数如下:
python
def bbox2dist(anchor_points, bbox, reg_max):
"""Transform bbox(xyxy) to dist(ltrb)."""
x1y1, x2y2 = bbox.chunk(2, -1) #将box的坐标拆分,x1y1为左上角坐标,x2y2为右上角坐标维度均为(4,8400,2)
return torch.cat((anchor_points - x1y1, x2y2 - anchor_points), -1).clamp_(0, reg_max - 0.01)
其将中心点坐标anchor_points
与左上、右下坐标做差并重写拼接,同时将值限制在0到15(准确的说是14.99)以内
python
torch.clamp(input, min, max, out=None)
作用:限幅。将input
的值限制在[min, max]
之间,并返回结果。
举例如下:
python
import torch
a = torch.arange(9).reshape(3, 3) # 创建3*3的tensor
b = torch.clamp(a, 3, 6) # 对a的值进行限幅,限制在[3, 6]
print('a:', a)
print('shape of a:', a.shape)
print('b:', b)
print('shape of b:', b.shape)
''' 输出结果 '''
a: tensor([[0, 1, 2],
[3, 4, 5],
[6, 7, 8]])
shape of a: torch.Size([3, 3])
b: tensor([[3, 3, 3],
[3, 4, 5],
[6, 6, 6]])
shape of b: torch.Size([3, 3])
那么,为何要这样做呢,我们知道将中心点坐标与左上右下的坐标(真值坐标)做差后得到的是宽高,将其限制在15以内就是说明其只负责30x30
以内的目标,而我们的特征图此时最大的为80x80
,这其实已将不小了,即认为图像中的目标应当都在这个范围内(大多数)毕竟其作用在DEL
损失中。
经过该方法处理后得到target_ltrb
,其维度仍为(4,8400,4)这个是真值的数据
随后便是计算DEL
损失了,其定义如下:
DFL损失
DFL
,全称Distribution Focal Loss
(分布焦点损失),很多人一听到Focal Loss
就立马想到分类,这没错,但DFL
却是用在边框回归中。这个损失用于求中心点坐标到上下左右四条边的距离。
其中,y为真值坐标,yi与yi+1是预测出的距离值,S是其对应的概率
由于DEL模块输出的数据为(4,8400,64)这个64=4x16,4即4条边,16则可认为是距离,这是一个分布,通过softmax函数求出的是概率,即距离为1到16的概率。
具体过程如下:
1、 模型先生成一个reg_max(默认为16)的概率统计分布,其对应得是{ 0 , 1 , . . . , 7 , 8 , . . . , 14 , 15 } ,该值就是模型预测出来的anchor points到bbox边的距离。假设模型预测出的结果pred_dist
{ 0.01 , 0.05 , . . . , 0.12 , 0.23 , . . . , 0.01 , 0.34 } 也就是anchor points
到bbox
边的距离为0的概率是0.01,距离为15的概率为0.34。
2、 然后使用上面介绍的检测头代码中的self.dfl
求出anchor points
到bbox
的距离的期望y
就是模型预测的最终的anchor points
到bbox
边的距离,这个期望最大是15
,也就是说模型预测出的anchor points
到bbox
边的距离最大是15
该距离的求法如下:
那么在预测过程中,这里是直接预测了,不需要再计算多个分布值了
而在训练时,可以按照如下理解:
如下图所示:
调用DELoss
方法,传入的参数如下:
python
loss_dfl = self.dfl_loss(pred_dist[fg_mask].view(-1, self.dfl_loss.reg_max), target_ltrb[fg_mask]) * weight
fg_mask
代表在TaskAlignedAssigne
r方法中匹配上的正样本,其维度为(4,8400),其值为False
或True
因此其会提取pred_dit(4,8400,64)
与target_ltrb(4,8400,4)
对应坐标内的数据,即在(4,8400)
这个维度提取,由于fg_mask
共有209
个,因此取出的值有209
个,同时通过view
对pred_dist
进行维度转换,即209x4=836
随后开始DFLoss
计算:
python
class DFLoss(nn.Module):
"""Criterion class for computing DFL losses during training."""
def __init__(self, reg_max=16) -> None:
"""Initialize the DFL module."""
super().__init__()
self.reg_max = reg_max
def __call__(self, pred_dist, target):
target = target.clamp_(0, self.reg_max - 1 - 0.01)#将target限制在0到14.99之间(这个已经做过了,只是为保险起见)
tl = target.long() # target left转换为整数
tr = tl + 1 # target right 就相当于y-yi了,因为已经将其整数化了
wl = tr - target # weight left
wr = 1 - wl # weight right
return (
F.cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) * wl
+ F.cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) * wr
).mean(-1, keepdim=True)
掩码损失(分割损失)
将masks从batch中取出,并
python
masks = batch["masks"].to(self.device).float()#(4,160,160)
if tuple(masks.shape[-2:]) != (mask_h, mask_w): # downsample 但不执行
masks = F.interpolate(masks[None], (mask_h, mask_w), mode="nearest")[0]
loss[1] = self.calculate_segmentation_loss(
fg_mask, masks, target_gt_idx, target_bboxes, batch_idx, proto, pred_masks, imgsz, self.overlap
)
匹配上的锚点(正样本)
掩膜(真值)这个值是经过处理下采样后
每个batch
中真值对应的类别
目标所属batch
真值预测框,这个是处理后的,为方便计算(转为了4,8400维)
预测的mask
single_mask_loss方法
该法用于求单张图片的分割损失。传入的参数如下:
这里的40
指的是匹配上的锚点数量(四个batch的锚点数量为40,69,100,40
)。可以看到,此时的pred_mask
(预测的mask)与真值mask都已经转换为相同的维度,即(40,160,160)
这里求两者损失使用的是binary_cross_entropy_with_logits
方法,即交叉熵损失函数,这里带着_with_logits
的原因是不需要将数据传入前使用sigmoid/softmax
映射到(0,1)
之间了。
python
def single_mask_loss(
gt_mask: torch.Tensor, pred: torch.Tensor, proto: torch.Tensor, xyxy: torch.Tensor, area: torch.Tensor
) -> torch.Tensor:
pred_mask = torch.einsum("in,nhw->ihw", pred, proto) # (n, 32) @ (32, 160, 160) -> (n, 160, 160) n即数量
loss = F.binary_cross_entropy_with_logits(pred_mask, gt_mask, reduction="none")
return (crop_mask(loss, xyxy).mean(dim=(1, 2)) / area).sum()
至此,完成了实例分割的损失计算
crop_mask
方法,这个方法是为了让mask
不要超界(不要超出box
)
python
def crop_mask(masks, boxes):
"""
It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box.
Args:
masks (torch.Tensor): [n, h, w] tensor of masks
boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form
Returns:
(torch.Tensor): The masks are being cropped to the bounding box.
"""
_, h, w = masks.shape
x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1) # x1 shape(n,1,1)
r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :] # rows shape(1,1,w)
c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None] # cols shape(1,h,1)
return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))
最终的总损失如下:
python
loss[0] *= self.hyp.box # box gain
loss[1] *= self.hyp.box # seg gain
loss[2] *= self.hyp.cls # cls gain
loss[3] *= self.hyp.dfl # dfl gain
模型参数更新
至此完成了损失计算过程。
随后,便是开始反向传播(这是个黑盒)
python
self.scaler.scale(self.loss).backward()
完成后,边进行模型参数的更新即可:
python
self.trainer.train()#这里是训练完成
# Update model and cfg after training
if RANK in {-1, 0}:
ckpt = self.trainer.best if self.trainer.best.exists() else self.trainer.last
self.model, _ = attempt_load_one_weight(ckpt)
self.overrides = self.model.args
self.metrics = getattr(self.trainer.validator, "metrics", None) # TODO: no metrics returned by DDP
return self.metrics
总结
本章梳理了预测结果与真值的损失计算过程,可以加深我们对模型训练的理解。