CO-DETR追踪损失函数情况

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录


前言

提示:这里可以添加本文要记录的大概内容:

参考之前的博客,记录co-detr训练推理过程将自己的数据集转化为coco格式便于CO-DETR训练,现在因为训练后AP不高,需要找到loss函数,保存最好的pth文件,还可以做出loss函数和epoch的图像,观察loss函数是不是在稳定下降还是说波动比较剧烈。


提示:以下是本篇文章正文内容,下面案例可供参考

寻找损失函数相关位置

在train.py里点进去train_detector,然后再点进去 runner.run(data_loaders, cfg.workflow)中的run,到的class EpochBasedRunner(BaseRunner):函数,在train函数里添加记录loss的代码,在run函数最后面添加画图的代码,具体代码如下:

python 复制代码
  def train(self, data_loader, **kwargs):
        best_loss = float('inf')

        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._max_iters = self._max_epochs * len(self.data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition

        epoch_loss = 0.0  # 每个 epoch 的损失初始化

        for i, data_batch in enumerate(self.data_loader):
            self.data_batch = data_batch
            self._inner_iter = i
            self.call_hook('before_train_iter')
            self.run_iter(data_batch, train_mode=True, **kwargs)

            
            self.call_hook('after_train_iter')
            del self.data_batch
            self._iter += 1

            # 假设损失在 log_buffer 中被记录
            if 'loss' in self.outputs['log_vars']:
                epoch_loss += self.outputs['log_vars']['loss']

                # 记录当前 epoch 的平均损失
        avg_loss = epoch_loss / len(data_loader)
        self.epoch_losses.append(avg_loss)  # 添加到损失列表
        if avg_loss is not None and avg_loss < best_loss:
            best_loss = avg_loss
            self.save_checkpoint(self.work_dir, 'best_model.pth')

        self.call_hook('after_train_epoch')
        self._epoch += 1

    @torch.no_grad()
    def val(self, data_loader, **kwargs):
        self.model.eval()
        self.mode = 'val'
        self.data_loader = data_loader
        self.call_hook('before_val_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(self.data_loader):
            self.data_batch = data_batch
            self._inner_iter = i
            self.call_hook('before_val_iter')
            self.run_iter(data_batch, train_mode=False)
            self.call_hook('after_val_iter')
            del self.data_batch
        self.call_hook('after_val_epoch')

    def run(self,
            data_loaders: List[DataLoader],
            workflow: List[Tuple[str, int]],
            max_epochs: Optional[int] = None,
            **kwargs) -> None:
        """Start running.

        Args:
            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
                and validation.
            workflow (list[tuple]): A list of (phase, epochs) to specify the
                running order and epochs. E.g, [('train', 2), ('val', 1)] means
                running 2 epochs for training and 1 epoch for validation,
                iteratively.
        """
        assert isinstance(data_loaders, list)
        assert mmcv.is_list_of(workflow, tuple)
        assert len(data_loaders) == len(workflow)
        if max_epochs is not None:
            warnings.warn(
                'setting max_epochs in run is deprecated, '
                'please set max_epochs in runner_config', DeprecationWarning)
            self._max_epochs = max_epochs

        assert self._max_epochs is not None, (
            'max_epochs must be specified during instantiation')

        for i, flow in enumerate(workflow):
            mode, epochs = flow
            if mode == 'train':
                self._max_iters = self._max_epochs * len(data_loaders[i])
                break

        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        self.logger.info('Start running, host: %s, work_dir: %s',
                         get_host_info(), work_dir)
        self.logger.info('Hooks will be executed in the following order:\n%s',
                         self.get_hook_info())
        self.logger.info('workflow: %s, max: %d epochs', workflow,
                         self._max_epochs)
        self.call_hook('before_run')

        while self.epoch < self._max_epochs:
            for i, flow in enumerate(workflow):
                mode, epochs = flow
                if isinstance(mode, str):  # self.train()
                    if not hasattr(self, mode):
                        raise ValueError(
                            f'runner has no method named "{mode}" to run an '
                            'epoch')
                    epoch_runner = getattr(self, mode)
                else:
                    raise TypeError(
                        'mode in workflow must be a str, but got {}'.format(
                            type(mode)))

                for _ in range(epochs):
                    if mode == 'train' and self.epoch >= self._max_epochs:
                        break
                    epoch_runner(data_loaders[i], **kwargs)

        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_run')
        # 绘制损失图表
        plt.figure(figsize=(10, 5))
        plt.plot(range(1, len(self.epoch_losses) + 1), self.epoch_losses, marker='o', label='Training Loss')
        plt.title('Training Loss per Epoch')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid()
        plt.savefig(osp.join(self.work_dir, 'training_loss.png'))  # 保存图像
        plt.show()  # 显示图像
相关推荐
明明真系叻14 分钟前
2025.4.20机器学习笔记:文献阅读
人工智能·笔记·机器学习
学术小八35 分钟前
2025年机电一体化、机器人与人工智能国际学术会议(MRAI 2025)
人工智能·机器人·机电
爱的叹息39 分钟前
关于 雷达(Radar) 的详细解析,涵盖其定义、工作原理、分类、关键技术、应用场景、挑战及未来趋势,结合实例帮助理解其核心概念
人工智能·分类·数据挖掘
许泽宇的技术分享41 分钟前
.NET MCP 文档
人工智能·.net
_x_w1 小时前
【17】数据结构之图及图的存储篇章
数据结构·python·算法·链表·排序算法·图论
anscos1 小时前
Actran声源识别方法连载(二):薄膜模态表面振动识别
人工智能·算法·仿真软件·actran
pianmian11 小时前
arcgis几何与游标(1)
开发语言·python
-曾牛1 小时前
【LangChain4j快速入门】5分钟用Java玩转GPT-4o-mini,Spring Boot整合实战!| 附源码
java·开发语言·人工智能·spring boot·ai·chatgpt
token-go1 小时前
[特殊字符] KoalaAI 1.0.23 震撼升级:GPT-4.1免费畅享,AI革命触手可及!
人工智能
冬天vs不冷1 小时前
SpringBoot条件注解全解析:核心作用与使用场景详解
java·spring boot·python