学习记录6 增加速度

任务一:关于csv内容的补充

测试1:调低置信度是否能识别出来结果?

可以看到船被识别成了pier/bouy,猜想和训练集中的pier/bouy样式类似,所以检测效果很差

测试2:调大无人船的速度,是否还会识别为pier/bouy?

buoy的输出变大了10倍,说明雷达的点云和目标投影是正确的,但是识别结果依旧是错的

测试3:是否需要进行调整归一化?感觉实际的影响并不大

测试4:微调算法

复制代码
labelImg

打标签

修改utils_fit.py

复制代码
import os
import torch
from tqdm import tqdm
from utils.utils import get_lr
from loss.segmentation_loss import (CE_Loss, Dice_loss, Focal_Loss,
                                     weights_init)
from utils_seg.utils import get_lr
from utils_seg.utils_metrics import f_score
from loss.multitaskloss import HUncertainty
from loss.mgda import MGDA
from loss.pc_seg_loss import NllLoss
import torch.nn.functional as F


def fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, loss_history_seg, loss_history_seg_wl, loss_history_seg_pc, eval_callback, eval_callback_seg, eval_callback_seg_w, eval_callback_seg_pc, optimizer, epoch, epoch_step,
                  epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, dice_loss, focal_loss, cls_weights, cls_weights_wl, num_class_seg, local_rank=0, is_radar_pc_seg=False):
    total_loss_det = 0
    total_loss_seg = 0
    total_loss_seg_w = 0
    total_loss_seg_pc = 0
    total_f_score = 0
    total_f_score_w = 0

    val_loss_det = 0
    val_loss_seg = 0
    val_loss_seg_w = 0
    val_loss_seg_pc = 0
    val_f_score = 0
    val_f_score_w = 0

    total_loss = 0
    val_total_loss = 0

    if local_rank == 0:
        print('Start Train')
        pbar = tqdm(total=epoch_step, desc=f'Epoch {epoch + 1}/{Epoch}', postfix=dict, mininterval=0.3)
    model_train.train()
    for iteration, batch in enumerate(gen):
        if iteration >= epoch_step:
            break
        
        # ======================= [数据解包] =======================
        # radars (RVP图) 仍然被加载,但我们稍后在模型调用中会忽略它
        if is_radar_pc_seg:
            # 当 is_radar_pc_seg=True 时,dataloader 返回10个元素,最后一个是 pc_evidence_maps
            images, targets, radars, pngs, pngs_w, seg_labels, seg_w_labels, radar_pc_features, radar_pc_labels, pc_evidence_maps = \
                batch[0], batch[1], batch[2], batch[3], batch[4], batch[5], batch[6], batch[7], batch[8], batch[9]

        else:
            # 当 is_radar_pc_seg=False 时,dataloader 返回8个元素,最后一个是 pc_evidence_maps
            images, targets, radars, pngs, pngs_w, seg_labels, seg_w_labels, pc_evidence_maps = batch[0], batch[1], batch[2], batch[3], \
                                                                              batch[4], batch[5], batch[6], batch[7]
        # ========================================================================

        with torch.no_grad():
            weights = torch.from_numpy(cls_weights)
            weights_wl = torch.from_numpy(cls_weights_wl)

            if cuda:
                images = images.cuda(local_rank)
                targets = [ann.cuda(local_rank) for ann in targets]
                # radars = radars.cuda(local_rank) # 忽略 radars
                pngs = pngs.cuda(local_rank)
                pngs_w = pngs_w.cuda(local_rank)
                seg_labels = seg_labels.cuda(local_rank)
                seg_w_labels = seg_w_labels.cuda(local_rank)
                weights = weights.cuda(local_rank)
                weights_wl = weights_wl.cuda(local_rank)
                pc_evidence_maps = pc_evidence_maps.cuda(local_rank)
                if is_radar_pc_seg:
                    radar_pc_features = radar_pc_features.cuda(local_rank)
                    radar_pc_labels = radar_pc_labels.cuda(local_rank)

        optimizer.zero_grad()
        if not fp16:
            # ======================= [消融修改 1/4]: 移除传入模型的 radars (非 FP16) =======================
            if is_radar_pc_seg:
                # 原始: outputs, ..., outputs_seg_pc = model_train(images, radars, radar_pc_features, pc_evidence_maps)
                outputs, outputs_seg, outputs_seg_w, outputs_seg_pc = model_train(images, radar_pc_features, pc_evidence_maps)
                loss_pc_seg = F.nll_loss(
                        F.log_softmax(outputs_seg_pc, dim=1).permute(0, 2, 1),
                        radar_pc_labels.squeeze(-1),
                        ignore_index=-1
                    )
            else:
                # 原始: outputs, outputs_seg, outputs_seg_w = model_train(images, radars, pc_evidence_maps)
                outputs, outputs_seg, outputs_seg_w = model_train(images, pc_evidence_maps)
            # ===============================================================================

            if focal_loss:
                loss_seg = Focal_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg)
                loss_seg_w = Focal_Loss(outputs_seg_w, pngs_w, weights_wl, num_classes=2)
            else:
                loss_seg = CE_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg)
                loss_seg_w = CE_Loss(outputs_seg_w, pngs_w, weights_wl, num_classes=2)

            if dice_loss:
                main_dice = Dice_loss(outputs_seg, seg_labels)
                main_dice_w = Dice_loss(outputs_seg_w, seg_w_labels)
                loss_seg = loss_seg + main_dice
                logg_seg_w = loss_seg_w + main_dice_w

            loss_det = yolo_loss(outputs, targets)
            mtl = HUncertainty(task_num=3)
            if is_radar_pc_seg:
                losses = [loss_seg, logg_seg_w, loss_det]
                total_loss = mtl(losses) + loss_pc_seg
            else:
                total_loss = mtl(loss_seg, logg_seg_w, loss_det)

            with torch.no_grad():
                train_f_score = f_score(outputs_seg, seg_labels)
                train_f_score_w = f_score(outputs_seg_w, seg_w_labels)

            total_loss.backward()
            optimizer.step()
        else:
            from torch.cuda.amp import autocast
            with autocast():
                # ======================= [消融修改 2/4]: 移除传入模型的 radars (FP16) =======================
                if is_radar_pc_seg:
                    # 原始: outputs, ..., outputs_seg_pc = model_train(images, radars, radar_pc_features, pc_evidence_maps)
                    outputs, outputs_seg, outputs_seg_w, outputs_seg_pc = model_train(images, radar_pc_features, pc_evidence_maps)
                    loss_pc_seg = F.nll_loss(
                            F.log_softmax(outputs_seg_pc, dim=1).permute(0, 2, 1),
                            radar_pc_labels.squeeze(-1),
                            ignore_index=-1
                        )                   
                else:
                    # 原始: outputs, outputs_seg, outputs_seg_w = model_train(images, radars, pc_evidence_maps)
                    outputs, outputs_seg, outputs_seg_w = model_train(images, pc_evidence_maps)
                # ===============================================================================

                if focal_loss:
                    loss_seg = Focal_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg)
                    loss_seg_w = Focal_Loss(outputs_seg_w, pngs_w, weights_wl, num_classes=2)
                else:
                    loss_seg = CE_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg)
                    loss_seg_w = CE_Loss(outputs_seg_w, pngs_w, weights_wl, num_classes=2)

                if dice_loss:
                    main_dice = Dice_loss(outputs_seg, seg_labels)
                    main_dice_w = Dice_loss(outputs_seg_w, seg_w_labels)
                    loss_seg = loss_seg + main_dice
                    logg_seg_w = loss_seg_w + main_dice_w

                loss_det = yolo_loss(outputs, targets)
                mtl = HUncertainty(task_num=3)
                if is_radar_pc_seg:
                    losses = [loss_seg, logg_seg_w, loss_det]
                    total_loss = mtl(losses) + loss_pc_seg
                else:
                    total_loss = loss_seg + logg_seg_w + loss_det

                with torch.no_grad():
                    train_f_score = f_score(outputs_seg, seg_labels)
                    train_f_score_w = f_score(outputs_seg_w, seg_w_labels)

            scaler.scale(total_loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model_train.parameters(), max_norm=10.0)
            scaler.step(optimizer)
            scaler.update()

        if ema:
            ema.update(model_train)
        
        total_loss_det += loss_det.item()
        total_loss_seg += loss_seg.item()
        total_loss_seg_w += loss_seg_w.item()
        if is_radar_pc_seg:
            total_loss_seg_pc += loss_pc_seg.item()
        total_loss += total_loss_det + total_loss_seg + total_loss_seg_w + total_loss_seg_pc
        total_f_score += train_f_score.item()
        total_f_score_w += train_f_score_w.item()

        if local_rank == 0:
            if is_radar_pc_seg:
                pbar.set_postfix(**{'detection loss': total_loss_det / (iteration + 1),
                                'se seg loss': total_loss_seg / (iteration + 1),
                                'wl seg loss': total_loss_seg_w / (iteration + 1),
                                'pc seg loss': total_loss_seg_pc / (iteration + 1),
                                'total loss': total_loss / (iteration + 1),
                                'f score se': total_f_score / (iteration + 1),
                                'f score wl': total_f_score_w / (iteration + 1),
                                'lr': get_lr(optimizer)})
            else:
                pbar.set_postfix(**{'detection loss': total_loss_det / (iteration + 1),
                                    'se seg loss': total_loss_seg / (iteration + 1),
                                    'wl seg loss': total_loss_seg_w / (iteration + 1),
                                    'total loss': total_loss / (iteration + 1),
                                    'f score se': total_f_score / (iteration + 1),
                                    'f score wl': total_f_score_w / (iteration + 1),
                                    'lr': get_lr(optimizer)})
            pbar.update(1)

    if local_rank == 0:
        pbar.close()
        print('Finish Train')
        print('Start Validation')
        pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}', postfix=dict, mininterval=0.3)

    if ema:
        model_train_eval = ema.ema
    else:
        model_train_eval = model_train.eval()

    for iteration, batch in enumerate(gen_val):
        if iteration >= epoch_step_val:
            break
        
        # ======================= [数据解包 - 验证] =======================
        if is_radar_pc_seg:
            images, targets, radars, pngs, pngs_w, seg_labels, seg_w_labels, radar_pc_features, radar_pc_labels, pc_evidence_maps = \
                batch[0], batch[1], batch[2], batch[3], batch[4], batch[5], batch[6], batch[7], batch[8], batch[9]

        else:
            images, targets, radars, pngs, pngs_w, seg_labels, seg_w_labels, pc_evidence_maps = batch[0], batch[1], batch[2], batch[3], \
                                                                              batch[4], batch[5], batch[6], batch[7]
        # ======================================================================================
        with torch.no_grad():
            if cuda:
                images = images.cuda(local_rank)
                targets = [ann.cuda(local_rank) for ann in targets]
                # radars = radars.cuda(local_rank) # 忽略 radars
                pngs = pngs.cuda(local_rank)
                pngs_w = pngs_w.cuda(local_rank)
                seg_labels = seg_labels.cuda(local_rank)
                seg_w_labels = seg_w_labels.cuda(local_rank)
                weights = weights.cuda(local_rank)
                pc_evidence_maps = pc_evidence_maps.cuda(local_rank)
                if is_radar_pc_seg:
                    radar_pc_features = radar_pc_features.cuda(local_rank)
                    radar_pc_labels = radar_pc_labels.cuda(local_rank)
            
            optimizer.zero_grad()
            
            # ======================= [消融修改 3/4]: 移除传入模型的 radars (验证) =======================
            if is_radar_pc_seg:
                # 原始: outputs, ..., outputs_seg_pc = model_train_eval(images, radars, radar_pc_features, pc_evidence_maps)
                outputs, outputs_seg, outputs_seg_w, outputs_seg_pc = model_train_eval(images, radar_pc_features, pc_evidence_maps)
                loss_pc_seg = F.nll_loss(
                        F.log_softmax(outputs_seg_pc, dim=1).permute(0, 2, 1),
                        radar_pc_labels.squeeze(-1),
                        ignore_index=-1
                    )                
            else:
                # 原始: outputs, outputs_seg, outputs_seg_w = model_train_eval(images, radars, pc_evidence_maps)
                outputs, outputs_seg, outputs_seg_w = model_train_eval(images, pc_evidence_maps)
            # ======================================================================================

            if focal_loss:
                loss_seg = Focal_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg)
                loss_seg_w = Focal_Loss(outputs_seg_w, pngs_w, weights_wl, num_classes=2)
            else:
                loss_seg = CE_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg)
                loss_seg_w = CE_Loss(outputs_seg_w, pngs_w, weights_wl, num_classes=2)

            if dice_loss:
                main_dice = Dice_loss(outputs_seg, seg_labels)
                main_dice_w = Dice_loss(outputs_seg_w, seg_w_labels)
                loss_seg = loss_seg + main_dice
                loss_seg_w = loss_seg_w + main_dice_w
            
            _f_score = f_score(outputs_seg, seg_labels)
            _f_score_w = f_score(outputs_seg_w, seg_w_labels)
            
            loss_value = yolo_loss(outputs, targets)
            loss_value_seg = loss_seg
            loss_value_seg_w = loss_seg_w
            if is_radar_pc_seg:
                loss_value_seg_pc = loss_pc_seg
            val_f_score += _f_score.item()
            val_f_score_w += _f_score_w.item()

        val_loss_det += loss_value.item()
        val_loss_seg += loss_value_seg.item()
        val_loss_seg_w += loss_value_seg_w.item()
        if is_radar_pc_seg:
            val_loss_seg_pc += loss_value_seg_pc.item()
        val_total_loss = val_loss_det + val_loss_seg + val_loss_seg_w + val_loss_seg_pc
        
        if local_rank == 0:
            if is_radar_pc_seg:
                pbar.set_postfix(**{'detection val_loss': val_loss_det / (iteration + 1),
                                    'se seg val_loss': val_loss_seg / (iteration + 1),
                                    'wl seg val_loss': val_loss_seg_w / (iteration + 1),
                                    'pc seg val_loss': val_loss_seg_pc / (iteration + 1),
                                    'val loss': val_total_loss / (iteration + 1),
                                    'f_score se': val_f_score / (iteration + 1),
                                    'f_score wl': val_f_score_w / (iteration + 1),
                                    })
            else:
                pbar.set_postfix(**{'detection val_loss': val_loss_det / (iteration + 1),
                                    'se seg val_loss': val_loss_seg / (iteration + 1),
                                    'wl seg val_loss': val_loss_seg_w / (iteration + 1),
                                    'val loss': val_total_loss / (iteration + 1),
                                    'f_score se': val_f_score / (iteration + 1),
                                    'f_score wl': val_f_score_w / (iteration + 1),
                                    })
            pbar.update(1)
            
    # ... (后续日志记录、模型保存等代码保持不变) ...
    if local_rank == 0:
        pbar.close()
        print('Finish Validation')
        loss_history.append_loss(epoch + 1, total_loss_det / epoch_step, val_loss_det / epoch_step_val)
        loss_history_seg.append_loss(epoch + 1, total_loss_seg / epoch_step, val_loss_seg / epoch_step_val)
        loss_history_seg_wl.append_loss(epoch + 1, total_loss_seg_w / epoch_step, val_loss_seg_w / epoch_step_val)
        if is_radar_pc_seg:
            loss_history_seg_pc.append_loss(epoch + 1, total_loss_seg_pc / epoch_step, val_loss_seg_pc / epoch_step_val)
        eval_callback.on_epoch_end(epoch + 1, model_train_eval)
        eval_callback_seg.on_epoch_end(epoch + 1, model_train_eval)
        eval_callback_seg_w.on_epoch_end(epoch + 1, model_train_eval)
        if is_radar_pc_seg:
            eval_callback_seg_pc.on_epoch_end(epoch + 1, model_train_eval)
        print('Epoch:' + str(epoch + 1) + '/' + str(Epoch))
        if is_radar_pc_seg:
            print(
                'Total Loss: %.3f || Val Loss Det: %.3f  || Val Loss Seg: %.3f || Val Loss Seg L: %.3f || Val Loss Seg PC: %.3f' % (
                (total_loss / epoch_step,
                 val_loss_det / epoch_step_val,
                 val_loss_seg / epoch_step_val,
                 val_loss_seg_w / epoch_step_val,
                 val_loss_seg_pc / epoch_step_val)))
        else:
            print(
                'Total Loss: %.3f || Val Loss Det: %.3f  || Val Loss Seg: %.3f || Val Loss Seg L: %.3f' % (
                    (total_loss / epoch_step,
                     val_loss_det / epoch_step_val,
                     val_loss_seg / epoch_step_val,
                     val_loss_seg_w / epoch_step_val,
                     )))

        if ema:
            save_state_dict = ema.ema.state_dict()
        else:
            save_state_dict = model.state_dict()

        if is_radar_pc_seg:
            if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
                torch.save(save_state_dict, os.path.join(save_dir,
                                                         "ep%03d-loss%.3f-det_val_loss%.3f-seg_val_loss%.3f-seg_wl_val_loss%.3f-seg_pc_val_loss%.3f.pth" % (
                                                             epoch + 1, val_total_loss / epoch_step,
                                                             val_loss_det / epoch_step_val,
                                                             val_loss_seg / epoch_step_val,
                                                             val_loss_seg_w / epoch_step_val,
                                                             val_loss_seg_pc / epoch_step_val)))

            if len(loss_history.val_loss) <= 1 or (val_total_loss / epoch_step_val) <= min(loss_history.val_loss) + min(
                    loss_history_seg.val_loss):
                print('Save best model to best_epoch_weights.pth')
                torch.save(save_state_dict, os.path.join(save_dir, "best_epoch_weights.pth"))

        else:
            if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
                torch.save(save_state_dict, os.path.join(save_dir,
                                                         "ep%03d-loss%.3f-det_val_loss%.3f-seg_val_loss%.3f-seg_wl_val_loss%.3f.pth" % (
                                                             epoch + 1, val_total_loss / epoch_step,
                                                             val_loss_det / epoch_step_val,
                                                             val_loss_seg / epoch_step_val,
                                                             val_loss_seg_w / epoch_step_val)))

            if len(loss_history.val_loss) <= 1 or (val_total_loss / epoch_step_val) <= min(loss_history.val_loss) + min(
                    loss_history_seg.val_loss):
                print('Save best model to best_epoch_weights.pth')
                torch.save(save_state_dict, os.path.join(save_dir, "best_epoch_weights.pth"))
                torch.save(model_train, os.path.join(save_dir, "best_epoch_weights.pt"))

        torch.save(save_state_dict, os.path.join(save_dir, "last_epoch_weights.pth"))
        torch.save(model_train, os.path.join(save_dir, "last_epoch_weights.pt"))

修改后代码:

复制代码
import os
import torch
from tqdm import tqdm
from utils.utils import get_lr
from loss.segmentation_loss import (CE_Loss, Dice_loss, Focal_Loss,
                                     weights_init)
from utils_seg.utils import get_lr
from utils_seg.utils_metrics import f_score
from loss.multitaskloss import HUncertainty
from loss.mgda import MGDA
from loss.pc_seg_loss import NllLoss
import torch.nn.functional as F


def fit_one_epoch(model_train, model, ema, yolo_loss, loss_history, loss_history_seg, loss_history_seg_wl, loss_history_seg_pc, eval_callback, eval_callback_seg, eval_callback_seg_w, eval_callback_seg_pc, optimizer, epoch, epoch_step,
                  epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, dice_loss, focal_loss, cls_weights, cls_weights_wl, num_class_seg, local_rank=0, is_radar_pc_seg=False):
    total_loss_det = 0
    total_loss_seg = 0
    total_loss_seg_w = 0
    total_loss_seg_pc = 0
    total_f_score = 0
    total_f_score_w = 0

    val_loss_det = 0
    val_loss_seg = 0
    val_loss_seg_w = 0
    val_loss_seg_pc = 0
    val_f_score = 0
    val_f_score_w = 0

    total_loss = 0
    val_total_loss = 0

    if local_rank == 0:
        print('Start Train')
        pbar = tqdm(total=epoch_step, desc=f'Epoch {epoch + 1}/{Epoch}', postfix=dict, mininterval=0.3)
    model_train.train()
    for iteration, batch in enumerate(gen):
        if iteration >= epoch_step:
            break
        
        # ======================= [数据解包] =======================
        if is_radar_pc_seg:
            images, targets, radars, pngs, pngs_w, seg_labels, seg_w_labels, radar_pc_features, radar_pc_labels, pc_evidence_maps = \
                batch[0], batch[1], batch[2], batch[3], batch[4], batch[5], batch[6], batch[7], batch[8], batch[9]

        else:
            images, targets, radars, pngs, pngs_w, seg_labels, seg_w_labels, pc_evidence_maps = batch[0], batch[1], batch[2], batch[3], \
                                                                              batch[4], batch[5], batch[6], batch[7]
        # ========================================================================

        with torch.no_grad():
            weights = torch.from_numpy(cls_weights)
            weights_wl = torch.from_numpy(cls_weights_wl)

            if cuda:
                images = images.cuda(local_rank)
                targets = [ann.cuda(local_rank) for ann in targets]
                pngs = pngs.cuda(local_rank)
                pngs_w = pngs_w.cuda(local_rank)
                seg_labels = seg_labels.cuda(local_rank)
                seg_w_labels = seg_w_labels.cuda(local_rank)
                weights = weights.cuda(local_rank)
                weights_wl = weights_wl.cuda(local_rank)
                pc_evidence_maps = pc_evidence_maps.cuda(local_rank)
                if is_radar_pc_seg:
                    radar_pc_features = radar_pc_features.cuda(local_rank)
                    radar_pc_labels = radar_pc_labels.cuda(local_rank)

        optimizer.zero_grad()
        if not fp16:
            if is_radar_pc_seg:
                outputs, outputs_seg, outputs_seg_w, outputs_seg_pc = model_train(images, radar_pc_features, pc_evidence_maps)
                loss_pc_seg = F.nll_loss(
                        F.log_softmax(outputs_seg_pc, dim=1).permute(0, 2, 1),
                        radar_pc_labels.squeeze(-1),
                        ignore_index=-1
                    )
            else:
                outputs, outputs_seg, outputs_seg_w = model_train(images, pc_evidence_maps)

            if focal_loss:
                loss_seg = Focal_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg)
                loss_seg_w = Focal_Loss(outputs_seg_w, pngs_w, weights_wl, num_classes=2)
            else:
                loss_seg = CE_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg)
                loss_seg_w = CE_Loss(outputs_seg_w, pngs_w, weights_wl, num_classes=2)

            if dice_loss:
                main_dice = Dice_loss(outputs_seg, seg_labels)
                main_dice_w = Dice_loss(outputs_seg_w, seg_w_labels)
                loss_seg = loss_seg + main_dice
                logg_seg_w = loss_seg_w + main_dice_w

            loss_det = yolo_loss(outputs, targets)
            
            # =================================================================================
            # 【微调核心修改 1/3】: 训练期屏蔽分割任务,总 Loss 仅由目标检测决定
            # =================================================================================
            loss_seg = loss_seg * 0.0
            logg_seg_w = logg_seg_w * 0.0
            
            if is_radar_pc_seg:
                loss_pc_seg = loss_pc_seg * 0.0
                total_loss = loss_det  # 抛弃 MTL,完全只看检测 Loss
            else:
                total_loss = loss_det  # 抛弃 MTL,完全只看检测 Loss
            # =================================================================================

            with torch.no_grad():
                train_f_score = f_score(outputs_seg, seg_labels)
                train_f_score_w = f_score(outputs_seg_w, seg_w_labels)

            total_loss.backward()
            optimizer.step()
        else:
            from torch.cuda.amp import autocast
            with autocast():
                if is_radar_pc_seg:
                    outputs, outputs_seg, outputs_seg_w, outputs_seg_pc = model_train(images, radar_pc_features, pc_evidence_maps)
                    loss_pc_seg = F.nll_loss(
                            F.log_softmax(outputs_seg_pc, dim=1).permute(0, 2, 1),
                            radar_pc_labels.squeeze(-1),
                            ignore_index=-1
                        )                   
                else:
                    outputs, outputs_seg, outputs_seg_w = model_train(images, pc_evidence_maps)

                if focal_loss:
                    loss_seg = Focal_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg)
                    loss_seg_w = Focal_Loss(outputs_seg_w, pngs_w, weights_wl, num_classes=2)
                else:
                    loss_seg = CE_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg)
                    loss_seg_w = CE_Loss(outputs_seg_w, pngs_w, weights_wl, num_classes=2)

                if dice_loss:
                    main_dice = Dice_loss(outputs_seg, seg_labels)
                    main_dice_w = Dice_loss(outputs_seg_w, seg_w_labels)
                    loss_seg = loss_seg + main_dice
                    logg_seg_w = loss_seg_w + main_dice_w

                loss_det = yolo_loss(outputs, targets)
                
                # =================================================================================
                # 【微调核心修改 2/3】: FP16模式下,同样屏蔽分割任务
                # =================================================================================
                loss_seg = loss_seg * 0.0
                logg_seg_w = logg_seg_w * 0.0
                
                if is_radar_pc_seg:
                    loss_pc_seg = loss_pc_seg * 0.0
                    total_loss = loss_det  # 抛弃 MTL
                else:
                    total_loss = loss_det  # 抛弃 MTL
                # =================================================================================

                with torch.no_grad():
                    train_f_score = f_score(outputs_seg, seg_labels)
                    train_f_score_w = f_score(outputs_seg_w, seg_w_labels)

            scaler.scale(total_loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model_train.parameters(), max_norm=10.0)
            scaler.step(optimizer)
            scaler.update()

        if ema:
            ema.update(model_train)
        
        total_loss_det += loss_det.item()
        total_loss_seg += loss_seg.item()
        total_loss_seg_w += logg_seg_w.item() if dice_loss else loss_seg_w.item()
        if is_radar_pc_seg:
            total_loss_seg_pc += loss_pc_seg.item()
        total_loss += total_loss_det + total_loss_seg + total_loss_seg_w + total_loss_seg_pc
        total_f_score += train_f_score.item()
        total_f_score_w += train_f_score_w.item()

        if local_rank == 0:
            if is_radar_pc_seg:
                pbar.set_postfix(**{'detection loss': total_loss_det / (iteration + 1),
                                'se seg loss': total_loss_seg / (iteration + 1),
                                'wl seg loss': total_loss_seg_w / (iteration + 1),
                                'pc seg loss': total_loss_seg_pc / (iteration + 1),
                                'total loss': total_loss / (iteration + 1),
                                'f score se': total_f_score / (iteration + 1),
                                'f score wl': total_f_score_w / (iteration + 1),
                                'lr': get_lr(optimizer)})
            else:
                pbar.set_postfix(**{'detection loss': total_loss_det / (iteration + 1),
                                    'se seg loss': total_loss_seg / (iteration + 1),
                                    'wl seg loss': total_loss_seg_w / (iteration + 1),
                                    'total loss': total_loss / (iteration + 1),
                                    'f score se': total_f_score / (iteration + 1),
                                    'f score wl': total_f_score_w / (iteration + 1),
                                    'lr': get_lr(optimizer)})
            pbar.update(1)

    if local_rank == 0:
        pbar.close()
        print('Finish Train')
        print('Start Validation')
        pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}', postfix=dict, mininterval=0.3)

    if ema:
        model_train_eval = ema.ema
    else:
        model_train_eval = model_train.eval()

    for iteration, batch in enumerate(gen_val):
        if iteration >= epoch_step_val:
            break
        
        if is_radar_pc_seg:
            images, targets, radars, pngs, pngs_w, seg_labels, seg_w_labels, radar_pc_features, radar_pc_labels, pc_evidence_maps = \
                batch[0], batch[1], batch[2], batch[3], batch[4], batch[5], batch[6], batch[7], batch[8], batch[9]

        else:
            images, targets, radars, pngs, pngs_w, seg_labels, seg_w_labels, pc_evidence_maps = batch[0], batch[1], batch[2], batch[3], \
                                                                              batch[4], batch[5], batch[6], batch[7]
        with torch.no_grad():
            if cuda:
                images = images.cuda(local_rank)
                targets = [ann.cuda(local_rank) for ann in targets]
                pngs = pngs.cuda(local_rank)
                pngs_w = pngs_w.cuda(local_rank)
                seg_labels = seg_labels.cuda(local_rank)
                seg_w_labels = seg_w_labels.cuda(local_rank)
                weights = weights.cuda(local_rank)
                pc_evidence_maps = pc_evidence_maps.cuda(local_rank)
                if is_radar_pc_seg:
                    radar_pc_features = radar_pc_features.cuda(local_rank)
                    radar_pc_labels = radar_pc_labels.cuda(local_rank)
            
            optimizer.zero_grad()
            
            if is_radar_pc_seg:
                outputs, outputs_seg, outputs_seg_w, outputs_seg_pc = model_train_eval(images, radar_pc_features, pc_evidence_maps)
                loss_pc_seg = F.nll_loss(
                        F.log_softmax(outputs_seg_pc, dim=1).permute(0, 2, 1),
                        radar_pc_labels.squeeze(-1),
                        ignore_index=-1
                    )                
            else:
                outputs, outputs_seg, outputs_seg_w = model_train_eval(images, pc_evidence_maps)

            if focal_loss:
                loss_seg = Focal_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg)
                loss_seg_w = Focal_Loss(outputs_seg_w, pngs_w, weights_wl, num_classes=2)
            else:
                loss_seg = CE_Loss(outputs_seg, pngs, weights, num_classes=num_class_seg)
                loss_seg_w = CE_Loss(outputs_seg_w, pngs_w, weights_wl, num_classes=2)

            if dice_loss:
                main_dice = Dice_loss(outputs_seg, seg_labels)
                main_dice_w = Dice_loss(outputs_seg_w, seg_w_labels)
                loss_seg = loss_seg + main_dice
                loss_seg_w = loss_seg_w + main_dice_w
            
            _f_score = f_score(outputs_seg, seg_labels)
            _f_score_w = f_score(outputs_seg_w, seg_w_labels)
            
            loss_value = yolo_loss(outputs, targets)
            
            # =================================================================================
            # 【微调核心修改 3/3】: 验证集的分割 Loss 也强制清零
            # 这是为了防止在保存"最好模型(Best Model)"时,受到无效分割 Loss 的干扰
            # =================================================================================
            loss_value_seg = loss_seg * 0.0
            loss_value_seg_w = loss_seg_w * 0.0
            if is_radar_pc_seg:
                loss_value_seg_pc = loss_pc_seg * 0.0
            
            # 验证集的总 Loss 现在完完全全只看目标检测(Bbox)的效果!
            if is_radar_pc_seg:
                val_total_loss_current = loss_value + loss_value_seg + loss_value_seg_w + loss_value_seg_pc
            else:
                val_total_loss_current = loss_value + loss_value_seg + loss_value_seg_w
            # =================================================================================

            val_f_score += _f_score.item()
            val_f_score_w += _f_score_w.item()

        val_loss_det += loss_value.item()
        val_loss_seg += loss_value_seg.item()
        val_loss_seg_w += loss_value_seg_w.item()
        if is_radar_pc_seg:
            val_loss_seg_pc += loss_value_seg_pc.item()
        val_total_loss = val_loss_det + val_loss_seg + val_loss_seg_w + val_loss_seg_pc if is_radar_pc_seg else val_loss_det + val_loss_seg + val_loss_seg_w
        
        if local_rank == 0:
            if is_radar_pc_seg:
                pbar.set_postfix(**{'detection val_loss': val_loss_det / (iteration + 1),
                                    'se seg val_loss': val_loss_seg / (iteration + 1),
                                    'wl seg val_loss': val_loss_seg_w / (iteration + 1),
                                    'pc seg val_loss': val_loss_seg_pc / (iteration + 1),
                                    'val loss': val_total_loss / (iteration + 1),
                                    'f_score se': val_f_score / (iteration + 1),
                                    'f_score wl': val_f_score_w / (iteration + 1),
                                    })
            else:
                pbar.set_postfix(**{'detection val_loss': val_loss_det / (iteration + 1),
                                    'se seg val_loss': val_loss_seg / (iteration + 1),
                                    'wl seg val_loss': val_loss_seg_w / (iteration + 1),
                                    'val loss': val_total_loss / (iteration + 1),
                                    'f_score se': val_f_score / (iteration + 1),
                                    'f_score wl': val_f_score_w / (iteration + 1),
                                    })
            pbar.update(1)
            
    if local_rank == 0:
        pbar.close()
        print('Finish Validation')
        loss_history.append_loss(epoch + 1, total_loss_det / epoch_step, val_loss_det / epoch_step_val)
        loss_history_seg.append_loss(epoch + 1, total_loss_seg / epoch_step, val_loss_seg / epoch_step_val)
        loss_history_seg_wl.append_loss(epoch + 1, total_loss_seg_w / epoch_step, val_loss_seg_w / epoch_step_val)
        if is_radar_pc_seg:
            loss_history_seg_pc.append_loss(epoch + 1, total_loss_seg_pc / epoch_step, val_loss_seg_pc / epoch_step_val)
        eval_callback.on_epoch_end(epoch + 1, model_train_eval)
        eval_callback_seg.on_epoch_end(epoch + 1, model_train_eval)
        eval_callback_seg_w.on_epoch_end(epoch + 1, model_train_eval)
        if is_radar_pc_seg:
            eval_callback_seg_pc.on_epoch_end(epoch + 1, model_train_eval)
        print('Epoch:' + str(epoch + 1) + '/' + str(Epoch))
        if is_radar_pc_seg:
            print(
                'Total Loss: %.3f || Val Loss Det: %.3f  || Val Loss Seg: %.3f || Val Loss Seg L: %.3f || Val Loss Seg PC: %.3f' % (
                (total_loss / epoch_step,
                 val_loss_det / epoch_step_val,
                 val_loss_seg / epoch_step_val,
                 val_loss_seg_w / epoch_step_val,
                 val_loss_seg_pc / epoch_step_val)))
        else:
            print(
                'Total Loss: %.3f || Val Loss Det: %.3f  || Val Loss Seg: %.3f || Val Loss Seg L: %.3f' % (
                    (total_loss / epoch_step,
                     val_loss_det / epoch_step_val,
                     val_loss_seg / epoch_step_val,
                     val_loss_seg_w / epoch_step_val,
                     )))

        if ema:
            save_state_dict = ema.ema.state_dict()
        else:
            save_state_dict = model.state_dict()

        if is_radar_pc_seg:
            if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
                torch.save(save_state_dict, os.path.join(save_dir,
                                                         "ep%03d-loss%.3f-det_val_loss%.3f-seg_val_loss%.3f-seg_wl_val_loss%.3f-seg_pc_val_loss%.3f.pth" % (
                                                             epoch + 1, val_total_loss / epoch_step,
                                                             val_loss_det / epoch_step_val,
                                                             val_loss_seg / epoch_step_val,
                                                             val_loss_seg_w / epoch_step_val,
                                                             val_loss_seg_pc / epoch_step_val)))

            if len(loss_history.val_loss) <= 1 or (val_total_loss / epoch_step_val) <= min(loss_history.val_loss) + min(
                    loss_history_seg.val_loss):
                print('Save best model to best_epoch_weights.pth')
                torch.save(save_state_dict, os.path.join(save_dir, "best_epoch_weights.pth"))

        else:
            if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
                torch.save(save_state_dict, os.path.join(save_dir,
                                                         "ep%03d-loss%.3f-det_val_loss%.3f-seg_val_loss%.3f-seg_wl_val_loss%.3f.pth" % (
                                                             epoch + 1, val_total_loss / epoch_step,
                                                             val_loss_det / epoch_step_val,
                                                             val_loss_seg / epoch_step_val,
                                                             val_loss_seg_w / epoch_step_val)))

            if len(loss_history.val_loss) <= 1 or (val_total_loss / epoch_step_val) <= min(loss_history.val_loss) + min(
                    loss_history_seg.val_loss):
                print('Save best model to best_epoch_weights.pth')
                torch.save(save_state_dict, os.path.join(save_dir, "best_epoch_weights.pth"))
                torch.save(model_train, os.path.join(save_dir, "best_epoch_weights.pt"))

        torch.save(save_state_dict, os.path.join(save_dir, "last_epoch_weights.pth"))
        torch.save(model_train, os.path.join(save_dir, "last_epoch_weights.pt"))

修改dataloader.py

1、修改分割为全黑假图

2、注释npz的部分

3、修改labels的索引方式,置0

训练到5就开始上升了

Epoch 4/30: 100%|█| 2/2 [00:05<00:00, 2.82s/it, detection val_loss=1.6, f_score se=0.369, f_score wl=0.5, pc seg val_loss=0, se seg val_loss

Finish Validation

Epoch:4/30

Total Loss: 1.642 || Val Loss Det: 1.605 || Val Loss Seg: 0.000 || Val Loss Seg L: 0.000 || Val Loss Seg PC: 0.000

Save best model to best_epoch_weights.pth

Start Train

Epoch 5/30: 100%|█| 18/18 [01:52<00:00, 6.28s/it, detection loss=1.36, f score se=0.289, f score wl=0.5, lr=0.00049, pc seg loss=0, se seg l

Finish Train

Start Validation

Epoch 5/30: 100%|█| 2/2 [00:05<00:00, 2.79s/it, detection val_loss=1.73, f_score se=0.313, f_score wl=0.5, pc seg val_loss=0, se seg val_los

Finish Validation

任务二:预测

相关推荐
風清掦2 小时前
【江科大STM32学习笔记-10】I2C通信协议 - 10.2 硬件 I2C 读写MPU6050
笔记·stm32·单片机·嵌入式硬件·学习
峥嵘life2 小时前
Android + Kiro AI软件开发实战教程
android·后端·学习
Engineer邓祥浩2 小时前
JVM学习笔记(10) 第三部分 虚拟机执行子系统 第9章 类加载及执行子系统的案例与实战
jvm·笔记·学习
自信150413057592 小时前
重生之从0开始学习c++之内存管理
c++·学习
m0_716765233 小时前
数据结构--单链表的插入、删除、查找详解
c语言·开发语言·数据结构·c++·笔记·学习·visual studio
_李小白3 小时前
【OSG学习笔记】Day 53: Text3D( 三维文字)
笔记·学习·3d
CompaqCV3 小时前
OpencvSharp 算子学习教案之 - Cv2.Subtract 重载1
学习·c#·opencvsharp算子·opencv教程
zhangrelay3 小时前
智能时代机器人工程师・云原生 + 大模型 + 智能体 全栈成长计划(2026 版)
笔记·学习
华阙之梦4 小时前
【GIS课堂】
学习