graphs 图也可以用 diffusion 耶~
本文提出了第一个去噪扩散模型在 3D 血管图生成中的工作,其是新颖的两阶段生成方法,依次对节点坐标和边进行去噪,在生成多样化、新颖且解剖学上合理的血管图方面性能表现出色。
论文:3D Vessel Graph Generation Using Denoising Diffusion
代码:https://github.com/chinmay5/vessel_diffuse
0、摘要
以 3D 图形表示的血管网络有助于预测疾病生物标志物,模拟血流,并有助于合成图像,有助于临床决策。然而,生成与感兴趣的解剖结构相对应的逼真血管图是具有挑战性的。
以前的方法主要是用自回归的方式生成血管树,不能应用于有循环的血管图,如毛细血管或特定的解剖结构,如Willis环。
为了解决这一问题,本文首次探索了去噪扩散模型在 3D 血管图(vessel graphs)生成中的应用,本文提出一种新的两阶段生成方法,该方法对节点坐标和边缘进行顺序去噪。(注意 graph 与 image 是不同的~)
在两个真实世界的血管数据集进行了实验,包括微观毛细血管和主要脑血管,并证明了本文的方法在生成多样化、新颖和解剖学上合理的血管图方面的普遍性。
1、引言
1.1、血管图
(1)将血管网络作为一种三维空间图进行研究,提供了一种对循环系统的解剖和生理特性的紧凑表示,但标注大规模的样本是昂贵和繁琐的;
(2)合成血管图可以增强下游任务的真实数据集,但目前生成真实的血管图仍然相对缺乏研究;
1.2、基于扩散模型的 graph 生成
(1)扩散模型已用于分子生成、蛋白质设计和材料科学等 graph 生成任务,虽然这些空间图生成任务有一个共同的目标,但它们在描述感兴趣的物理量以及这些量如何相互影响方面存在显著差异;
(2)在分子生成的情况下,两个相邻的原子 (节点) 是碳,这为共价键的存在提供了强烈的偏差,然而,这种节点类形式的归纳偏置不适用于血管生成;
(3)现有的空间图生成方法并不容易适用于血管图生成任务;
(4)根据成像分辨率的不同,血管图可以具有不同的拓扑结构和生理特性;
本文的目标是开发一个灵活的数据驱动模型,能够生成不同的成像分辨率下的血管图
1.3、本文贡献
(1)提出一种新的基于扩散去噪的血管图生成方法,这在该领域是第一次;
(2)通过提出一种定制的方法来解决生成血管图的复杂性,先对节点去噪,然后对边去噪;
(3)对两个真实血管数据集(一个代表毛细血管,另一个代表大血管)的实验,验证所提出方法能够生成多样化、独特和有效的血管图;
2、相关研究
2.1、血管图生成(Vessel Graph Generation)
(1)通过氧气浓度来生成动脉树,这样的模拟是昂贵的;
(2)一些研究尝试生成血管网格(vessel mesh),采用生成对抗网络来合成冠状动脉,利用变分自动编码器生成颅内血管节段,然而,这些生成算法不能在血管图中生成循环,这是许多解剖场景的关键特征;
2.2、基于扩散的图生成(Diffusion-based Graph Generation)
(1)在分子图生成方面有很多工作,但目前还缺乏基于扩散的血管图生成算法;
3、方法
本文想要生成空间图,其中节点嵌入在3D空间中,边具有类别属性。解剖学上,节点代表血管网络的分叉点,边缘代表连接它们的血管。
设 x i ∈ R 3 {x_i \in \mathbb{R}^3} xi∈R3 表示第 i {i} i 个节点的空间三维坐标, e i j ∈ R c {e_{ij} \in \mathbb{R}^c} eij∈Rc 为第 i {i} i 个节点和第 j {j} j 个节点的边, c {c} c 代表边的类别数,故,一个图(graph)由节点坐标 X ∈ R n × 3 {X \in \mathbb{R}^{n×3}} X∈Rn×3 和 分类边邻接矩阵 E ∈ R n × n × c {E \in \mathbb{R}^{n×n×c}} E∈Rn×n×c 编码,其中 n {n} n 为节点数目。
3.1、两阶段去噪
本文认为改变节点坐标会使边的类型预测变得困难,一旦所有的节点位置都已知,寻找可信的边配置就会受益于节点坐标的集体归纳偏置。
因此提出了一种两阶段图生成策略,首先,专注于生成一组似是而非的节点坐标作为点云。在第二阶段,学习边,保持节点坐标固定。
本文方法概览,先生成点再生成边:
3.1.1、节点去噪
使用 DDPM 来实现点生成任务,利用噪声模型 q ( X t ∣ X t − 1 ) {q(X^t|X^{t-1})} q(Xt∣Xt−1) 向点坐标 X 0 {X^0} X0 中添加高斯噪声:
α 1 , α 2 , . . . , α T {α^1,α^2,...,α^T} α1,α2,...,αT 为噪声调度, α 1 ≈ 1 {α^1≈1} α1≈1, α T ≈ 0 {α^T≈0} αT≈0,确保坐标已映射到标准高斯;
随后使用反向扩散过程对节点坐标进行去噪,训练过程中,选择一个随机的时间点 t {t} t ,并训练模型来预测添加的噪声。设 f γ {f_γ} fγ 表示带有参数 γ {γ} γ 的模型,前向获得 X t : = α ‾ t X 0 + 1 − α ‾ t ϵ {X^t:=\sqrt{\overline α^t}X^0 + \sqrt{1 - \overline α^t}ϵ} Xt:=αt X0+1−αt ϵ ,预测噪声 ϵ ^ : = f γ ( X t , t ) {\hat ϵ:=f_γ(X^t,t)} ϵ^:=fγ(Xt,t) ,其中, α ‾ t = ∏ s = 1 t α s {\overline α^t = \prod_{s=1}^t α^s} αt=∏s=1tαs,通过最小化损失函数来优化 f γ {f_γ} fγ:
3.1.2、边去噪
一旦节点去噪模型被训练出来,就可以训练边去噪模型了,为此,采用离散扩散模型,从边属性 E 0 {E^0} E0 开始,使用如下加噪模型:
其中, Q i j t : = q ( e t = j ∣ e t − 1 = i ) {Q_{ij}^t := q(e^t=j|e^{t-1}=i)} Qijt:=q(et=j∣et−1=i) 是边类别状态空间中的马尔可夫转移概率,对于 m {m} m 的边缘分布,得到的转移矩阵构造为 Q t : = α t I + ( 1 − α t 1 c m ′ ) {Q^t := α^tI+(1-α^t1_cm^\prime)} Qt:=αtI+(1−αt1cm′),其中 ′ {^\prime} ′ 表示转置操作, α t {α^t} αt 为噪声调度。
采用带参数 δ {δ} δ 的图 transformer 网络 f δ {f_δ} fδ 来建模边去噪网络;
血管图可以根据感兴趣的解剖结构具有特定的方向,一个旋转等变图生成模型不能捕获这样的数据集属性,因此,本文的模型不是旋转等变的,以保持数据集的方向属性。
训练过程中,选择一个随机的时间点 t {t} t ,使用等式(3)获得 E t {E^t} Et,并预测真实边 E ^ 0 : = f δ ( E t , X 0 , t ) {\hat E^0:=f_δ(E^t,X^0,t)} E^0:=fδ(Et,X0,t) ,模型由交叉熵损失训练:
除了正确的边类别外,边之间的度在血管图中是至关重要的,因此,引入了一种新的节点度损失。对于给定的节点位置,有多个有效的边配置,模型应该预测一个类似于 ground truth 的度分布,在一个小批次上使用预测和目标节点度分布之间的 KL 散度损失。
然而,我们需要邻接矩阵来比较节点度,这需要从预测的边邻接似然中离散采样。这个操作将破坏梯度反向传播,为了解决这个问题,使用 Gumbel-softmax (GS) 技巧进行邻接矩阵采样,节点度损失计算为:
总的边损失为: L e d g e = L C E + L o {\mathcal{L_{edge}}=\mathcal{L}{CE}+\mathcal{L{o}}} Ledge=LCE+Lo,在消融研究中研究了这些损失的贡献。
3.2、图生成
训练好 f γ {f_γ} fγ 和 f δ {f_δ} fδ 后,就可以采样血管图了,首先,从可能的离散值中抽取节点数 n n n,从 X T ∼ N ( 0 , I ) {X^T \sim \mathcal{N}(0,I)} XT∼N(0,I),执行去噪步骤,获得先验 p γ ( X T − 1 ∣ X T ) {p_γ(X^{T-1}|X^{T})} pγ(XT−1∣XT) 如下:
t > 0 {t>0} t>0 时, ϵ ∼ N ( 0 , I ) {ϵ \sim \mathcal{N}(0,I)} ϵ∼N(0,I),否则 ϵ = 0 {ϵ=0} ϵ=0,当拥有了 X 0 {X^0} X0,执行边去噪步骤,采样 E i j T ∼ m {E_{ij}^T \sim m} EijT∼m,首先从 E t {E^t} Et 中得到 E ^ 0 {\hat E^0} E^0 ,并计算边类别的后验分布:
其中, Q ‾ t : = Q 1 . . . Q t {\overline Q^t := Q^1... Q^t} Qt:=Q1...Qt,⊙为点乘,下一步,从边属性中采样类别值: E i j t − 1 ∼ p δ ( E i j t − 1 ∣ E i j t ) {E_{ij}^{t-1} \sim p_δ(E_{ij}^{t-1}|E_{ij}^{t})} Eijt−1∼pδ(Eijt−1∣Eijt)
4、实验与结果
4.1、数据集
(1)显微镜下的毛细血管图像,即 VesSAP,选择了24个带标注的 volumes,并从Voreen中提取了图,在原始分辨率的体素空间中使用了 48 × 48 × 48 的 patch 大小图,使用半径信息基于厚度创建了4个边类别,包括一个背景(无边)类;
(2)Willis 数据集,来自 CROWN (n=300) 和 TopCoW (n=90) 挑战的 390 张 MRA 图像,在 TopCoW 数据集上训练的多类分割工具对 CROWN 数据集进行分割,并使用 Voreen 从分割中提取图,对于 CoW 数据,使用整个 CoW 图作为一个样本,它有14类包含不同动脉标签的边和一个背景类;
4.2、基线
(1)血管图生成方法目前是明显的空白,选择最新的两种分子图生成方法: Congress 和 MiDi;
4.3、评价指标
(1)找到一个好的度量来评估节点图生成血管坐标和边信息的方法是困难的,当生成的图准确地表示真实的数据分布时,它们可以提高下游的性能;
(2)先前的工作使用了生成网格和真实网格之间的半径、总长度和扭曲度的分布差异。沿着这条线,本文采用以下物理参数来计算真实图和生成图之间的KL散度:对于节点,采用三维坐标位置( x , y , z {{x, y, z}} x,y,z)和节点的度( d e g ( V ) deg(\mathcal V) deg(V)),对于边,采用边的数目( ∣ ε ∣ {|\varepsilon|} ∣ε∣)、边的长度( l ε {l_\varepsilon} lε),两条边之间的角度、三轴边的方向、连接分量 β 0 {\beta_0} β0、循环分量 β 1 {\beta_1} β1。
4.4、实施细节
(1)一块 A-6000 GPU;
(2)1000 epochs,batch size = 64;
(3)AdamW 优化器,0.0003 学习率;
模型超参数细节:
4.5、结果与讨论
与现有方法比较:
可视化示例,模型能够学习复杂的结构,例如两个数据集的循环和方向特征:
两个数据集各指标图像比较:
与 MiDi 模型的可视化比较:
4.6、消融实验
在 CoW 数据集上消融旋转方差、两阶段去噪模型、度损失:
动作要快,姿势要帅~