CO-DETR追踪损失函数情况

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录


前言

提示:这里可以添加本文要记录的大概内容:

参考之前的博客,记录co-detr训练推理过程将自己的数据集转化为coco格式便于CO-DETR训练,现在因为训练后AP不高,需要找到loss函数,保存最好的pth文件,还可以做出loss函数和epoch的图像,观察loss函数是不是在稳定下降还是说波动比较剧烈。


提示:以下是本篇文章正文内容,下面案例可供参考

寻找损失函数相关位置

在train.py里点进去train_detector,然后再点进去 runner.run(data_loaders, cfg.workflow)中的run,到的class EpochBasedRunner(BaseRunner):函数,在train函数里添加记录loss的代码,在run函数最后面添加画图的代码,具体代码如下:

python 复制代码
  def train(self, data_loader, **kwargs):
        best_loss = float('inf')

        self.model.train()
        self.mode = 'train'
        self.data_loader = data_loader
        self._max_iters = self._max_epochs * len(self.data_loader)
        self.call_hook('before_train_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition

        epoch_loss = 0.0  # 每个 epoch 的损失初始化

        for i, data_batch in enumerate(self.data_loader):
            self.data_batch = data_batch
            self._inner_iter = i
            self.call_hook('before_train_iter')
            self.run_iter(data_batch, train_mode=True, **kwargs)

            
            self.call_hook('after_train_iter')
            del self.data_batch
            self._iter += 1

            # 假设损失在 log_buffer 中被记录
            if 'loss' in self.outputs['log_vars']:
                epoch_loss += self.outputs['log_vars']['loss']

                # 记录当前 epoch 的平均损失
        avg_loss = epoch_loss / len(data_loader)
        self.epoch_losses.append(avg_loss)  # 添加到损失列表
        if avg_loss is not None and avg_loss < best_loss:
            best_loss = avg_loss
            self.save_checkpoint(self.work_dir, 'best_model.pth')

        self.call_hook('after_train_epoch')
        self._epoch += 1

    @torch.no_grad()
    def val(self, data_loader, **kwargs):
        self.model.eval()
        self.mode = 'val'
        self.data_loader = data_loader
        self.call_hook('before_val_epoch')
        time.sleep(2)  # Prevent possible deadlock during epoch transition
        for i, data_batch in enumerate(self.data_loader):
            self.data_batch = data_batch
            self._inner_iter = i
            self.call_hook('before_val_iter')
            self.run_iter(data_batch, train_mode=False)
            self.call_hook('after_val_iter')
            del self.data_batch
        self.call_hook('after_val_epoch')

    def run(self,
            data_loaders: List[DataLoader],
            workflow: List[Tuple[str, int]],
            max_epochs: Optional[int] = None,
            **kwargs) -> None:
        """Start running.

        Args:
            data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
                and validation.
            workflow (list[tuple]): A list of (phase, epochs) to specify the
                running order and epochs. E.g, [('train', 2), ('val', 1)] means
                running 2 epochs for training and 1 epoch for validation,
                iteratively.
        """
        assert isinstance(data_loaders, list)
        assert mmcv.is_list_of(workflow, tuple)
        assert len(data_loaders) == len(workflow)
        if max_epochs is not None:
            warnings.warn(
                'setting max_epochs in run is deprecated, '
                'please set max_epochs in runner_config', DeprecationWarning)
            self._max_epochs = max_epochs

        assert self._max_epochs is not None, (
            'max_epochs must be specified during instantiation')

        for i, flow in enumerate(workflow):
            mode, epochs = flow
            if mode == 'train':
                self._max_iters = self._max_epochs * len(data_loaders[i])
                break

        work_dir = self.work_dir if self.work_dir is not None else 'NONE'
        self.logger.info('Start running, host: %s, work_dir: %s',
                         get_host_info(), work_dir)
        self.logger.info('Hooks will be executed in the following order:\n%s',
                         self.get_hook_info())
        self.logger.info('workflow: %s, max: %d epochs', workflow,
                         self._max_epochs)
        self.call_hook('before_run')

        while self.epoch < self._max_epochs:
            for i, flow in enumerate(workflow):
                mode, epochs = flow
                if isinstance(mode, str):  # self.train()
                    if not hasattr(self, mode):
                        raise ValueError(
                            f'runner has no method named "{mode}" to run an '
                            'epoch')
                    epoch_runner = getattr(self, mode)
                else:
                    raise TypeError(
                        'mode in workflow must be a str, but got {}'.format(
                            type(mode)))

                for _ in range(epochs):
                    if mode == 'train' and self.epoch >= self._max_epochs:
                        break
                    epoch_runner(data_loaders[i], **kwargs)

        time.sleep(1)  # wait for some hooks like loggers to finish
        self.call_hook('after_run')
        # 绘制损失图表
        plt.figure(figsize=(10, 5))
        plt.plot(range(1, len(self.epoch_losses) + 1), self.epoch_losses, marker='o', label='Training Loss')
        plt.title('Training Loss per Epoch')
        plt.xlabel('Epochs')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid()
        plt.savefig(osp.join(self.work_dir, 'training_loss.png'))  # 保存图像
        plt.show()  # 显示图像
相关推荐
新智元4 分钟前
刚刚,Figure 03 惊天登场!四年狂造 10 万台,人类保姆集体失业
人工智能·openai
万猫学社6 分钟前
我们为什么需要Agent?
人工智能
松果集31 分钟前
【1】数据类型2
python
CoovallyAIHub35 分钟前
告别等待!十条高效PyTorch数据增强流水线,让你的GPU不再"饥饿"
深度学习·算法·计算机视觉
且慢.58939 分钟前
命令行的学习使用技巧
python
共绩算力44 分钟前
OpenAI Whisper 语音识别模型:技术与应用全面分析
人工智能·whisper·语音识别·共绩算力
海琴烟Sunshine1 小时前
leetcode 66.加一 python
python·算法·leetcode
工藤学编程1 小时前
零基础学AI大模型之Stream流式输出实战
人工智能
不良人龍木木1 小时前
机器学习-常用库
人工智能·机器学习
罗橙7号1 小时前
【pyTorch】关于PyTorch的高级索引机制理解
人工智能·pytorch·python