关于gt_sampling的理解

pcdet/datasets/augmentor/data_augmentor.py

python 复制代码
    def gt_sampling(self, config=None):
        db_sampler = database_sampler.DataBaseSampler(
            root_path=self.root_path,
            sampler_cfg=config,
            class_names=self.class_names,
            logger=self.logger
        )
        return db_sampler

此函数指向DataBaseSampler类,单步调试运行到__call__函数

python 复制代码
    def __call__(self, data_dict):
        """
        Args:
            data_dict:
                gt_boxes: (N, 7 + C) [x, y, z, dx, dy, dz, heading, ...]

        Returns:

        """
        gt_boxes = data_dict['gt_boxes']
        gt_names = data_dict['gt_names'].astype(str)
        existed_boxes = gt_boxes
        total_valid_sampled_dict = []
        sampled_mv_height = []
        sampled_gt_boxes2d = []

        for class_name, sample_group in self.sample_groups.items():
            if self.limit_whole_scene:
                num_gt = np.sum(class_name == gt_names)
                sample_group['sample_num'] = str(int(self.sample_class_num[class_name]) - num_gt)
            if int(sample_group['sample_num']) > 0:
                #取15个groud truth标签
                sampled_dict = self.sample_with_fixed_number(class_name, sample_group)

                # 求出框的坐标信息
                sampled_boxes = np.stack([x['box3d_lidar'] for x in sampled_dict], axis=0).astype(np.float32)

                assert not self.sampler_cfg.get('DATABASE_WITH_FAKELIDAR', False), 'Please use latest codes to generate GT_DATABASE'

                # 碰撞检测
                #iou1针对的是当前样本里的groud truth与从实例库中采样得到的groud truth进行iou计算,iou如果不为0,则两个发生碰撞(现实情况不可能发生)
                iou1 = iou3d_nms_utils.boxes_bev_iou_cpu(sampled_boxes[:, 0:7], existed_boxes[:, 0:7])
                #iou2针对的是实例库中采样得到的groud truth彼此之间进行iou计算,iou如果不为0,则两个发生碰撞(现实情况不可能发生)
                iou2 = iou3d_nms_utils.boxes_bev_iou_cpu(sampled_boxes[:, 0:7], sampled_boxes[:, 0:7])
                iou2[range(sampled_boxes.shape[0]), range(sampled_boxes.shape[0])] = 0
                iou1 = iou1 if iou1.shape[1] > 0 else iou2
                #将iou1和iou2中值为0的标记为ture,满足现实情况
                valid_mask = ((iou1.max(axis=1) + iou2.max(axis=1)) == 0)

                if self.img_aug_type is not None:
                    sampled_boxes2d, mv_height, valid_mask = self.sample_gt_boxes_2d(data_dict, sampled_boxes, valid_mask)
                    sampled_gt_boxes2d.append(sampled_boxes2d)
                    if mv_height is not None:
                        sampled_mv_height.append(mv_height)

                valid_mask = valid_mask.nonzero()[0]
                valid_sampled_dict = [sampled_dict[x] for x in valid_mask]
                valid_sampled_boxes = sampled_boxes[valid_mask]

                existed_boxes = np.concatenate((existed_boxes, valid_sampled_boxes[:, :existed_boxes.shape[-1]]), axis=0)
                total_valid_sampled_dict.extend(valid_sampled_dict)

        sampled_gt_boxes = existed_boxes[gt_boxes.shape[0]:, :]

        if total_valid_sampled_dict.__len__() > 0:
            sampled_gt_boxes2d = np.concatenate(sampled_gt_boxes2d, axis=0) if len(sampled_gt_boxes2d) > 0 else None
            sampled_mv_height = np.concatenate(sampled_mv_height, axis=0) if len(sampled_mv_height) > 0 else None
            '''
            将采样的groud truth标签中的点云数据添加到当前样本点云场景中去
            sampled_gt_boxes采样的groud truth的坐标、大小、偏移角
            total_valid_sampled_dict采样的groud truth,包括其内部的点云数据存储位置,点云数目,bbox,即包括sampled_gt_boxes
            '''
            data_dict = self.add_sampled_boxes_to_scene(
                data_dict, sampled_gt_boxes, total_valid_sampled_dict, sampled_mv_height, sampled_gt_boxes2d
            )

        data_dict.pop('gt_boxes_mask')
        return data_dict

关于采样多少个实例库中的groud truth标签,在配置文件中设置

通过函数sample_with_fixed_number采样实例库中的标签

python 复制代码
   def sample_with_fixed_number(self, class_name, sample_group):
        """
        Args:
            class_name:
            sample_group:
        Returns:

        """
        sample_num, pointer, indices = int(sample_group['sample_num']), sample_group['pointer'], sample_group['indices']
        # 初次运行将pointer设置为0,索引随机打乱
        if pointer >= len(self.db_infos[class_name]):
            indices = np.random.permutation(len(self.db_infos[class_name]))
            pointer = 0

        #按照打乱后的索引顺序取sample_num个数,这里的sample_num设置为15
        sampled_dict = [self.db_infos[class_name][idx] for idx in indices[pointer: pointer + sample_num]]
        pointer += sample_num
        sample_group['pointer'] = pointer
        sample_group['indices'] = indices
        return sampled_dict

注意 :OpenPcDet的实例库在self.db_infos

这里以car类为例,每一个dict内存储了一个groud truth标签,一张图片可能有多个标签,也可能只有一个(可能没有吗?)

从实例库中取出15个groud truth标签后,就可以进行碰撞检测了


有效采样标签掩码(发生碰撞无效)

python 复制代码
valid_mask = ((iou1.max(axis=1) + iou2.max(axis=1)) == 0)

将有效标签数据和当前样本的groud truth标签数据拼接则已完成了数据增强

相关推荐
~山有木兮1 年前
关于PointHeadBox类的理解
openpcdet