【Motion Forecasting】SIMPL:简单且高效的自动驾驶运动预测Baseline

SIMPL: A Simple and Efficient Multi-agent Motion Prediction Baseline for Autonomous Driving

这项工作发布于2024年,前一段时间我已经对这篇文章的摘要和结论进行了学习和总结,这一部分详见https://blog.csdn.net/Coffeemaker88/article/details/141687294?spm=1001.2014.3001.5501

SIMPL当中也使用了与HDGT或GoRela类似的视角不变性编码方法,并使用Transformer架构对场景中元素之间的关联进行建模,本篇博客主要对该文章第三部分Methodology进行学习和记录。

Methodology

Problem Formulation

针对自动驾驶预测任务,每一篇文章当中对问题的描述基本是一致的。可以总结为使用观测时间段内代理的历史运动轨迹加上高精地图当中所提供的地图信息来对代理未来时刻的多模态轨迹进行预测。

SIMPL当中对于Problem Formulation这一部分的描述较为详细,此处不妨对原文对应部分直接进行翻译:

轨迹预测任务包括基于代理观测到的历史运动轨迹以及周围地图信息来对目标代理可能的未来轨迹进行预测。具体来说,在包含 N a N_a Na个移动代理(包括对数据进行采样的自动驾驶汽车)的驾驶场景当中,使用 M M M来表示地图信息,使用 X = { x 0 , . . . , x N a } X=\{x_0, ..., x_{N_a}\} X={x0,...,xNa}来表示所有代理的历史观测轨迹,其中 x i = { x i , − H + 1 , . . . , x i , 0 } x_i = \{x_{i, -H+1}, ..., x_{i, 0}\} xi={xi,−H+1,...,xi,0}表示的是代理 i i i在过去 H H H个时刻的历史轨迹。

为了不失一般性,多代理运动预测器会生成场景中所有 N a N_a Na个代理的未来轨迹,记作 Y = { y 0 , . . . , y N a } Y=\{y_0, ..., y_{N_a}\} Y={y0,...,yNa}。对于每一个单独的代理 i i i,轨迹预测器会生成其未来 K K K条可能的轨迹,以及这些未来轨迹对应的概率分数,概率分数用于捕捉多模态轨迹的分布。多模态轨迹被表示为 y i = { y i 1 , . . . , y i K } y_i = \{y^1_i, ..., y^K_i\} yi={yi1,...,yiK},其中 y i k = { y i , 1 k , . . . , y i , T k } y^k_i = \{y^k_{i, 1}, ..., y^k_{i, T}\} yik={yi,1k,...,yi,Tk}表示的是代理 i i i的第 k k k条未来轨迹,轨迹时间长度为 T T T,而概率分数被记作 α i = { α i 1 , . . . , α i K } \alpha_i = \{\alpha^1_i, ..., \alpha^K_i\} αi={αi1,...,αiK}。

综上所述,代理 i i i的多模态轨迹预测可以被视为对以下混合分布进行估计:
P ( y i ∣ X , M ) = ∑ k = 1 K α i k P ( y i k ∣ x , M ) P(y_i | X, M) = \sum^K_{k=1}\alpha^k_i P(y^k_i|x, M) P(yi∣X,M)=∑k=1KαikP(yik∣x,M)

需要注意的是,本文方法主要关注边缘运动预测任务(对每一个代理的未来轨迹进行单独的预测),但是SIMPL可以平滑地通过使用场景级别的损失函数(scene-level loss functions)来将模型拓展到多代理联合预测任务上。这一部分将会作为SIMPL的未来工作。

Framework Overview

SIMPL的整体架构如下图所示:

与目前大多数的主流方法一致,SIMPL在对场景进行编码时采用向量化的表示方法。

对于每一个语义实例(semantic instance),比如轨迹或是车道段,作者构建了一个局部参考坐标系(local reference frame),来对实例自身的特征和实例之间的相对特征进行解耦。

之后,代理和地图特征将会使用简单的编码器来进行特征提取,而实例之间的相对姿态将会进行成对的计算,并使用MLP来进行编码,最终得到相对位置编码(Relative Positional Embedding,RPE)。

实例特征对应的tokens以及RPE将会被输入到对称融合Transformer(Symmetric Fusion Transformer,SFT)当中,它是一个紧致并且便捷的特征融合模块,会对称地对特征进行更新。

最后,SIMPL使用Bernstein基多项式来对代理的多模态未来轨迹进行参数化,SIMPL将会使用一个简单的解码器来同时输出所有目标代理的多模态未来轨迹。

Instance-centric Scene Representation

这一部分也是我最感兴趣的部分,即以实例为中心的场景表示方法。这一部分内容是近几年的研究热点,2023年TPAMI的HDGT、2024年TPAMI的MTR++、2023年CVPR的QCNet以及2024年CVPR的HPNet均在场景编码阶段使用了"局部参考坐标系"的概念,这种编码方法会使得模型能够学习到运动预测任务的视角不变性,如果进一步对局部坐标系进行细化,将局部坐标系建立在每一个时刻的每一个代理位置上,那么模型将支持流式输入,为真实场景的部署提供了可能。

SIMPL同样使用了类似的场景编码方法。

与scene-centric表示方法不同,场景还可以使用每一个实例局部坐标系下的向量特征以及实例之间的相对姿态来进行表示。在SIMPL的场景编码当中,将会为每一个语义实例建立一个局部参考坐标系,来对每一个语义实例的空间属性进行标准化,本文将上述表示方法称为"instance-centric",即以实例为中心的表示方法。

为了不失一般性,本文方法将代理历史运动轨迹的局部坐标系建立在代理的当前时刻状态之上(与HDGT、GoRela的方式相同)。而对于静态的地图元素,比如车道段,本文方法以折线的中点作为锚点,并使用端点之间构成向量的角度作为折线的方向。

局部坐标系统可以被视为实例的"锚姿态(anchor pose)",基于局部坐标系统可以方便地对实例之间成对的相对空间关系进行计算。

具体来说,针对场景当中的某个实例 i i i,它在全局坐标系统下的锚姿态可以使用它的位置 p i ∈ R 2 p_i \in R^2 pi∈R2和方向 v i ∈ R 2 v_i \in R^2 vi∈R2来定义。按照GoRela当中的方式,本文方法同样使用三个变量来描述实例 i i i和实例 j j j之间的相对姿态,分别是航向角差值(heading difference) α i → j \alpha_{i \rightarrow j} αi→j、相对方位(relative azimuth) β i → j \beta_{i \rightarrow j} βi→j以及相对距离(relative distance) ∣ ∣ d i → j ∣ ∣ ||d_{i \rightarrow j}|| ∣∣di→j∣∣。为了加强数值的稳定性,角度将会使用它们的sine和cosine值进行表示。

航向角差值被表示为:
sin ⁡ ( α i → j ) = v i × v j ∣ ∣ v i ∣ ∣ ∣ ∣ v j ∣ ∣ , cos ⁡ ( α i → j ) = v i ⋅ v j ∣ ∣ v i ∣ ∣ ∣ ∣ v j ∣ ∣ \sin(\alpha_{i \rightarrow j}) = {{v_i \times v_j} \over {||v_i||||v_j||}}, \cos(\alpha_{i \rightarrow j}) = {{v_i \cdot v_j} \over {||v_i||||v_j||}} sin(αi→j)=∣∣vi∣∣∣∣vj∣∣vi×vj,cos(αi→j)=∣∣vi∣∣∣∣vj∣∣vi⋅vj

而相对方位被表示为:
sin ⁡ ( β i → j ) = d i → j × v j ∣ ∣ d i → j ∣ ∣ ∣ ∣ v j ∣ ∣ , cos ⁡ ( β i → j ) = d i → j ⋅ v j ∣ ∣ d i → j ∣ ∣ ∣ ∣ v j ∣ ∣ \sin(\beta_{i \rightarrow j}) = {{d_{i \rightarrow j} \times v_j} \over {||d_{i \rightarrow j} || || v_j ||}}, \cos(\beta_{i \rightarrow j}) = {{d_{i \rightarrow j} \cdot v_j} \over {||d_{i \rightarrow j} || || v_j ||}} sin(βi→j)=∣∣di→j∣∣∣∣vj∣∣di→j×vj,cos(βi→j)=∣∣di→j∣∣∣∣vj∣∣di→j⋅vj

为了简便,本文方法省略了GoRela中为相对距离所加入的位置编码,实例之间的相对空间信息可以使用维度为 5 5 5的向量 r i → j r_{i \rightarrow j} ri→j来进行表示,即 r i → j = [ sin ⁡ ( α i → j ) , cos ⁡ ( α i → j ) , sin ⁡ ( β i → j ) , cos ⁡ ( β i → j ) , ∣ ∣ d i → j ∣ ∣ ] r_{i \rightarrow j} = [\sin(\alpha_{i \rightarrow j}), \cos(\alpha_{i \rightarrow j}), \sin(\beta_{i \rightarrow j}), \cos(\beta_{i \rightarrow j}), ||d_{i \rightarrow j}||] ri→j=[sin(αi→j),cos(αi→j),sin(βi→j),cos(βi→j),∣∣di→j∣∣]。在实际实现中,可以使用PyTorch或是NumPy当中的broadcasting机制来方便地实现相对姿态向量的并行计算。

综上所述,给定一个包含 N = N a + N m N = N_a + N_m N=Na+Nm个语义实例的场景,表示相对位置信息的张量维度为 [ N , N , 5 ] [N, N, 5] [N,N,5],其中表示 i i i和 j j j之间相对信息的向量 r i → j r_{i \rightarrow j} ri→j位于张量的第 j j j行第 i i i列。

Context Feature Encoding

在获取了以实例为中心的场景特征表示以及实例之间的相对位置编码之后,使用相应的编码器(encoders,也可以被称为tokenizers)来将上述信息转化为特征向量。

为了保持SIMPL是简单的,使用1D CNN来处理代理的历史运动信息,使用PointNet-based编码器来提取静态地图特征。为了不失一般性,本文方法令所有的隐层特征维度为 D D D。因此,所得的代理和地图特征的维度为 [ N a , D ] [N_a, D] [Na,D]和 [ N m , D ] [N_m, D] [Nm,D]。对于上述的实现细节,详情可参考VectorNet以及LaneGCN当中的做法。

此外,相对姿态可以进一步使用MLP来进行处理,从而产生维度为 [ N , N , D ] [N, N, D] [N,N,D]的相对位置嵌入表示(Relative Positional Embedding,RPE)。

Symmetric Fusion Transformer

在获取了instance tokens以及相应的RPE之后(instance tokens指的就是代理的历史运动信息特征以及静态地图的特征,RPE指的就是上文所说的使用MLP对相对位置向量进行处理之后得到的特征),SIMPL使用对称融合Transformer(Symmetric Fusion Transformer,下文直接记作SFT)来以视角不变的方式对instance tokens进行更新,如下图所示:

SFT由若干个堆叠起来的SFT layers组成,与标准的Transformer类似。

具体来说,可以将驾驶场景视为一个带有自环的完全有向图,输入的instance-centric features可以被视为结点,而使用RPE来描述边的信息。在对结点特征进行更新的过程中,结点特征只受与目标结点相连接的边的影响,确保了融合的特征仍然具有视角不变性。

更具体地说,将实例 i i i和 j j j对应的tokens分别记作 f i , f j f_i, f_j fi,fj。将以 f i f_i fi为源结点, f j f_j fj为目标结点的RPE记作 r i → j ′ r'{i \rightarrow j} ri→j′。三元组 ( f i , f j , r i → j ′ ) (f_i, f_j, r'{i \rightarrow j}) (fi,fj,ri→j′)包含着从结点 i i i变换到结点 j j j的信息,因此,SIMPL使用简单的MLP来对上述三元组进行编码:并得到实例 i i i之于实例 j j j的上下文特征:

c i → j = ϕ ( f i ⊞ f j ⊞ r i → j ′ ) c_{i \rightarrow j} = \phi(f_i \boxplus f_j \boxplus r'_{i \rightarrow j}) ci→j=ϕ(fi⊞fj⊞ri→j′)

其中 ⊞ \boxplus ⊞表示concat操作,而 ϕ : R 3 D → R D \phi: R^{3D} \rightarrow R^D ϕ:R3D→RD为MLP,由线性层、LayerNorm和ReLU组成。随后对目标结点及其上下文特征计算Cross-Attention,来对目标结点的特征进行更新:

f j ′ = M H A ( Q u e r y : f j , K e y : C j , V a l u e : C j ) f'_j = MHA(Query: f_j, Key: C_j, Value: C_j) fj′=MHA(Query:fj,Key:Cj,Value:Cj)

M H A ( ⋅ , ⋅ , ⋅ ) MHA(\cdot, \cdot, \cdot) MHA(⋅,⋅,⋅)是标准的多头注意力,而 C j = { c i → j } i ∈ { 1 , . . . , N } C_j = \{c_{i \rightarrow j}\}{i \in \{1, ..., N\}} Cj={ci→j}i∈{1,...,N}为实例 j j j的token的上下文特征向量,它同时也包含 c j → j c{j \rightarrow j} cj→j自身。

与标准的Transformer相同,在计算MHA之后,会使用逐点的前馈网络来对注意力的结果进行聚合。此外,在每一层当中, r i → j ′ r'_{i \rightarrow j} ri→j′将会使用MLP来进行重新编码,并与输入的RPE通过残差连接相加。

在实操中,SIMPL提供了一个更为高效的实现方法,以向量化的形式来完成上述的特征融合。首先,给定输入实例的tokens F ∈ R N × D F \in R^{N \times D} F∈RN×D,将其沿着不同的维度进行扩展,对它进行 N N N次复制,来构件源结点和目标结点的张量,二者的维度都是 [ N , N , D ] [N, N, D] [N,N,D]。回忆一下,RPE的维度也是 [ N , N , D ] [N, N, D] [N,N,D],将上述三个张量进行concat,就可以得到 ( f i , f j , r i → j ′ ) (f_i, f_j, r'_{i \rightarrow j}) (fi,fj,ri→j′),它的维度是 [ N , N , 3 D ] [N, N, 3D] [N,N,3D],施加 ϕ \phi ϕ之后即可得到上下文张量 C ∈ R N × N × D C \in R^{N \times N \times D} C∈RN×N×D。上下文张量 C C C当中的第 j j j行指的就是实例 j j j的上下文 C j C_j Cj。

综上,可以使用 C C C作为key和vaue,将拓展后的 F ′ ∈ R N × 1 × D F' \in R^{N \times 1 \times D} F′∈RN×1×D作为query,来计算MHA。

值得注意的是,本文所提出的SFT layers与近期时兴的Query-Centric方法非常的类似,但SIMPL计算的是全局注意力(相当于在全局所有instances之上构建了一个全联通的有向图,其弊端是将会带来庞大的计算开销)和并且RPE会得到更新(即边的特征会得到更新,但是是简单的re-encoding + residual connection的方式,与HDGT中使用结点特征对边特征进行更新不同)。

Multimodal Continuous Trajectory Decoder

在对称全局特征融合之后,更新的代理特征将会被聚合,并输入到多模态运动解码器当中,来生成所有代理的未来轨迹。

SIMPL将会为每一个代理预测 K K K条可能的轨迹,对于每一个模式,都使用一个简单的MLP进行轨迹预测,它同时包含一个regression head输出轨迹和一个classification head + softmax输出概率分数。

对于regression head,与过去方法中直接对轨迹的2D坐标点进行预测不同,SIMPL使用Bernstein基多项式对轨迹进行连续的参数化表示。参数化的多项式具有连续的表示,使得未来的每一个时刻轨迹点相应的运动都是平滑的,并且均具有高阶导数。

过去的研究结果表明,使用单一多项式对轨迹进行参数化,会使得模型的性能大幅度下降。作者认为性能下降的原因可能是所预测的系数存在数值不稳定性,导致回归任务变得很难。

为了充分地利用参数化轨迹的优势,并避免性能下降,SIMPL引入了Bernstein基多项式,它的系数是具有具体空间含义的控制点,使得模型可以更好地收敛。

具体来说,自由度为 n n n的Bernstein基多项式可以被写作:

其中 b n i ( t ) b^i_n(t) bni(t)是 i i i阶Bernstein基, ( n , i ) T (n, i)^T (n,i)T是二项系数, t t t是参数化曲线的变量, p i p_i pi是控制点。需要注意的是,对于 n n n阶Bernstein基多项式,共有 n + 1 n+1 n+1个控制点,第一个和最后一个控制点总是曲线的端点。

由于Bernstein基多项式被定义在 t ∈ ( 0 , 1 ] t \in (0, 1] t∈(0,1]上,SIMPL对真实的时间进行了标准化,得到 τ ∈ [ 0 , τ max ⁡ ] \tau \in [0, \tau_{\max}] τ∈[0,τmax],从而使得参数化的曲线可以被写作 f ( t ) = ∑ i = 0 n b n i ( τ τ max ⁡ ) p i f(t) = \sum^n_{i=0}b^i_n({{\tau} \over {\tau_{\max}}})p_i f(t)=∑i=0nbni(τmaxτ)pi。此外,由于hodograph property, n n n阶Bernstein基多项式的导数仍然是Bernstein基多项式,控制点被定义为 p i ( 1 ) = n ( p i + 1 − p i ) p^{(1)}i = n(p{i+1} - p_i) pi(1)=n(pi+1−pi),轨迹的velocity profile可以被计算为:

在实操中,使用MLP作为regression head,来将融合的代理特征映射为控制点。之后,每条预测轨迹的坐标点 Y p o s ∈ R T × 2 Y_{pos} \in R^{T \times 2} Ypos∈RT×2可以通过将constant basis矩阵 B ∈ R T × ( n + 1 ) B \in R^{T \times (n+1)} B∈RT×(n+1)与相应的2D控制点 P ∈ R ( n + 1 ) × 2 P \in R^{(n+1) \times 2} P∈R(n+1)×2进行相乘而得到:

其中 T T T是要预测的时间长度,而 t i = τ i τ max ⁡ t_i = {{\tau_i} \over {\tau_{\max}}} ti=τmaxτi是标准化后的时间点。

最后,预测的轨迹将会根据代理相应的锚姿态变换回全局坐标系统当中。

Training

SIMPL以端到端的方式进行训练,总体的损失函数是回归损失和分类损失的加权和:

L = ω L r e g + ( 1 − ω ) L c l s L = \omega L_{reg} + (1-\omega)L_{cls} L=ωLreg+(1−ω)Lcls

其中 ω ∈ [ 0 , 1 ] \omega \in [0, 1] ω∈[0,1]是平衡两项损失的权重,在SIMPL中被设置为0.8来强调回归任务的重要性。采用winner-takes-all策略来处理轨迹的多模态性。即对于每一个代理,在 K K K条所预测的轨迹中选取FDE最小的那条轨迹来计算损失。

对于分类随时,使用max-margin loss来区分正模式和其它模式。对于轨迹回归任务,SIMPL除了会对轨迹的坐标点进行回归之外,还额外引入了一个航向角损失,作为额外的监督,得到:

其中 Y ˉ ( ⋅ ) \bar{Y}_{(\cdot)} Yˉ(⋅)是Ground Truth。使用smooth L1 loss作为位置回归损失,而航向角回归损失定义为:

总结

SIMPL的亮点可以总结为两个,即Symmetric Fusion Transformer的使用以及使用Bernstein基多项式对未来轨迹进行编码的方法。由于我对Bernstein基多项式不了解,无法对这一部分进行总结,因此可以多说一说我对SFT及SIMPL中场景编码方法的思考。

在场景编码阶段,SIMPL的做法让我看到了许多HDGT和GoRela的影子,三者都是使用当前时刻场景中代理的位置和行驶方向作为每条轨迹的局部坐标系,并使用锚姿态对整条轨迹当中的位置和方向进行标准化,使得每一个代理的历史运动信息都处于它们各自当前时刻的局部坐标系统当中。三者都进一步在每个局部坐标系统之间构建起成对的相对姿态信息,这一信息可以作为边的特征。

具体来说,HDGT和GoRela将场景图建模为异构图,而SIMPL将场景建模为instance token-level的同质图。SIMPL当中的图是全连接的,结点特征是使用1D CNN或PointNet-based encoder编码的代理历史运动信息特征和地图特征,而边特征是相对姿态信息。

在对结点特征进行更新时,相对姿态信息以及该边所对应的两个结点的特征将会作为目标结点上下文特征的一部分进行编码,目标结点所有邻居结点的上下文特征组合构成的张量将会作为目标结点最终的上下文特征,与目标结点计算Cross-Attention,从而对目标结点的特征进行更新。作者在文章中也给出了基于矩阵乘法的高效实现,非常值得参考。

总得来说,我个人认为场景编码以及实例特征更新这一部分仍然具有改进的空间,比如将场景图的异构性作为先验信息引入到模型当中作为引导,或是根据场景信息定义某种边连接的轨迹,以降低计算开销。

相关推荐
七夜星七夜月1 天前
时间序列预测论文阅读和相关代码库
论文阅读·python·深度学习
WenBoo-2 天前
HIPT论文阅读
论文阅读
chnyi6_ya2 天前
论文笔记:Buffer of Thoughts: Thought-Augmented Reasoning with Large Language Models
论文阅读·人工智能·语言模型
Jude_lennon2 天前
【论文笔记】结合:“integrate“ 和 “combine“等
论文阅读
LuH11242 天前
【论文阅读笔记】HunyuanVideo: A Systematic Framework For Large Video Generative Models
论文阅读·笔记
lalahappy2 天前
Swin transformer 论文阅读记录 & 代码分析
论文阅读·深度学习·transformer
开心星人2 天前
【论文阅读】Trigger Hunting with a Topological Prior for Trojan Detection
论文阅读
图学习的小张2 天前
论文笔记:是什么让多模态学习变得困难?
论文阅读·神经网络·机器学习
Maker~2 天前
28、论文阅读:基于像素分布重映射和多先验Retinex变分模型的水下图像增强
论文阅读·深度学习
小嗷犬3 天前
【论文笔记】CLIP-guided Prototype Modulating for Few-shot Action Recognition
论文阅读·人工智能·深度学习·神经网络·多模态