AOT源码解析4.4 -decoder生成预测mask并计算loss

3、生成ref_imgs的预测mask和loss

这一步在训练阶段调用

3.1 数据处理

图1,如图1所示,将enc_embs的最后一个比例的特征图和有ref_imgs相关的特征图得到的LSTT特征图相拼接作为输入

python 复制代码
        curr_enc_embs = self.curr_enc_embs
        curr_lstt_embs = self.curr_lstt_output[0]

        pred_id_logits = self.AOT.decode_id_logits(curr_lstt_embs,
                                                   curr_enc_embs)

3.2 Decoder结构

图2, decoder的操作步骤如图,该解码器将enc_embs各个比例的特征图结合到一起

  • Decoder结构
python 复制代码
class FPNSegmentationHead(nn.Module):
    def __init__(self,
                 in_dim,
                 out_dim,
                 decode_intermediate_input=True,
                 hidden_dim=256,
                 shortcut_dims=[24, 32, 96, 1280],
                 align_corners=True):
        super().__init__()
        self.align_corners = align_corners

        self.decode_intermediate_input = decode_intermediate_input

        self.conv_in = ConvGN(in_dim, hidden_dim, 1)

        self.conv_16x = ConvGN(hidden_dim, hidden_dim, 3)
        self.conv_8x = ConvGN(hidden_dim, hidden_dim // 2, 3)
        self.conv_4x = ConvGN(hidden_dim // 2, hidden_dim // 2, 3)

        self.adapter_16x = nn.Conv2d(shortcut_dims[-2], hidden_dim, 1)
        self.adapter_8x = nn.Conv2d(shortcut_dims[-3], hidden_dim, 1)
        self.adapter_4x = nn.Conv2d(shortcut_dims[-4], hidden_dim // 2, 1)

        self.conv_out = nn.Conv2d(hidden_dim // 2, out_dim, 1)

        self._init_weight()

    def forward(self, inputs, shortcuts):

        if self.decode_intermediate_input:
            x = torch.cat(inputs, dim=1)
        else:
            x = inputs[-1]

        x = F.relu_(self.conv_in(x))
        s1 = self.adapter_16x(shortcuts[-2])
        x = F.relu_(self.conv_16x(self.adapter_16x(shortcuts[-2]) + x))

        x = F.interpolate(x,
                          size=shortcuts[-3].size()[-2:],
                          mode="bilinear",
                          align_corners=self.align_corners)
        x = F.relu_(self.conv_8x(self.adapter_8x(shortcuts[-3]) + x))

        x = F.interpolate(x,
                          size=shortcuts[-4].size()[-2:],
                          mode="bilinear",
                          align_corners=self.align_corners)
        x = F.relu_(self.conv_4x(self.adapter_4x(shortcuts[-4]) + x))

        x = self.conv_out(x)

        return x

3.3 计算loss

  • 对Decoder输出的结果按照对象数量进行分隔
python 复制代码
        pred_id_logits = self.pred_id_logits

        pred_id_logits = F.interpolate(pred_id_logits,
                                       size=gt_mask.size()[-2:],
                                       mode="bilinear",
                                       align_corners=self.align_corners)

        label_list = []
        logit_list = []
        for batch_idx, obj_num in enumerate(self.obj_nums):
            now_label = gt_mask[batch_idx].long()
            now_logit = pred_id_logits[batch_idx, :(obj_num + 1)].unsqueeze(0)
            label_list.append(now_label.long())
            logit_list.append(now_logit)
  • 计算loss

在深度学习中,尤其是在图像相关的任务(如图像分割)中,我们通常有大量的像素需要预测。在这种情况下,可能并不是所有的像素对最终的任务都同样重要。

例如,模型可能已经能够很好地预测图像的大部分区域,但是对于一些难以区分的区域(如物体边缘或小物体)预测得不够好。这些难以预测的区域可能正是模型需要关注的重点。

为了使模型更加关注这些难以预测的区域,可以采用一种称为"硬例挖掘"(hard example mining)的技术。这种方法的基本思想是,不是对所有的像素平均地计算损失,而是只关注那些损失最大的像素。

通过这种方式,模型的训练可以更加集中在那些难以正确预测的像素上,从而提高模型的整体性能。具体来说,"top k percent pixels" 指的是按照损失值从高到低排序后,选取前 k 百分比的像素。例如,如果 k 设置为 50%,那么在损失计算中,只会考虑损失最大的前 50% 的像素。

在代码中,这通常是通过以下步骤实现的:

  • 计算所有像素的损失。
  • 根据损失值对像素进行排序。
  • 选择损失值最高的前 k 百分比的像素。
  • 只计算这些选定像素的损失,并将它们加起来作为最终的损失。
python 复制代码
class CrossEntropyLoss(nn.Module):
    def __init__(self,
                 top_k_percent_pixels=None,
                 hard_example_mining_step=100000):
        super(CrossEntropyLoss, self).__init__()
        self.top_k_percent_pixels = top_k_percent_pixels
        if top_k_percent_pixels is not None:
            assert (top_k_percent_pixels > 0 and top_k_percent_pixels < 1)
        self.hard_example_mining_step = hard_example_mining_step + 1e-5
        if self.top_k_percent_pixels is None:
            self.celoss = nn.CrossEntropyLoss(ignore_index=255,
                                              reduction='mean')
        else:
            self.celoss = nn.CrossEntropyLoss(ignore_index=255,
                                              reduction='none')


    def forward(self, dic_tmp, y, step):
        total_loss = []
        for i in range(len(dic_tmp)):
            pred_logits = dic_tmp[i]
            gts = y[i]
            if self.top_k_percent_pixels is None:
                final_loss = self.celoss(pred_logits, gts)
            else:
                # Only compute the loss for top k percent pixels.
                # First, compute the loss for all pixels. Note we do not put the loss
                # to loss_collection and set reduction = None to keep the shape.
                num_pixels = float(pred_logits.size(2) * pred_logits.size(3))
                pred_logits = pred_logits.view(
                    -1, pred_logits.size(1),
                    pred_logits.size(2) * pred_logits.size(3))
                gts = gts.view(-1, gts.size(1) * gts.size(2))
                pixel_losses = self.celoss(pred_logits, gts)
                if self.hard_example_mining_step == 0:
                    top_k_pixels = int(self.top_k_percent_pixels * num_pixels)
                else:
                    ratio = min(1.0,
                                step / float(self.hard_example_mining_step))
                    top_k_pixels = int((ratio * self.top_k_percent_pixels +
                                        (1.0 - ratio)) * num_pixels)
                top_k_loss, top_k_indices = torch.topk(pixel_losses,
                                                       k=top_k_pixels,
                                                       dim=1)

                final_loss = torch.mean(top_k_loss)
            final_loss = final_loss.unsqueeze(0)
            total_loss.append(final_loss)
        total_loss = torch.cat(total_loss, dim=0)
        return total_loss
相关推荐
Jason-河山1 分钟前
Java爬虫抓取数据的艺术
java·爬虫·python
桥田智能5 分钟前
工博会动态 | 来8.1馆 看桥田如何玩转全场
人工智能·机器人·自动化
深度学习的奋斗者5 分钟前
YOLOv8+注意力机制+PyQt5玉米病害检测系统完整资源集合
python·深度学习·yolo
—你的鼬先生6 分钟前
从零开始使用树莓派debian系统使用opencv4.10.0进行人脸识别(保姆级教程)
python·opencv·debian·人脸识别·二维码识别·opencv安装
Eric.Lee202121 分钟前
数据集-目标检测系列-口罩检测数据集 mask>> DataBall
人工智能·目标检测·计算机视觉·数据集·口罩检测
界面开发小八哥29 分钟前
如何用LightningChart Python实现地震强度数据可视化应用程序?
开发语言·python·信息可视化
ChinaZ.AI32 分钟前
ComfyUI 速度更快,显存占用更低的图像反推模型Florence2PromptGen,效果媲美JoyCaption,还支持Flux训练打标
人工智能·stable diffusion·aigc·flux·comfyui·florence
小强在此1 小时前
机器学习【教育领域及其平台搭建】
人工智能·学习·机器学习·团队开发·教育领域·机器
憨憨憨憨憨到不行的程序员1 小时前
【OpenCV】场景中人的识别与前端计数
人工智能·opencv·计算机视觉
剑指~巅峰1 小时前
亲身体验Llama 3.1:开源模型的部署与应用之旅
人工智能·深度学习·opencv·机器学习·计算机视觉·数据挖掘·llama