【论文阅读】Segment Anything Model for Road Network Graph Extraction

【论文阅读】Segment Anything Model for Road Network Graph Extraction (CVPRW 2024)

Paper链接:https://openaccess.thecvf.com/content/CVPR2024W/SG2RL/html/Hetang_Segment_Anything_Model_for_Road_Network_Graph_Extraction_CVPRW_2024_paper.html

文章目录

  • [【论文阅读】Segment Anything Model for Road Network Graph Extraction (CVPRW 2024)](#【论文阅读】Segment Anything Model for Road Network Graph Extraction (CVPRW 2024))
    • [1. 摘要](#1. 摘要)
    • [2. 方法](#2. 方法)
      • [2.1 整体结构](#2.1 整体结构)
      • [2.2 Image Encoder](#2.2 Image Encoder)
      • [2.3 Mask Decoder](#2.3 Mask Decoder)
      • [2.4 Topology Decoder](#2.4 Topology Decoder)
      • [2.5 Label Generation](#2.5 Label Generation)

1. 摘要

简单来说,本工作将矢量道路线提取的部分流程视为分割任务,利用SAM预训练模型的强大分割能力,实现了SOTA精度和极高的推理速度。

2. 方法

2.1 整体结构

SAM-Road整体由三个部分构成:

  1. Image Encoder:预训练SAM Image Encoder
  2. Geometry Decoder:即图中的Mask Decoder,由4层转置卷积构成,输出分割概率图
  3. Topology Decoder:由Transformer实现拓扑结构中的Message Passing

2.2 Image Encoder

采用最小版本,即ViT-B。训练时采用0.1倍的基础学习率来微调。

2.3 Mask Decoder

为了提升整体以及交叉点的提取精度,Mask Decoder同时输出两个通道数为1的masks ,形状为(H_img, W_img, 2)。

  1. mask_0用于提取graph vertices。首先,道路由连续的mask表示,因此,每个像素点均有可能是graph vertex。为了获取sparse vertices,本工作设计了一种用于抑制多余vertices的NMS算法。

    复制代码
    NMS of Vertices算法
    1. 根据threshold预处理,消除分数低的像素。
    2. 以d_v为抑制距离(类似目标检测NMS中的IoU),半径内保留分数最高的vertex。

    这一步可能出现road vertices分数大于附近intersections的情况,从而出现误消除intersections的情况。

  2. mask_1用于提取intersections。使用同样的NMS算法。

两个masks经处理后,对二者进行join,并将intersections设置较高的分数,再次应用NMS得到最终的graph vertices。

2.4 Topology Decoder

Topology Decoder由3层多头注意力组成,用于将"离散"的vertices连接成拓扑结构。

本方法目的是寻找每个顶点的一阶邻居,并将此视为二分类任务。步骤如下:

  1. 选择一个source vertex;

  2. 在 R n b r R_{nbr} Rnbr范围内选择至多 N n b r N_{nbr} Nnbr个target vertex,构成多个vertex pairs;

    注意,source vertex与每个target vertex都是一阶邻居关系

  3. 对所有选中的顶点计算特征(根据坐标,通过在特征图上进行Bilinear Sample得到顶点特征,即Figure 2中的Source Feat和Target Feat);

  4. 对所有vertex pairs计算offset,得到 d k d_{k} dk;

  5. 拼接Source Feat,Target Feat和 d k d_k dk,得到形状为 ( N n b r , 2 D f e a t + 2 ) (N_{nbr}, 2D_{feat}+2) (Nnbr,2Dfeat+2)的向量,并proj到 ( N n b r , D f e a t ) (N_{nbr}, D_{feat}) (Nnbr,Dfeat)作为query;

  6. 经3层多头注意力后,将query输入线性层得到分类logits,表示vertex pairs相连的概率。

2.5 Label Generation

  • Mask Labels

    1. 使用宽度为3个像素的mask代表道路线段;
    2. 使用半径为3个像素的mask代表intersections;
  • Topology Labels

    • 以教师强制方式训练Topology Decoder

      1. 均匀采样gt mask得到模拟概率图,在此基础上应用NMS Vertices等算法;
      2. 使用高斯分布对gt vertices坐标进行随机扰动;
相关推荐
HollowKnightZ2 小时前
目标姿态估计综述:Deep Learning-Based Object Pose Estimation: A Comprehensive Survey
人工智能·深度学习
加油吧zkf2 小时前
Conda虚拟环境管理:从入门到精通的常用命令
图像处理·深度学习·计算机视觉·conda
小哥谈4 小时前
论文解析篇 | YOLOv12:以注意力机制为核心的实时目标检测算法
人工智能·深度学习·yolo·目标检测·机器学习·计算机视觉
水龙吟啸4 小时前
从零开始搭建深度学习大厦系列-2.卷积神经网络基础(5-9)
人工智能·pytorch·深度学习·cnn·mxnet
HollowKnightZ5 小时前
论文阅读笔记:VI-Net: Boosting Category-level 6D Object Pose Estimation
人工智能·深度学习·计算机视觉
yzx9910135 小时前
AI大模型平台
大数据·人工智能·深度学习·机器学习
Better Rose7 小时前
人工智能与机器学习暑期科研项目招募(可发表论文)
人工智能·深度学习·机器学习·论文撰写
Jamence7 小时前
多模态大语言模型arxiv论文略读(155)
论文阅读·人工智能·计算机视觉·语言模型·论文笔记
慕婉03077 小时前
深度学习中的常见损失函数详解及PyTorch实现
人工智能·pytorch·深度学习
神经星星7 小时前
在线教程丨一句话精准P图,FLUX.1 Kontext可实现图像编辑/风格迁移/文本编辑/角色一致性编辑
人工智能·深度学习·机器学习