CO-DETR追踪损失函数情况

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录


前言

提示:这里可以添加本文要记录的大概内容:

参考之前的博客,记录co-detr训练推理过程将自己的数据集转化为coco格式便于CO-DETR训练,现在因为训练后AP不高,需要找到loss函数,保存最好的pth文件,还可以做出loss函数和epoch的图像,观察loss函数是不是在稳定下降还是说波动比较剧烈。


提示:以下是本篇文章正文内容,下面案例可供参考

寻找损失函数相关位置

在train.py里点进去train_detector,然后再点进去 runner.run(data_loaders, cfg.workflow)中的run,到的class EpochBasedRunner(BaseRunner):函数,在train函数里添加记录loss的代码,在run函数最后面添加画图的代码,具体代码如下:

python 复制代码
  def train(self, data_loader, **kwargs):
        best_loss = float('inf')

        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._max_iters = self._max_epochs * len(self.data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition

        epoch_loss = 0.0  # 每个 epoch 的损失初始化

        for i, data_batch in enumerate(self.data_loader):
            self.data_batch = data_batch
            self._inner_iter = i
            self.call_hook('before_train_iter')
            self.run_iter(data_batch, train_mode=True, **kwargs)

            
            self.call_hook('after_train_iter')
            del self.data_batch
            self._iter += 1

            # 假设损失在 log_buffer 中被记录
            if 'loss' in self.outputs['log_vars']:
                epoch_loss += self.outputs['log_vars']['loss']

                # 记录当前 epoch 的平均损失
        avg_loss = epoch_loss / len(data_loader)
        self.epoch_losses.append(avg_loss)  # 添加到损失列表
        if avg_loss is not None and avg_loss < best_loss:
            best_loss = avg_loss
            self.save_checkpoint(self.work_dir, 'best_model.pth')

        self.call_hook('after_train_epoch')
        self._epoch += 1

    @torch.no_grad()
    def val(self, data_loader, **kwargs):
        self.model.eval()
        self.mode = 'val'
        self.data_loader = data_loader
        self.call_hook('before_val_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(self.data_loader):
            self.data_batch = data_batch
            self._inner_iter = i
            self.call_hook('before_val_iter')
            self.run_iter(data_batch, train_mode=False)
            self.call_hook('after_val_iter')
            del self.data_batch
        self.call_hook('after_val_epoch')

    def run(self,
            data_loaders: List[DataLoader],
            workflow: List[Tuple[str, int]],
            max_epochs: Optional[int] = None,
            **kwargs) -> None:
        """Start running.

        Args:
            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
                and validation.
            workflow (list[tuple]): A list of (phase, epochs) to specify the
                running order and epochs. E.g, [('train', 2), ('val', 1)] means
                running 2 epochs for training and 1 epoch for validation,
                iteratively.
        """
        assert isinstance(data_loaders, list)
        assert mmcv.is_list_of(workflow, tuple)
        assert len(data_loaders) == len(workflow)
        if max_epochs is not None:
            warnings.warn(
                'setting max_epochs in run is deprecated, '
                'please set max_epochs in runner_config', DeprecationWarning)
            self._max_epochs = max_epochs

        assert self._max_epochs is not None, (
            'max_epochs must be specified during instantiation')

        for i, flow in enumerate(workflow):
            mode, epochs = flow
            if mode == 'train':
                self._max_iters = self._max_epochs * len(data_loaders[i])
                break

        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        self.logger.info('Start running, host: %s, work_dir: %s',
                         get_host_info(), work_dir)
        self.logger.info('Hooks will be executed in the following order:\n%s',
                         self.get_hook_info())
        self.logger.info('workflow: %s, max: %d epochs', workflow,
                         self._max_epochs)
        self.call_hook('before_run')

        while self.epoch < self._max_epochs:
            for i, flow in enumerate(workflow):
                mode, epochs = flow
                if isinstance(mode, str):  # self.train()
                    if not hasattr(self, mode):
                        raise ValueError(
                            f'runner has no method named "{mode}" to run an '
                            'epoch')
                    epoch_runner = getattr(self, mode)
                else:
                    raise TypeError(
                        'mode in workflow must be a str, but got {}'.format(
                            type(mode)))

                for _ in range(epochs):
                    if mode == 'train' and self.epoch >= self._max_epochs:
                        break
                    epoch_runner(data_loaders[i], **kwargs)

        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_run')
        # 绘制损失图表
        plt.figure(figsize=(10, 5))
        plt.plot(range(1, len(self.epoch_losses) + 1), self.epoch_losses, marker='o', label='Training Loss')
        plt.title('Training Loss per Epoch')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid()
        plt.savefig(osp.join(self.work_dir, 'training_loss.png'))  # 保存图像
        plt.show()  # 显示图像
相关推荐
前进的李工5 分钟前
LangChain使用之Model IO(提示词模版之ChatPromptTemplate)
java·前端·人工智能·python·langchain·大模型
AIArchivist5 分钟前
深度解析|超级AI医院:不止是概念,更是医疗未来的确定性方向
人工智能·健康医疗
华农DrLai8 分钟前
什么是角色扮演Prompt?为什么给AI设定身份能提升表现?
人工智能·深度学习·ai·prompt·bert·transformer
大傻^9 分钟前
SpringAI2.0 向量存储生态:Redis、Amazon S3 与 Bedrock Knowledge Base 集成
数据库·人工智能·向量存储·springai
Fairy要carry11 分钟前
面试-Agent上下文过载、步骤混乱的问题
开发语言·python
咋个办呢12 分钟前
AI智能体自学打卡:一份非常全的 Markdown Prompt 模板(可做减法)
人工智能·ai·prompt·智能体
彷徨的蜗牛18 分钟前
智能AI自动化协同发文系统架构设计:从理论到实践的完整指南
人工智能·系统架构·自动化
许国栋_19 分钟前
B2B企业如何建设价值管理办公室(VMO)?实践与落地解析
人工智能·安全·云计算·产品经理
一RTOS一22 分钟前
从PLC到机器人:实时操作系统如何决定能力上限
人工智能·机器人·鸿道操作系统·鸿道实时操作系统·国产嵌入式操作系统选型·鸿道机器人操作系统
大傻^28 分钟前
Spring AI 2.0 企业级 RAG 架构:混合检索、重排序与多模态知识库
人工智能·spring·架构·多模态·rag·混合检索·重排序