SIRA-PCR: Sim-to-Real Adaptation for 3D Point Cloud Registration 论文解读

目录

一、导言

[二、 相关工作](#二、 相关工作)

1、三维点云配准工作

2、无监督域适应

三、SIRA-PCR

1、FlyingShape数据集

2、Sim-to-real自适应方法

3、配准

4、损失函数


一、导言

该论文来自于ICCV2023,论文提出了一种新的方法SIRA-PCR,通过利用合成数据FlyingShapes解决现有数据稀缺问题。

数据稀缺原因:现有的基于数据驱动的深度学习方法一般依赖于两类数据,一是单一物体级别如ModelNet40,ShapeNet,二是对于室内场景水平如3DMatch。虽然单一物体级别数据集有较强的几何形状,但很难推广到真实的室内场景,室内场景的数据集来训练性能很好,但捕获真实的室内场景十分耗时且估计的相机姿态存在错误标签。所以针对场景级的合成数据集仍然数据稀缺,所以本文创建了一个合成场景数据集FlyingShapes。

(1)构建了第一个大规模的室内合成数据集FlyingShapes,使用基于物理和随机的策略将ShapeNet对象插入3D-FRONT场景中。

(2)设计了一个称为SIRA的管道,包括一个自适应重采样模块,目的是缓解合成数据和真实数据之间的域差距(分布差异),从而提高点云配准的性能。

二、 相关工作

1、三维点云配准工作

一种是直接配准,一种是基于对应的方法(与以前的博文类似)

2、无监督域适应

无监督自适应是一种解决源域(合成数据)与目标域(真实数据)之间分布差异的方法,旨在利用源域的标注数据和目标域的无标注数据,学习一个可以将源域数据映射到目标域的模型,从而提高目标域任务的性能,通过这样的方式,利用合成数据集来解决真实数据集标注不足的问题。

对于以往的点云配准任务来说,一般通过利用GAN来执行域对抗训练,使得两个域的点云特征难以被鉴别器区分。还有方法通过自监督学习里解决域对齐。

三、SIRA-PCR

SIRA-PCR框架,首先通过SIRA结构基于FlyingShapes数据集来训练合成数据到真实数据的域自适应,之后基于3DMatch或3DLoMatch(实验工作)使用GeoTransformer来训练配准工作。

**SIRA-PCR框架的目的:**就是通过先训练一遍我们改进的合成数据集FlyingShape,来学习到一些真实数据集中应该有的特征,这样优化了真实数据集中由于标注错误,数据样本少而训练不足的问题,其实从本质上就是增加更多的样本量,完全可以把做合成数据集到合成2真实的域适应的工作看成一种增加样本量的方式。

1、FlyingShapes数据集

3D-FRONT数据集是一个由专业设计师设计的室内场景和家具布局的场景数据集,但由于结构过于简单,FlyingShape数据集从3D-FRONT场景数据集中添加一定的家具模型来模拟真实的场景。

FlyingShapes数据集考虑了三个因素进行优化:

(1)几何增强

由于现实场景中家具摆放不一定合理,可能存在一定的随机性。所以使用两种方式添加对象数据集(单一对象),分别是基于物理意义的(重力要求),随机放置(无视重力)。

由于3D-FRONT数据集中存在大量的地板、墙壁、天花板等简单的平面,对结构的泛化效果有限,从而也使得网络过分关注这些简单的平面结构,所以以50%的概率去除这些平坦平面,进而提高对于对象数据集中物体的几何结构平衡。

(2)高质量的视角选择

视角选择着重考虑,点云数据集中,每一团点云结构应该包含较为足够数量的对象(大于5个对象结构),并且视角保持人类视角,且移动保持人类转动行进的速度(设置高度为1.6m,水平视角360°,垂直方向仰角0到45°,以30°或15°来均匀采样视图)

(3)数据准备

为模拟RGB-D摄像机的深度信息,我们通过虚拟摄像机绘制深度图并转换为点云,保留重叠范围在30%以上的点云,提高模型的泛化能力。

2、Sim-to-real自适应方法

这一部分就是做无监督域适应工作,其实本质来说也是一个GAN网络。

生成器(ResampleGAN):

生成器采用Encoder-Decoder结构,Encoder部分是ResampleKPConvEncoder结构(KPConv+ARM自适应重采样模块+FPN,用于提取特征),Decoder部分是MLPDecoder结构(本质是三层1维卷积通过LeakyReLU激活函数相连,用于将特征转换为三维坐标)

Encoder部分又可以看做一个KPConv与FPN的结合,并且每一层都会添加一个ARM重采样结构。KPConv负责从顶到底提取特征,FPN部分负责多层次的特征再次提取。

**ARM重采样:**Adaptive Re-sample Module,实现了基于注意力机制的点的局部重采样,利用点的特征和邻点的信息来计算一个点的加权平均坐标,保证了每一步的卷积操作输出后都有针对邻点特征的优化。

**ARM算法:**输入所有点的坐标(坐标矩阵),构建每个点周围的一个局部patch,并得到该点的特征和该点邻点特征,并通过加权邻点特征和该点特征,得到新的坐标点,作为调整点位置。

对于论文图3的解释如下:

我们举点云中任何一个点为例子,实际是直接用点云的坐标矩阵来进行下面计算。

对于一个点的特征(d维列向量)与周围K个点的邻点特征(K*d维矩阵)的转置相乘,并除以进行归一化,之后通过Softmax函数得到权重系数(K维列向量),乘以K个邻近点坐标矩阵 (K*3维矩阵)得到重采样点

公式解释: (其中,就是将点运算转换为矩阵运算,可忽略看)

生成器代码如下:

python 复制代码
#重采样代码
class PatchResampleBlock(nn.Module):

    def __init__(self, feat_channels) -> None:
        r"""Initialize a patch resample block.

        Args:
            feat_channels: dimension of input features
        """
        super(PatchResampleBlock, self).__init__()
        self.feat_channels = feat_channels
        self.feat_proj = nn.Linear(self.feat_channels, self.feat_channels)

    def forward(self, points, feats, neighbor_indices):
        point_num, neighbor_limit = neighbor_indices.shape

        # adjust neighbor_indices (stand still when the patch has few neighbors)
        point_indices = torch.arange(point_num,
                                     device=neighbor_indices.device).reshape(
                                         (point_num, 1)).repeat(
                                             (1, neighbor_limit))  # (N ,K)
        neighbor_indices = torch.where(neighbor_indices < point_num,
                                       neighbor_indices, point_indices)

        # feature projection
        feats = self.feat_proj(feats)
        neighbor_feats = feats[neighbor_indices]  # (N, K, d)
        neighbor_points = points[neighbor_indices]  # (N, K, 3)

        neighbor_weights = torch.einsum(
            "nd,nkd->nk", feats,
            neighbor_feats)  # (N, d) x (N, K, d) -> (N, K)
        neighbor_weights = nn.functional.softmax(neighbor_weights /
                                                 self.feat_channels**0.5,
                                                 dim=-1)  # (N, K) -> (N, K)

        output_points = torch.einsum(
            "nk,nkp->np", neighbor_weights,
            neighbor_points)  # (N, K) x (N, K, 3) -> (N, 3)

        return output_points

#KPconv+FPN+ARM作为Encoder
class ResampleKPConvEncoder(nn.Module):

    def __init__(self, input_dim, output_dim, init_dim, kernel_size,
                 init_radius, init_sigma, group_norm):
        super(ResampleKPConvEncoder, self).__init__()

        self.encoder1_1 = ConvBlock(input_dim, init_dim, kernel_size,
                                    init_radius, init_sigma, group_norm)
        self.encoder1_2 = ResidualBlock(init_dim, init_dim * 2, kernel_size,
                                        init_radius, init_sigma, group_norm)

        self.encoder2_1 = ResidualBlock(init_dim * 2,
                                        init_dim * 2,
                                        kernel_size,
                                        init_radius,
                                        init_sigma,
                                        group_norm,
                                        strided=True)
        self.encoder2_2 = ResidualBlock(init_dim * 2, init_dim * 4,
                                        kernel_size, init_radius * 2,
                                        init_sigma * 2, group_norm)
        self.encoder2_3 = ResidualBlock(init_dim * 4, init_dim * 4,
                                        kernel_size, init_radius * 2,
                                        init_sigma * 2, group_norm)

        self.encoder3_1 = ResidualBlock(init_dim * 4,
                                        init_dim * 4,
                                        kernel_size,
                                        init_radius * 2,
                                        init_sigma * 2,
                                        group_norm,
                                        strided=True)
        self.encoder3_2 = ResidualBlock(init_dim * 4, init_dim * 8,
                                        kernel_size, init_radius * 4,
                                        init_sigma * 4, group_norm)
        self.encoder3_3 = ResidualBlock(init_dim * 8, init_dim * 8,
                                        kernel_size, init_radius * 4,
                                        init_sigma * 4, group_norm)

        self.encoder4_1 = ResidualBlock(init_dim * 8,
                                        init_dim * 8,
                                        kernel_size,
                                        init_radius * 4,
                                        init_sigma * 4,
                                        group_norm,
                                        strided=True)
        self.encoder4_2 = ResidualBlock(init_dim * 8, init_dim * 16,
                                        kernel_size, init_radius * 8,
                                        init_sigma * 8, group_norm)
        self.encoder4_3 = ResidualBlock(init_dim * 16, init_dim * 16,
                                        kernel_size, init_radius * 8,
                                        init_sigma * 8, group_norm)

        self.resample4 = PatchResampleBlock(feat_channels=init_dim * 16)
        self.decoder4 = UnaryBlock(init_dim * 16 + 3, init_dim * 16,
                                   group_norm)

        self.resample3 = PatchResampleBlock(feat_channels=init_dim * 8)
        self.decoder3 = UnaryBlock(init_dim * 24 + 3, init_dim * 8, group_norm)

        self.resample2 = PatchResampleBlock(feat_channels=init_dim * 4)
        self.decoder2 = UnaryBlock(init_dim * 12 + 3, init_dim * 4, group_norm)

        self.resample1 = PatchResampleBlock(feat_channels=init_dim * 2)
        self.decoder1 = UnaryBlock(init_dim * 6 + 3, output_dim, group_norm)

        self.outputlayer = LastUnaryBlock(output_dim, output_dim)

    def forward(self, data_dict):
        points_list = data_dict['points']
        neighbors_list = data_dict['neighbors']
        subsampling_list = data_dict['subsamples']
        upsampling_list = data_dict['upsamples']

        feats_s1 = torch.ones((points_list[0].shape[0], 1),
                              dtype=torch.float32).to(points_list[0])
        feats_s1 = self.encoder1_1(feats_s1, points_list[0], points_list[0],
                                   neighbors_list[0])
        feats_s1 = self.encoder1_2(feats_s1, points_list[0], points_list[0],
                                   neighbors_list[0])

        feats_s2 = feats_s1
        feats_s2 = self.encoder2_1(feats_s2, points_list[1], points_list[0],
                                   subsampling_list[0])
        feats_s2 = self.encoder2_2(feats_s2, points_list[1], points_list[1],
                                   neighbors_list[1])
        feats_s2 = self.encoder2_3(feats_s2, points_list[1], points_list[1],
                                   neighbors_list[1])

        feats_s3 = feats_s2
        feats_s3 = self.encoder3_1(feats_s3, points_list[2], points_list[1],
                                   subsampling_list[1])
        feats_s3 = self.encoder3_2(feats_s3, points_list[2], points_list[2],
                                   neighbors_list[2])
        feats_s3 = self.encoder3_3(feats_s3, points_list[2], points_list[2],
                                   neighbors_list[2])

        feats_s4 = feats_s3
        feats_s4 = self.encoder4_1(feats_s4, points_list[3], points_list[2],
                                   subsampling_list[2])
        feats_s4 = self.encoder4_2(feats_s4, points_list[3], points_list[3],
                                   neighbors_list[3])
        feats_s4 = self.encoder4_3(feats_s4, points_list[3], points_list[3],
                                   neighbors_list[3])

        resampled_points4 = self.resample4(points_list[3], feats_s4,
                                           neighbors_list[3])
        latent_s4 = torch.cat([feats_s4, resampled_points4],
                              dim=1)  # (N4, 64*16+3)
        latent_s4 = self.decoder4(latent_s4)

        resampled_points3 = self.resample3(points_list[2], feats_s3,
                                           neighbors_list[2])
        latent_s3 = nearest_upsample(latent_s4, upsampling_list[2])
        latent_s3 = torch.cat([latent_s3, feats_s3, resampled_points3],
                              dim=1)  # (N3, 64*16+64*8+3)
        latent_s3 = self.decoder3(latent_s3)

        resampled_points2 = self.resample2(points_list[1], feats_s2,
                                           neighbors_list[1])
        latent_s2 = nearest_upsample(latent_s3, upsampling_list[1])
        latent_s2 = torch.cat([latent_s2, feats_s2, resampled_points2],
                              dim=1)  # (N2, 64*8+64*4+3)
        latent_s2 = self.decoder2(latent_s2)  # (N1, 256)

        resampled_points1 = self.resample1(points_list[0], feats_s1,
                                           neighbors_list[0])
        latent_s1 = nearest_upsample(latent_s2, upsampling_list[0])
        latent_s1 = torch.cat([latent_s1, feats_s1, resampled_points1],
                              dim=1)  # (N1, 64*4+64*2+3)
        latent_s1 = self.decoder1(latent_s1)  # (N1, 256)

        output = self.outputlayer(latent_s1)  # (N1, 256)

        return output

#三层卷积作为Decoder
class MLPDecoder(nn.Module):

    def __init__(self, dimoffeat=256):
        super(MLPDecoder, self).__init__()
        self.sharedmlp = nn.Sequential(
            nn.Conv1d(dimoffeat, int(dimoffeat / 2), 1),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Conv1d(int(dimoffeat / 2), int(dimoffeat / 8), 1),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Conv1d(int(dimoffeat / 8), 3, 1))

    def forward(self, feat):
        feat = feat.transpose(-1, -2)  # (N, d) -> (d, N)
        feat = feat.unsqueeze(0)  # (d, N) -> (batch, d, N)
        output = self.sharedmlp(feat)  # (batch, d, N) -> (batch, 3, N)
        output = output.squeeze(0)  # (batch, 3, N) -> (3, N)
        output = output.transpose(-1, -2)  # (3, N) -> (N, 3)

        return output

#生成器
class ResampleGAN(nn.Module):

    def __init__(self, input_dim, dimofbottelneck, init_dim, kernel_size,
                 init_radius, init_sigma, group_norm):
        super(ResampleGAN, self).__init__()
        self.encoder = ResampleKPConvEncoder(input_dim, dimofbottelneck,
                                             init_dim, kernel_size,
                                             init_radius, init_sigma,
                                             group_norm)
        self.decoder = MLPDecoder(dimoffeat=dimofbottelneck)

    def forward(self, data_dict):
        feat = self.encoder(data_dict)
        points_recovered = self.decoder(feat)

        return points_recovered

判别器(ResampleGAN):

判别器输入点云中的坐标矩阵和邻点坐标信息,考虑到生成器采用不同的ARM重采样自适应,会导致改变局部的密度,所以在多尺度下(小、中、大尺度,长度分别是5,10,20)进行特征提取(特征提取使用PointNet网络),并进行特征融合,判别器的目的是判断不同层次的点是真实还是合成的。

python 复制代码
class MultiScalePointNet(nn.Module):

    def __init__(self, dimoffeat=256, multiscale=[5, 10, 20]) -> None:
        super(MultiScalePointNet, self).__init__()
        self.multiscale = multiscale
        self.patchpointnets = nn.ModuleList()
        for scale in multiscale:
            self.patchpointnets.append(
                nn.Sequential(
                    nn.Conv1d(3, dimoffeat // 4, kernel_size=1, stride=1),
                    nn.LeakyReLU(negative_slope=0.2),
                    nn.Conv1d(dimoffeat // 4,
                              dimoffeat // 2,
                              kernel_size=1,
                              stride=1), nn.LeakyReLU(negative_slope=0.2),
                    nn.Conv1d(dimoffeat // 2,
                              dimoffeat,
                              kernel_size=1,
                              stride=1)))
        self.maxpool = nn.AdaptiveMaxPool1d(output_size=1)
        self.sharedmlp = nn.Sequential(
            nn.Conv1d(len(multiscale) * dimoffeat,
                      dimoffeat,
                      kernel_size=1,
                      stride=1), nn.LeakyReLU(negative_slope=0.2),
            nn.Conv1d(dimoffeat, dimoffeat // 2, kernel_size=1, stride=1),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Conv1d(dimoffeat // 2, dimoffeat // 4, kernel_size=1, stride=1),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Conv1d(dimoffeat // 4, dimoffeat // 8, kernel_size=1, stride=1),
            nn.LeakyReLU(negative_slope=0.2),
            nn.Conv1d(dimoffeat // 8, 1, kernel_size=1, stride=1))

    def forward(self, points, neighbor_indices):
        patchfeats_list = []
        for idx in range(len(self.multiscale)):
            neighbors = points[
                neighbor_indices[:, :self.multiscale[idx]]]  # shape (N, K, 3)
            neighbors = neighbors - points[:, None, :]
            neighbors = neighbors.transpose(1, 2)  # (N, K, 3) -> (N, 3, K)

            feats = self.patchpointnets[idx](
                neighbors)  # (N, 3, K) -> (N, d, K)
            patchfeats = self.maxpool(feats)  # (N, d, K) -> (N, d, 1)
            patchfeats = patchfeats.squeeze(-1)  # (N, d, 1) -> (N, d)
            patchfeats = patchfeats.transpose(0, 1)  # (N, d) -> (d, N)
            patchfeats_list.append(patchfeats)

        multiscalefeats = torch.cat(patchfeats_list,
                                    dim=0)  # m x (d, N) -> (md, N)

        multiscalefeats = multiscalefeats.unsqueeze(
            0)  # (md, N) -> (batch, md, N)
        out = self.sharedmlp(
            multiscalefeats)  # (batch, md, N) -> (batch, 1, N)
        out = out.squeeze(0)  # (batch, 1, N) -> (1, N)
        out = out.squeeze(0)

        return out

3、配准

配准工作backbone使用的是GeoTransformer结构作为配准模块,并且使用GeoTransformer中提到的局部到全局的配准策略LGR策略来替换RANSAC。

LGR可以利用所有的对应关系来估计变换矩阵,而不像RANSAC使用一部分内点而存在内点分布不均匀的影响,LGR可以更好地处理重复几何结构和低重叠比例情况。

4、损失函数

损失函数包含两个模块,对于SIRA和点云配准两个部分分别进行训练。

对于SIRA模块,使用倒角损失,生成器损失和判别器损失三部分来训练域自适应。

对于配准模块,使用重叠圆损失(点对应损失)和全局点对损失两部分训练(使用的就是Geotrans的配准模块,损失也一模一样)。

四、实验

1、对比不同框架

提前获得了一部分数据集(合成数据集)的情况下,其实效果相比GeoTransformer提升并没有很大,说明这个方法idea很好,但其实用在配准工作其实一般。

2、点云配准训练中,不同数量的数据集对性能指标的影响

在不同的backbone下,显然Samples越少,对其他backbone影响越来越明显,对SIRA-PCR影响则很微小。性能指标包括:特征匹配召回,离群比、配准召回。不得不说这不就是因为提前吃了一遍数据集,学到了特征的影响吗。

3、其他实验

对于FlyingShapes数据集的建立,引入structure3D效果提升,这个不太了解情况,证明了删除平面确实有效。

消融实验中对于SIRA的不同组件以及FlyingShapes的object data和scene data进行消融。

论文参考:ICCV 2023 Open Access Repository

相关推荐
古希腊掌管学习的神32 分钟前
[机器学习]XGBoost(3)——确定树的结构
人工智能·机器学习
ZHOU_WUYI1 小时前
4.metagpt中的软件公司智能体 (ProjectManager 角色)
人工智能·metagpt
靴子学长2 小时前
基于字节大模型的论文翻译(含免费源码)
人工智能·深度学习·nlp
AI_NEW_COME3 小时前
知识库管理系统可扩展性深度测评
人工智能
海棠AI实验室3 小时前
AI的进阶之路:从机器学习到深度学习的演变(一)
人工智能·深度学习·机器学习
hunteritself3 小时前
AI Weekly『12月16-22日』:OpenAI公布o3,谷歌发布首个推理模型,GitHub Copilot免费版上线!
人工智能·gpt·chatgpt·github·openai·copilot
IT古董4 小时前
【机器学习】机器学习的基本分类-强化学习-策略梯度(Policy Gradient,PG)
人工智能·机器学习·分类
centurysee4 小时前
【最佳实践】Anthropic:Agentic系统实践案例
人工智能
mahuifa4 小时前
混合开发环境---使用编程AI辅助开发Qt
人工智能·vscode·qt·qtcreator·编程ai
四口鲸鱼爱吃盐4 小时前
Pytorch | 从零构建GoogleNet对CIFAR10进行分类
人工智能·pytorch·分类