提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
文章目录
前言
提示:这里可以添加本文要记录的大概内容:
参考之前的博客,记录co-detr训练推理过程,将自己的数据集转化为coco格式便于CO-DETR训练,现在因为训练后AP不高,需要找到loss函数,保存最好的pth文件,还可以做出loss函数和epoch的图像,观察loss函数是不是在稳定下降还是说波动比较剧烈。
提示:以下是本篇文章正文内容,下面案例可供参考
寻找损失函数相关位置
在train.py里点进去train_detector
,然后再点进去 runner.run(data_loaders, cfg.workflow)
中的run,到的class EpochBasedRunner(BaseRunner):
函数,在train
函数里添加记录loss的代码,在run
函数最后面添加画图的代码,具体代码如下:
python
def train(self, data_loader, **kwargs):
best_loss = float('inf')
self.model.train()
self.mode = 'train'
self.data_loader = data_loader
self._max_iters = self._max_epochs * len(self.data_loader)
self.call_hook('before_train_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
epoch_loss = 0.0 # 每个 epoch 的损失初始化
for i, data_batch in enumerate(self.data_loader):
self.data_batch = data_batch
self._inner_iter = i
self.call_hook('before_train_iter')
self.run_iter(data_batch, train_mode=True, **kwargs)
self.call_hook('after_train_iter')
del self.data_batch
self._iter += 1
# 假设损失在 log_buffer 中被记录
if 'loss' in self.outputs['log_vars']:
epoch_loss += self.outputs['log_vars']['loss']
# 记录当前 epoch 的平均损失
avg_loss = epoch_loss / len(data_loader)
self.epoch_losses.append(avg_loss) # 添加到损失列表
if avg_loss is not None and avg_loss < best_loss:
best_loss = avg_loss
self.save_checkpoint(self.work_dir, 'best_model.pth')
self.call_hook('after_train_epoch')
self._epoch += 1
@torch.no_grad()
def val(self, data_loader, **kwargs):
self.model.eval()
self.mode = 'val'
self.data_loader = data_loader
self.call_hook('before_val_epoch')
time.sleep(2) # Prevent possible deadlock during epoch transition
for i, data_batch in enumerate(self.data_loader):
self.data_batch = data_batch
self._inner_iter = i
self.call_hook('before_val_iter')
self.run_iter(data_batch, train_mode=False)
self.call_hook('after_val_iter')
del self.data_batch
self.call_hook('after_val_epoch')
def run(self,
data_loaders: List[DataLoader],
workflow: List[Tuple[str, int]],
max_epochs: Optional[int] = None,
**kwargs) -> None:
"""Start running.
Args:
data_loaders (list[:obj:`DataLoader`]): Dataloaders for training
and validation.
workflow (list[tuple]): A list of (phase, epochs) to specify the
running order and epochs. E.g, [('train', 2), ('val', 1)] means
running 2 epochs for training and 1 epoch for validation,
iteratively.
"""
assert isinstance(data_loaders, list)
assert mmcv.is_list_of(workflow, tuple)
assert len(data_loaders) == len(workflow)
if max_epochs is not None:
warnings.warn(
'setting max_epochs in run is deprecated, '
'please set max_epochs in runner_config', DeprecationWarning)
self._max_epochs = max_epochs
assert self._max_epochs is not None, (
'max_epochs must be specified during instantiation')
for i, flow in enumerate(workflow):
mode, epochs = flow
if mode == 'train':
self._max_iters = self._max_epochs * len(data_loaders[i])
break
work_dir = self.work_dir if self.work_dir is not None else 'NONE'
self.logger.info('Start running, host: %s, work_dir: %s',
get_host_info(), work_dir)
self.logger.info('Hooks will be executed in the following order:\n%s',
self.get_hook_info())
self.logger.info('workflow: %s, max: %d epochs', workflow,
self._max_epochs)
self.call_hook('before_run')
while self.epoch < self._max_epochs:
for i, flow in enumerate(workflow):
mode, epochs = flow
if isinstance(mode, str): # self.train()
if not hasattr(self, mode):
raise ValueError(
f'runner has no method named "{mode}" to run an '
'epoch')
epoch_runner = getattr(self, mode)
else:
raise TypeError(
'mode in workflow must be a str, but got {}'.format(
type(mode)))
for _ in range(epochs):
if mode == 'train' and self.epoch >= self._max_epochs:
break
epoch_runner(data_loaders[i], **kwargs)
time.sleep(1) # wait for some hooks like loggers to finish
self.call_hook('after_run')
# 绘制损失图表
plt.figure(figsize=(10, 5))
plt.plot(range(1, len(self.epoch_losses) + 1), self.epoch_losses, marker='o', label='Training Loss')
plt.title('Training Loss per Epoch')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.grid()
plt.savefig(osp.join(self.work_dir, 'training_loss.png')) # 保存图像
plt.show() # 显示图像