FedGraph: Federated Graph Learning With Intelligent Sampling论文阅读

FedGraph: Federated Graph Learning With Intelligent Sampling

    • Abstract
    • Introduction
    • [2 BACKGROUND AND MOTIVATION](#2 BACKGROUND AND MOTIVATION)
      • [2.1 Federated Learning](#2.1 Federated Learning)
      • [2.2 Graph Convolutional Network](#2.2 Graph Convolutional Network)
      • [2.3 Graph Sampling](#2.3 Graph Sampling)
    • [3 FEDGRAPH DESIGN](#3 FEDGRAPH DESIGN)
      • [3.1 Local GCN Training by Clients](#3.1 Local GCN Training by Clients)
      • [3.1.1 GCN Construction](#3.1.1 GCN Construction)
      • [3.1.2 GCN Training](#3.1.2 GCN Training)
      • [3.2 Global Parameter Update by the Server](#3.2 Global Parameter Update by the Server)
    • [3.3 Security Analysis](#3.3 Security Analysis)
    • [4 INTELLIGENT GRAPH SAMPLING BASED ON DRL](#4 INTELLIGENT GRAPH SAMPLING BASED ON DRL)
      • [4.1 DDPG-Based Problem Formulation](#4.1 DDPG-Based Problem Formulation)
      • [4.2 Sampling Based on DDPG](#4.2 Sampling Based on DDPG)
    • [5 PERFORMANCE EVALUATION](#5 PERFORMANCE EVALUATION)
      • [5.1Experimental Settings](#5.1Experimental Settings)
      • [5.2Experimental Results](#5.2Experimental Results)
    • [6 RELATED WORK](#6 RELATED WORK)
      • [6.1 Federated Learning](#6.1 Federated Learning)
      • [6.2 Graph Convolutional Networks](#6.2 Graph Convolutional Networks)
    • [7 CONCLUSION](#7 CONCLUSION)

联邦图:具有智能采样的联邦图学习

Abstract

联邦学习因其在分布式机器学习中的隐私保护而引起了研究的广泛关注。然而,现有的联邦学习工作主要集中在卷积神经网络(CNN)上,它不能有效地处理在许多应用中流行的图数据。图卷积网络(GCN)被认为是最有前途的图学习技术之一,但其联邦设置很少被探索。在本文中,我们提出了用于多个计算客户端之间的联邦图学习的联邦图,每个计算客户端都包含一个子图 。FedGraph通过解决两个独特的挑战,为跨客户端提供了强大的图形学习能力。首先,传统的GCN训练需要在客户之间进行特征数据共享,从而导致隐私泄露的风险。FedGraph使用一种新的跨客户端卷积操作 来解决这个问题。第二个挑战是由大图大小导致的高GCN训练开销。我们提出了一种基于深度强化学习的智能图采样算法,该算法可以自动收敛到平衡训练速度和准确性的最优采样策略。我们实现了基于PyTorch的FedGraph,并将其部署在一个测试台上进行性能评估。四个流行数据集的实验结果表明,FedGraph通过实现更快的收敛到更高的精度,显著优于现有的工作。

Introduction

联邦学习在实现分布式设备之间的协作机器学习,同时保护它们的数据隐私[1]方面显示出了巨大的前景。关于联邦学习[2],[3]的研究越来越多,但他们研究了卷积神经网络(CNN)模型,该模型在图像和语音数据上显示出优越的学习精度。然而,许多应用程序生成由节点和边组成的图数据(例如,社会图和蛋白质结构),许多证据表明CNN不能有效地处理图学习[4],[5]。图卷积网络(GCN)[6]是通过一种新的图卷积运算来处理图的学习问题。与CNN的过滤少量相邻像素的卷积操作不同,一个图的卷积操作可以过滤相邻节点的特征。不幸的是,现有的联邦学习工作主要集中在CNN上,这使得GCN没有被探索。

近年来,人们对分散数据集上的图学习进行了一些初步的研究。Zhou等人[7]研究了一个在图上的垂直联邦学习场景,其中客户端保持相同的节点,但具有不同的特征和边缘类型。同样,Mei等人的[8]假设图的结构、特征和标签属于不同的来源。这些作品不同于本文中所研究的一般背景。最近的一些研究通过讨论非I.I.D数据分布在联邦图学习[9],[10]中的影响,探讨了图学习和联邦学习的交集。然而,这些工作没有考虑图间连接,这是现实世界[11]中普遍存在的现象。

在本文中,**我们研究了基于分布在多个计算客户端之间的图数据的GCN联合学习,这些计算客户端由于隐私保护而不允许直接的数据共享。每个客户端都有一个子图,它与其他客户端所持有的子图有边连接。**每个图节点都与一些包含私人信息的特性相关联。例如,医院中的医疗记录可以组织为图表,其中每个图表节点代表一个记录,其特征包括个人信息(如年龄、性别和职业)以及健康状况(如疾病)[11]。人们普遍认为,这些特征数据是对隐私敏感的,而且它们不能被暴露出来。给定一些带有标签的节点,图学习的目标是预测其他节点的标签。

GCN 上的联邦学习并不是 CNN 上联邦学习的简单扩展,因为它面临两个独特的挑战。首先,GCN 训练涉及客户端之间的节点特征共享,这会导致隐私泄露的风险。为了利用图结构信息,图卷积操作旨在聚合相邻节点的特征数据。如果某些相邻节点由其他客户端维护,而这些客户端拒绝公开其特征,则此类操作将失败。隐私保护的一个直接解决方案是消除特征共享,但这会严重降低训练准确性,我们的实验结果证实了这一点。第二个挑战是大图规模带来的高训练开销 [12],[13]。例如,Facebook 维护的社交网络包含超过 30 亿用户,相应的图数据大小可能为数百 GB [14]。由于 GCN 模型将原始图堆叠了多层相同的结构,因此模型大小变得非常大,甚至超过了物理内存的限制。

在本文中,我们提出了 FedGraph,这是一个联邦图学习系统,它融合了联邦学习和 GCN 的思想,为隐私保护的分布式图学习开辟了新的机会。FedGraph 特别擅长在具有复杂连接的分布式图上进行学习,并且可以通过解决上述挑战来收敛到较高的训练精度。对于第一个挑战,即特征共享和隐私保护的困境,一种常见的解决方案是使用基于密码学的技术,例如同态加密 [15]、[16],来实现对加密数据的计算。尽管这些技术具有强大的安全保障,但它们的计算开销很高,因此对于追求高训练速度的 FedGraph 来说并不合适。也存在基于硬件的解决方案,例如 SGX [17]、[18],用于隐私保护,但安全硬件容量有限,无法处理大型图数据 [18]。 FedGraph 通过设计跨客户端图卷积操作解决了这一难题,无需繁重的加密操作或专用硬件。FedGraph 不会直接共享节点特征,而是在共享之前将它们嵌入到低维表示中,这样就无法恢复原始特征。

为了减少 GCN 训练开销,图采样已被广泛采用,以随机选择一小批节点进行训练 [12]、[19]、[20]、[21]。GraphSAGE [12] 是一种基于节点邻居关系的图采样方法。它在对每个节点应用图卷积操作时随机选择固定数量的邻居。FastGCN [20] 被提出通过独立选择每个图卷积层的节点来提高采样效率。然而,现有的工作由于三个弱点而无法满足 FedGraph 设计的要求。首先,这些采样方法依赖于一些手工制作的参数,而这些参数严重依赖于领域专家的知识。例如,GraphSAGE 的性能由指定采样邻居数量的参数决定,手动参数调整非常耗时。其次,现有方法忽略了训练速度和训练准确性之间的权衡。采样较少的节点会加速训练,但会降低准确性。第三,参与联合图学习的客户端在图大小和计算能力上是异构的。对所有客户端应用相同的采样策略远非最佳解决方案。

这些弱点使得 FedGraph 中的采样算法设计具有挑战性。我们没有努力改进现有的启发式设计,而是求助于深度强化学习 (DRL) 技术,设计了一种智能采样算法,该算法可以通过共同考虑计算开销、训练准确性和客户端异质性来自动调整采样策略。通过仔细研究各种 DRL 算法,我们选择了深度确定性策略梯度 (DDPG) 并将其投入到联邦图学习中。本文的主要贡献如下。

  1. 我们提出 FedGraph 作为一种新型的联合图学习系统。我们正式介绍了客户端进行本地训练和服务器进行全局参数更新的过程。提出了一种轻量级的跨客户端卷积操作,以实现客户端之间的特征共享,同时避免隐私泄露。
  2. 为 FedGraph 设计了一种基于 DRL 的采样算法,以便它可以自动找到在训练速度和准确性之间取得良好平衡的最佳采样策略。
  3. 我们实现了 FedGraph 的原型并在测试平台上对其进行了评估。性能评估中使用了四个流行的图数据集。实验结果表明,与现有工作相比,FedGraph 的收敛速度至少快 2 倍,准确率高出约 10%。

本文的其余部分组织如下。我们在第 2 节中回顾了 GCN 和联邦学习的一些必要背景。第 3 节介绍了 FedGraph 设计,第 4 节介绍了智能采样策略设计。第 5 节给出了实验结果。第 6 节介绍了相关工作。最后,第 7 节总结了本文。

2 BACKGROUND AND MOTIVATION

在本节中,我们介绍了联邦学习和 GCN 的一些必要背景。此外,我们分析了现有的图采样方法及其弱点,这激发了本文的 FedGraph 设计。

2.1 Federated Learning

联邦学习的目标是在分布式设备之间训练一个共享模型,同时避免暴露它们的训练数据。典型的联邦设置由许多设备组成,每个设备都拥有一个不能暴露给其他设备的数据集。此外,还有一个参数服务器负责在设备之间同步训练结果。联邦学习包含多个训练轮次。在每一轮训练中,设备首先从参数服务器下载最新的全局模型,并独立使用本地数据进行训练。然后,它们将更新的模型或模型差异发送回参数服务器。在收集所有设备的训练结果后,参数服务器将它们集成以创建新的全局模型。在整个训练过程中,设备只共享模型,几乎不可能从这些模型中推断出训练数据。由于对训练数据的保护,联邦学习成为近年来的热门测试主题之一,许多重要的研究工作已经做出,以解决各种挑战[2],[22],[23],[24]。然而,它们都集中在CNN模型上,面向GCN的联邦学习很少被研究。

2.2 Graph Convolutional Network

CNN 在图像和视频等欧几里得数据学习方面取得了巨大成功。然而,实践中大量数据表示为由节点和边组成的图,也称为非欧几里得数据。图卷积网络 (GCN) [6] 已被提出作为图学习最有前途的技术之一。通过堆叠多个图卷积层,GCN 能够利用图结构和节点/边特征的信息来解决各种应用中的节点/边分类问题。具体而言,我们考虑一个定义为 G = (V, E)的无向图,其中集合 V 和 E 分别包括节点和边。相应的图邻接矩阵用 A 表示。每个节点 v ∈ V v \in V v∈V 与一个特征向量 x(v) 相关联。 GCN 包含 L 个卷积层,每个卷积层的结构与原始图 G 相同。在第 l 层中,每个节点 v 由向量 h ( l ) ( v ) h^{(l)}(v) h(l)(v)表示,称为节点嵌入。第一层是输入图,我们有 h ( l ) ( v ) = x ( v ) h^{(l)}(v)=x (v) h(l)(v)=x(v)。如图 1 所示,图卷积操作聚合相邻节点的嵌入,将结果转换为低维表示,最后将它们馈送到激活函数 σ ( . ) \sigma(.) σ(.)(例如 ReLU),以生成下一层的节点嵌入。正式地,GCN 的传播规则可以定义如下:
Z ( l + 1 ) = Q H ( l ) W ( l ) ; H ( l + 1 ) = σ ( Z ( l + 1 ) ) Z^{(l+1)} = QH^{(l)}W^{(l)};H^{(l+1)}=\sigma(Z^{(l+1)}) Z(l+1)=QH(l)W(l);H(l+1)=σ(Z(l+1))

其中, H ( l ) H^{(l)} H(l)包括第l层中的所有节点嵌入,以及 Q = D − 1 2 A D − 1 2 Q=D^{-\frac{1}{2}}AD^{-\frac{1}{2}} Q=D−21AD−21.对于矩阵D,我们有 D i i = ∑ j A i j 和 A − = A + I D_{ii}=\sum_{j}A_{ij}和 \mathop{A}\limits^{-}=A+I Dii=∑jAij和A−=A+I,其中I是一个单位矩阵。 W ( l ) W^{(l)} W(l)中包含的特征权重是可训练的参数。给定一些带有标签的节点,我们可以使用梯度下降算法来训练特征权值矩阵 W ( l ) W^{(l)} W(l)。训练后的参数可以用来对无标签的节点进行分类。

2.3 Graph Sampling

在许多应用中,图非常大,而相应的 GCN 训练具有很高的计算开销。

图采样已被提出来以减小用于 GCN 训练的图的大小,其现有工作可分为两类。一类是逐节点邻居采样,即为每个节点迭代采样固定数量的邻居。如图 2a 所示,给定第 l + 1 层中的某些节点,我们随机选择它们的邻居子集作为第 l 层。这种采样保证了节点嵌入的聚合始终发生在相邻节点之间。逐节点邻居采样的代表性工作是 GraphSAGE [12]。然而,随着构建更多层,采样节点的数量可能会呈指数增长。此外,Huang 等人[25] 指出,它会导致某些节点的嵌入计算冗余,例如图 2a 中的红色节点,它们是其他节点的共享邻居。最近提出了几种方法来提高节点邻居采样的性能,例如 VR-GCN [19] 和 Cluster-GCN [26],但它们不能从根本上解决这个弱点。

另一种方法称为逐层重要性采样。其基本思想是根据基于节点度计算的采样概率为每个 GCN 层独立采样固定数量的节点。

FastGCN [20] 是一种典型的逐层重要性采样方法。然而,由于不同层的节点是独立采样的,一些采样的节点可能与前一层的节点没有联系,比如图 2b 中所示的蓝色标记节点。在图卷积操作期间,一些未链接节点的嵌入可能会丢失,这会降低训练性能。两种采样方法的优缺点促使我们设计一种新的采样策略,该策略可以在采样期间保持相邻关系的同时很好地控制计算开销。

3 FEDGRAPH DESIGN

我们考虑一个典型的联邦图学习设置,它由一组执行本地训练任务的计算客户端 C 和一个负责全局参数更新的服务器组成,如图 3 所示。计算客户端和服务器可能位于不同位置,它们通过广域网连接。每个客户端 i ∈ C i \in C i∈C 维护一个图 G i ( V i , E i ) G_i(V_i, E_i) Gi(Vi,Ei),其中每个节点 v ∈ V i v \in V_i v∈Vi 都与一个不能暴露给其他客户端的特征向量 x ( v ) x(v) x(v)Þ相关联。节点的子集 V i l a b e l ⊆ V i V_i^{label} \subseteq V_i Vilabel⊆Vi 具有用{y(v)\|v \\in V_i\^{label}} 表示的标签,可用作训练数据。边集 Ei 包含 Vi 中节点之间的内部边,以及连接到其他客户端持有的节点的外部边。每个客户端都知道其他客户端维护的邻近节点的存在,但不能直接访问它们的特征向量。

我们假设计算客户端和参数服务器是诚实但好奇的( honest-but-curious),即他们诚实地遵循联邦学习程序,但想要学习其他人持有的特征信息。这是一种典型的威胁模型,已被当前的联邦学习研究广泛使用[15],[27],[28]。下面讨论一些其他更严重的威胁模型。一些恶意客户端可以通过修改发送到参数服务器的模型参数来篡改训练。为了应对这种威胁,我们可以使用可信执行环境(TEE)进行本地训练。TEE 在现代 CPU 上通常可用。它实现了由硬件保证的隔离执行环境,攻击者无法访问 TEE 中的数据和代码。此外,恶意参数服务器可以修改全局模型参数以危害联邦学习。我们可以使用安全多方计算(MPC)或同态加密(HE)来保护模型聚合。此外,TEE 还可用于保护参数服务器上的全局模型聚合。

我们的系统设计如图4所示。我们通过添加新模块来定制参数服务器和客户端,以实现智能采样。参数服务器包含三个主要模块。基于DDPG的采样算法为所有客户端生成采样策略。模型聚合器收集客户端训练的局部特征权重,并将它们聚合起来以生成新的全局特征权重以供下一轮训练。此外,还设计了一个通信模块用于参数服务器和客户端之间的消息交换。该通信模块由基于TCP通信协议的gRPC API实现。每个客户端都有一个GCN构建模块,负责根据采样策略创建GCN模型。设计了一个GCN训练模块来运行训练算法。

在 FedGraph 中,为了预测未标记节点的 y ( v ) y(v) y(v),客户端协作训练全局特征权重 W。训练有多个轮次。在每一轮中,客户端从服务器下载最新的特征权重并构建本地 GCN 来训练这些权重。由于存在外部边缘连接,本地 GCN 训练涉及客户端之间的嵌入共享。之后,他们将更新的特征权重发送到服务器,然后服务器创建新的全局特征权重,用于下一轮训练。虽然 FedGraph 与传统的联邦学习共享类似的流程,但它具有独特的本地训练和全局参数更新流程,如下所示。

3.1 Local GCN Training by Clients

算法 1 中描述了每个客户端 i \\in C 的本地 GCN 训练过程。在每一轮开始时,客户端 i 从服务器下载最新的特征权重 W 以及图采样策略 Pi。本地特征权重初始化为 W i = W − W_i = \mathop{W}\limits^{-} Wi=W−。然后,该客户端启动多个训练迭代以根据本地图数据更新特征权重。具体而言,每次训练迭代由以下两个主要步骤组成。

3.1.1 GCN Construction

我们构建了一个由 L 层组成的 GCN Gi,使用函数 Mod elConstruct() 根据策略 Pi 对节点子集进行采样。基本思路是先随机选择一组带标签的节点,这也称为小批量。然后,对于小批量中的每个节点,我们迭代地聚合距离最多 L − 1 L-1 L−1跳的邻居子集的嵌入。

ModelConstrut() 的伪代码如算法 2 所示。具体来说,采样策略可以用KaTeX parse error: Expected '}', got 'EOF' at end of input: ...,,p_i^{(L-1) \}表示;其中 κ i \kappa_i κi表示小批大小,KaTeX parse error: Expected '}', got 'EOF' at end of input: ...,,p_i^{(L-1) \}分别是L-1层的邻居采样概率。如第1行所示,我们将 κ i \kappa_i κi标记的节点采样为小批处理,它们组成了最终的Lth层。然后,我们迭代地向后构造其他的GCN层。对于第(l+1)层中的每个节点v,我们随机选择其邻居的一个子集 N i ( l ) ( v ) N_i^{(l)}(v) Ni(l)(v)到第1层,概率为pði lÞ。此外,我们创建了一个矩阵 Q i ( l ) Q_i^{(l)} Qi(l)来替换(7)中的Q,其中 V i ( v ) V_i(v) Vi(v)表示原始图Gi中节点v的邻居集。矩阵 Q i ( l ) Q_i^{(l)} Qi(l)描述了采样后更新后的相邻关系,以后将用于特征聚合。第l层的所有采样节点都保持在集合 V i ( l ) V_i^{(l)} Vi(l)中,如最后一行所示。

GCN 构造结合了节点采样和层采样的优势。这些采样概率是独立的,这为层上的细粒度采样提供了机会,就像层采样一样。通过仔细设置这些概率,我们可以避免邻域递归爆炸式扩展带来的高计算成本。同时,由于采样过程基于邻域关系,类似于节点采样,我们可以避免对没有连接的节点进行采样。

3.1.2 GCN Training

构建 GCN 模型后,我们继续基于梯度下降训练此 GCN。跨客户端图卷积操作在算法 1 的第 7-13 行中描述。具体而言,客户端在处理第一个 GCN 层时仅聚合内部邻居的嵌入,如公式 (2) 所示。从第二层开始,我们使客户端能够聚合内部邻居和外部邻居,如公式 (3) 所示。这样的设计可以防止本地原点特征的泄漏,同时实现信息共享。我们将在第 3.3 节中给出安全性分析。

聚合后,应用非线性变换生成下一层的节点嵌入 h i ( l + 1 ) ( v ) h^{(l+1)}_i(v) hi(l+1)(v),如等式所示 (4).其目的是最小化在等式中定义的损失函数(5),我们计算梯度和更新特征权值在等式(6),其中 ϵ \epsilon ϵ是学习率。最后,客户端i将更新后的特性权重(或它们与下载的特性权重的差异)提交给参数服务器。

3.2 Global Parameter Update by the Server

参数服务器更新全局权重的过程如算法3所示。服务器首先初始化随机特征权重W和采样策略 { P 1 , P 2 , . . . P ∣ C ∣ } \{ P_1, P_2, ... P_{|C|} \} {P1,P2,...P∣C∣},然后分别将它们发送给客户端。在接下来的每一轮训练中,它从所有客户端收集更新后的局部特征权重,然后执行两个主要任务。首先,它通过聚合局部权重来创建全局特征权重,如公式(8)所示,其中 κ i \kappa_i κi表示当前训练轮次中客户端i的小批量大小,即标记节点的数量。第二个任务是使用函数GenSampling()更新客户端的采样策略,该函数的详细信息将在下一节中给出。GenSampling()的设计是本文最重要的贡献之一,它依赖于深度强化学习技术来平衡计算开销和模型准确性。最后,服务器将新的全局特征权重和采样策略发送给客户端,开始下一轮训练。

3.3 Security Analysis

为了展示我们提出的算法1如何保护特征数据,我们考虑了两个客户端i和j,它们在训练过程中需要共享节点嵌入,而不丧失通用性。假设客户端i从客户端j聚合嵌入,并希望推断出原始节点特征 h j ( 1 ) h^{(1)}_j hj(1)。请注意, h j ( 1 ) h^{(1)}_j hj(1)是一个包含客户端j所持有的所有节点的特征的矩阵,即, h j ( 1 ) ( v ) = x j ( v ) , v ∈ V j h^{(1)}_j(v)=x_j(v),v \in V_j hj(1)(v)=xj(v),v∈Vj。

我们让 V j i V_j^i Vji表示客户端i在客户端j处的邻近节点。根据算法1,客户端i可以得到 { h j ( 2 ) ( V j i ) W j ( 2 ) , h j ( 3 ) ( V j i ) W j ( 3 ) , . . . , h j ( L ) ( V j i ) W j ( L ) } \{h^{(2)}_j(V_j^i)W_j^{(2)},h^{(3)}_j(V_j^i)W_j^{(3)},...,h^{(L)}_j(V_j^i)W_j^{(L)} \} {hj(2)(Vji)Wj(2),hj(3)(Vji)Wj(3),...,hj(L)(Vji)Wj(L)}的信息。然后,客户端i可以通过使用本地 W i ( l ) W^{(l)}_i Wi(l)近似远程 W j ( l ) W^{(l)}_j Wj(l)来猜测节点嵌入 { h j ( 2 ) , . . . , h j ( L ) } \{h^{(2)}_j,...,h^{(L)}_j \} {hj(2),...,hj(L)},当它们仅从服务器同步全局特征权重时,这是可能的。

然而,客户端 i 很难进一步推断出 h j ( 1 ) ( V j i ) h^{(1)}_j(V^i_j) hj(1)(Vji),因为h^{(2)}_j=\\sigma(Q^{(1)}_j h\^{(1)}_j W\^{(1)}_j) 而客户端 i 没有关于 而客户端 i 没有关于 而客户端i没有关于Q\^{(1)}_j 的信息,即采样后客户端 j 中的相邻矩阵。此外,由于更高层嵌入的维数降低, 的信息,即采样后客户端 j 中的相邻矩阵。此外,由于更高层嵌入的维数降低, 的信息,即采样后客户端j中的相邻矩阵。此外,由于更高层嵌入的维数降低,{h^{(2)}_j,...,h^{(L)}_j }的猜测很难达到高精度。考虑到邻居节点的原始特征可以得到保护,在客户端 j 处获取内部节点的特征将是不可能的。因此,我们可以得出结论,FedGraph 可以在联邦图学习期间保护节点特征的同时实现信息共享。

4 INTELLIGENT GRAPH SAMPLING BASED ON DRL

采样策略 { P 1 , P 2 , . . . P ∣ C ∣ } \{ P_1, P_2, ... P_{|C|} \} {P1,P2,...P∣C∣}决定了 GCN 训练中涉及多少个节点,它们既影响计算开销,又影响训练精度。通过采样更少的节点,我们可以在降低训练精度的同时加速训练过程。另一方面,通过更多的采样节点,我们可以更好地近似原始 GCN 以获得更高的训练精度,但计算成本很高。因此,设计采样策略进行权衡具有重要意义,但现有工作忽略了这一点。同时,由于优化空间很大,采样策略设计很困难,手动调整在实践中几乎不起作用。我们希望自动算法在最少的人为参与下生成良好的采样策略。

通过仔细研究采样策略,我们发现它们对学习性能(训练速度和准确性)的影响无法使用精确的闭式表达式来描述。我们不再苦苦挣扎于启发式算法设计,而是求助于可以自动逼近良好解决方案的深度强化学习(DRL)。DRL 的思想可以以多种方式实现,为具有不同性能的不同应用场景生成一个蓬勃发展的算法系列。通过仔细比较候选 DRL 算法,我们选择使用深度确定性策略梯度(DDPG)算法 [29],它可以有效地处理我们问题的高维和连续动作空间。DDPG 结合了深度 Q 网络和演员评论家方法,从而享受了它们的优势。

4.1 DDPG-Based Problem Formulation

为了应用 DDPG,我们首先将问题表述为马尔可夫决策过程,如下所示。

状态空间。我们将训练轮 t 的系统状态定义为本轮开始时观察到的特征权重,可以表示为 s [ t ] = { W ‾ [ t ] , W 1 [ t ] , W 2 [ t ] , . . . , W ∣ C ∣ [ t ] } s[t]=\{\overline{W}[t],W_1[t],W_2[t],...,W_{|C|}[t] \} s[t]={W[t],W1[t],W2[t],...,W∣C∣[t]}。注意, W ‾ [ t ] \overline{W}[t] W[t]是全局特征权重,Wi½t 表示客户端 i ∈ C i \in C i∈C 的局部特征权重。整个动作空间用 S 表示。由于状态空间巨大,我们利用主成分分析 (PCA) [30] 将高维空间投影到低维空间,同时尽可能保持分布信息完整。

动作空间。在第 t 轮开始时,参数服务器需要为所有客户端确定图采样策略。因此,每轮 t 的动作 a[t] 被定义为相应的采样策略,即 a [ t ] = { P 1 [ t ] , P 2 [ t ] , . . . , P ∣ C ∣ [ t ] } a[t]=\{P_1[t],P_2[t],...,P_{|C|}[t] \} a[t]={P1[t],P2[t],...,P∣C∣[t]}。动作空间用 A 表示。

奖励。由于学习速度和准确率都被视为性能指标,因此应定义奖励来反映它们。我们使用每个训练轮次 t 的完成时间(用 δ [ t ] \delta[t] δ[t] 表示)来评估训练速度。服务器可以通过测量从所有客户端收集本地训练结果的时间消耗来轻松获得 δ [ t ] \delta[t] δ[t]。训练准确率 λ [ t ] \lambda[t] λ[t]是根据参数服务器上的测试集计算的。我们考虑一个典型的联邦设置,其中参数服务器通常是持有测试集的任务发布者。每个客户端都有自己的训练集和验证集,由于隐私问题,这些信息不能公开。利用 δ [ t ] \delta[t] δ[t] 和 λ [ t ] \lambda[t] λ[t]的信息,我们将每轮 t 的奖励定义如下,

其中 Λ \Lambda Λ是目标准确率。常数 Ω \Omega Ω、 α \alpha α 和 β \beta β 可以调整以表达对学习速度和准确率的不同偏好。奖励包含两部分。第一部分评估准确率的提高。我们注意到,随着学习的进行, λ [ t ] \lambda[t] λ[t]表现出非线性的改进。它可以在前几轮训练中快速提高,但后来改进变小。为了使奖励无偏,我们在这里使用指数函数。第二部分以负数形式评估每轮训练的完成时间,以鼓励快速训练。在实践中,客户端的完成时间受许多因素的影响,例如计算硬件或网络延迟。我们通过在 (9) 中添加常数 β \beta β来减轻这些因素的影响,以便我们能够更好地评估不同采样策略的影响。在我们的实验中,我们将时间惩罚,即 α ( δ [ t ] − β ) \alpha (\delta[t]-\beta) α(δ[t]−β) 控制在接近 1 的水平,如 [24] 所述,这可以通过分析轻松实现。

学习策略和目标。我们将问题中的 DRL 学习策略定义为 π θ : S → A \pi_{\theta}:S \to A πθ:S→A,由 θ \theta θ参数化。更准确地说,给定状态 s[t],算法输出确定性动作 a t a_t at。我们基于 DRL 的采样算法的目标是最大化起始状态的预期累积折扣奖励,其定义为

动作价值函数 q π ( s [ t ] , a [ t ] ) q_{\pi}(s[t],a[t]) qπ(s[t],a[t])定义为描述基于策略 π θ \pi_{\theta} πθ 在状态 s[t]下执行动作a[t]后的预期累积折扣奖励,即 q π ( s [ t ] , a [ t ] ) = E [ R [ t ] ∣ S [ t ] = s [ t ] , A [ t ] = π θ ( s [ t ] ) ] q_{\pi}(s[t],a[t])=E[R[t]|S[t]=s[t],A[t]=\pi_{\theta}(s[t])] qπ(s[t],a[t])=E[R[t]∣S[t]=s[t],A[t]=πθ(s[t])]。通常,

我们使用神经网络来近似策略函数 π θ \pi_{\theta} πθ和动作价值函数 q π q_{\pi} qπ。

4.2 Sampling Based on DDPG

基于 DDPG 的采样算法设计如图 5 所示。我们设计了一个参与者网络 μ ( s ∣ θ μ ) \mu(s|\theta_{\mu}) μ(s∣θμ)来预测确定性动作,以及一个评论家网络 q ( s , a ∣ θ q ) q(s,a|\theta_{q}) q(s,a∣θq)来估计动作值函数 q π ( s , a ) q_{\pi}(s,a) qπ(s,a)。同时,我们维护参与者网络和评论家网络的副本,分别表示为 μ ~ ( s ∣ θ ~ μ ) \widetilde\mu(s|\widetilde{\theta}{\mu}) μ (s∣θ μ) 和 q ~ ( s , a ∣ θ ~ q ) \widetilde q(s,a|\widetilde{\theta}{q}) q (s,a∣θ q),也称为目标网络。它们可用于更新原始参与者和评论家网络。

与深度 Q 网络类似,我们维护一个有限大小的重放缓冲区来存储历史转换,定义为(s[t],a[t],r[t],s[t+1])。我们通过从回复缓冲区中抽取一小批转换来更新参与者和评论家网络。当缓冲区已满时,最旧的样本将被丢弃。然后,我们正式介绍基于 DRL 的采样算法,即函数 GenSampling() 的实现细节,并解释它如何学习最佳采样方案。

基于DDPG算法的伪代码如算法4所示。我们在第1-5行初始化四个网络以及系统状态。在第t轮训练开始时,服务器以所有客户端的特征权重形式观察当前状态信息s[t],以及(9)中定义的奖励r[t-1] ,如第9行所示。然后,我们使用PCA方法[30]降低s[t]的维度以获得 s ′ [ t ] s^{\prime}[t] s′[t],然后将转换 ( s ′ [ t − 1 ] , a [ t − 1 ] , r [ t − 1 ] , s ′ [ t ] ) (s^{\prime}[t-1],a[t-1],r[t-1],s^{\prime}[t]) (s′[t−1],a[t−1],r[t−1],s′[t])存储到重放缓冲区中。之后,我们随机选择K个转换的小批量来更新评论家网络,通过最小化损失函数

其中 η μ \eta_{\mu} ημ 是行动者网络的学习率。第 14 行更新了两个目标网络的参数,其中 ϕ ≪ 1 \phi \ll 1 ϕ≪1。最后,我们得到了动作 a[t] ,表示基于更新的网络的采样策略。

5 PERFORMANCE EVALUATION

5.1Experimental Settings

我们使用 PyTorch 和 Deep Graph Library (DGL) [31](一个专用于图形深度学习的 Python 包)实现 FedGraph。我们在 20 个计算客户端上部署了 FedGraph,这些客户端配备 Intel i7-10700 CPU、32 GB 内存和 Geforce RTX 2080 GPU。我们考虑了 4 个流行的图形数据集:Cora、Citeseer、PubMed 和 Reddit,它们已广泛用于 GCN 研究 [12]、[19]、[20]、[21]、[25]、[26]。表 1 总结了这些数据集的一些统计信息。由于某些图(例如 Cora 和 Citeseer)的大小有限,我们使用以下方法基于这些数据集合成大型图。给定表 1 中的数据集,每个客户端 i 随机选择一定比例 ξ i \xi_i ξi的节点作为其局部图数据,并且 ξ 1 , ξ 2 , . . . , ξ ∣ C ∣ \xi_1,\xi_2,...,\xi_{|C|} ξ1,ξ2,...,ξ∣C∣值为 0.8 的正态分布。生成的局部图可能会在某些节点上重叠,尤其是对于 Cora 和 Citeseer 等小型图数据集。对于大型图,我们会仔细控制局部图生成以避免重叠。即使某些节点在合成数据集中重叠,我们也会将它们视为不同的节点,并且不会影响训练性能。[32] 采用了类似的图合成方法。对于局部数据集,我们随机选择一组节点来生成训练集、验证集和测试集。根据原始图维护跨客户端的边连接。对于局部图学习,每个客户端构建一个 3 层 GCN,包括一个输入层和两个卷积层。我们为 Cora、Citeseer 和 PubMed 设置了 16 个隐藏单元、50% 的 dropout 率和 0.01 的学习率。对于 Reddit,有 128 个隐藏单元,dropout 率为 20%,学习率为 0.0001。我们将 Cora、Citeseer 和 Reddit 的批处理大小设置为 256,将 PubMed 的批处理大小设置为 1,024[20]。我们使用 ADAM 优化器进行局部 GCN 训练。对于奖励函数 (9),我们在实验中将指数函数的底数,即 Ω \Omega Ω设置为 128。由于 FedGraph 依赖于奖励函数的指数性质,因此底数对 FedGraph 的影响很小。此外,训练准确率 λ [ t ] \lambda[t] λ[t]和目标准确率 L 之间的差异会影响每轮 t 的奖励。对于每个数据集,我们选择现有工作报告的最佳准确率。即使我们不知道最佳准确率,我们也可以根据经验进行估计。由于 FedGraph 仅依赖于奖励函数的指数性质,因此这种估计对 Fed Graph 影响不大。常数 α \alpha α和 β \beta β 都旨在平衡准确率的提高和时间成本。在我们的实验中,我们将时间惩罚 α ( δ [ t ] − β ) \alpha(\delta[t]-\beta) α(δ[t]−β) 控制在接近 1 的水平,类似于 [24] 中的设置。为了进行比较,我们扩展了以下三种用于联邦图学习的图采样方案。

1)全批次:我们不进行图采样,而是使用原始图来构造GCN。

2)GraphSAGE:一种典型的逐节点邻居采样方法,迭代采样固定数量的邻居。两个卷积层的邻居采样大小分别设置为25和10,与[12],[20],[26]中的设置相同。

3)FastGCN:一种典型的逐层重要性采样方法,对每层独立采样固定数量的节点(也称为层大小)。Cora和Citeseer的层大小设置为256,Reddit和PubMed的层大小设置为8,192,这是[21]提倡的设置。

在 FedGraph 基于 DRL 的采样算法中,参与者网络和批评者网络均有 2 个隐藏层,分别为 512 个和 256 个单元。我们使用工具 sklearn.decomposition.PCA [33] 将特征权重压缩为 20 维。

5.2Experimental Results

基于 DRL 的采样的收敛性。我们让 FedGraph 训练 300 个情节,并在图 6 中显示四个数据集下的累积回报。我们将目标准确率设置为 Cora 的 90.16%、PubMed 的 78.7%、Citeseer 的 87.9% 和 Reddit 的 96.27%。我们观察到四个数据集的累积折现回报可以在不到 100 个情节的时间内收敛到稳定值,特别是最大的数据集 Reddit 在 50 个情节后几乎收敛,如图 6d 所示。这些事实证明了我们提出的基于 DRL 的采样方案具有良好的收敛性。

训练准确率结果。不同采样方案的准确率收敛如图 7 所示,可以看出 FedGraph 可以以更快的速度收敛并获得更高的准确率。为了公平比较,我们使用物理时间而不是训练轮次作为评估不同方案训练速度的指标。这是因为客户端的图大小不同,并且它们在每轮训练中消耗的时间成本也不同。具体来说,FedGraph 在 Cora 上大约需要 5 秒就能达到 75% 的准确率,但其他三种算法需要 10 秒以上才能达到类似的准确率。在 PubMed 中,FedGraph 需要大约 15 秒才能达到 73% 的准确率,但 GraphSAGE 和全批次方案需要 2 倍以上的时间才能收敛。在最大的数据集 Reddit 中,FedGraph 的优势更加明显,如图 7d 所示。我们总结原因如下。 GraphSAGE 存在严重的计算冗余问题,导致训练耗时增加。FastGCN 无法从其他客户端获取足够的嵌入信息,因为有些采样节点没有边连接。全批量方案需要计算所有节点的嵌入,这会导致较高的计算成本,尤其是在较大的图 PubMed 和 Reddit 上。Fed Graph 很好地解决了上述方法的弱点,从而获得了更高的性能。请注意,训练总轮数固定为 300,FastGCN 可以更早完成训练,因为它采样的训练节点较少。此外,为了评估 FedGraph 的可扩展性,我们将实验规模扩大到 50 个客户端,并在图 8 中显示相应的结果。我们可以发现 FedGraph 仍然优于其他采样方案。

图异质性的影响。我们通过改变 ξ i \xi_i ξi的方差来研究图异质性的影响。我们考虑三个异质性水平,相应的方差分别为 0.1(低)、0.5(中)和 1(高)。为了更好地理解,我们计算了最小图大小与最大图大小之间的比率,结果分别约为 0.2、0.4 和 0.6。我们测量训练时间以收敛到大多数采样方案可以实现的目标准确度。在 PubMed 中,我们将目标准确度设置为 72%,但 FastGCN 只能收敛到 68.6%。如图 9 所示,在所有数据集下,随着图变得更加异质,所有采样方案的训练时间都会增加。然而,FedGraph 对时间增长有更好的控制,因为它基于 DRL 的采样同时考虑了训练速度和准确性。

跨客户端嵌入共享的效果。FedGraph 使用跨客户端图卷积操作来实现客户端之间的嵌入共享,同时在本地 GCN 训练期间隐藏本地特征。为了进行比较,我们考虑了两种替代方法,一种方法(称为 FedGraph_allShare)是从第一层共享嵌入以最大化信息共享,另一种方法(称为 Fed Graph_nonShare)是放弃跨客户端共享以简化设计。我们在图 10 中展示了这三种设计的准确率收敛。总训练轮数设置为 300。我们可以发现 Fed Graph 的曲线接近 FedGraph_allShare 的曲线,这表明 FedGraph 虽然消除了第一层的嵌入共享,但信息损失却很小。这是因为高层嵌入包含有关原始特征的信息。因此,FedGraph 可以有效地从跨客户端嵌入共享中学习,而无需交换原始特征。同时,FedGraph 在所有数据集下的表现都明显优于 FedGraph_nonShare。在 Cora 和 Citeseer 中,跨客户端卷积操作可以将训练准确率提高约 10%。在 PubMed 中,两种设计的最终准确率相似,但 FedGraph 可以快速收敛。Reddit 对跨客户端嵌入共享比其他数据集更敏感,FedGraph_nonShare 收敛到约 70% 的准确率,而 FedGraph 可以收敛到约 90%。这是因为 Reddit 具有丰富的边连接(如表 1 所示),忽略跨客户端边会严重破坏整个图结构。注意,FedGraph_nonShare 更早完成 300 轮训练,因为它消除了嵌入共享。

GCN 深度的影响。我们通过改变图卷积层的数量来研究 GCN 深度的影响。结果如图 11 所示。我们可以看到,对于所有数据集,随着层数从 2 增加到 4,时间复杂度明显增加。同时,准确率变化不大。特别是,由于过度平滑问题,Citeseer 的准确率随着 GCN 层数的增加而降低 [6]、[34]、[35]。

非 IID 数据的影响。图 12 展示了 FedGraph 处理非 IID 数据的有效性。我们通过为每个局部图选择节点类型的子集来生成非 IID 数据分布。实验结果表明,FedGraph 仍然优于其他方案。

6.1 Federated Learning

联邦学习因其在实现隐私保护分布式机器学习方面的巨大潜力而引起了广泛的研究关注[3],[22]。赵等人[2]用数学方法证明了非IID数据在联邦学习中的影响,并提出了一种向每个客户端发送一组均匀分布数据以减少非IID数据影响的方法。

最近,有几篇论文研究了与本文不同的联邦设置下的GNN。Suzumura等人,[36]开发了一个联邦学习平台,用于检测多个金融机构的金融犯罪活动。他们通过图分析方法而不是图神经网络将全局图信息提取到欧几里得数据中。此外,他们假设全局图属于所有客户端。相比之下,我们研究非欧几里得数据的GCN,每个客户端都拥有一个本地图。

Jiang 等人 [37] 提出了一种基于 GNN 和联邦学习的新型分布式监控系统。这项工作与我们的论文有两个关键区别。首先,他们考虑跨设备联合设置,涉及大量具有有限计算和通信能力的摄像头。相比之下,我们研究跨孤岛联合设置,通常涉及少量客户端。其次,他们旨在保护训练有素的模型。但是,我们探索客户端间连接并保护节点特征。

Mei 等人 [8] 研究具有垂直联合设置的联合隐私保护图神经网络,即假设图结构、特征和标签属于不同的来源。但是,我们考虑水平联合设置,即每个本地客户端都维护一个具有自己的图结构、节点特征和标签的完整图数据集。

6.2 Graph Convolutional Networks

由于其出色的性能,GCN 已广泛应用于许多图学习应用中,例如节点分类 [6]、[38]、链接预测 [39] 和推荐系统 [40]。最近,一些研究将 GCN 应用于自然语言处理任务,如机器翻译 [41] 和关系分类 [42]。为了加速 GCN 训练,NeuGraph [43] 被提出作为一种新框架,支持图上高效且可扩展的并行神经网络计算。NeuGraph 不仅可以支持单 GPU 训练,还可以支持多 GPU 上的并行处理。Scardapane 等人 [13] 提出了基于消息传递交换的分布式 GCN 训练。然而,这项工作忽略了联邦学习场景所必需的隐私保护。

图采样可以有效减少 GCN 训练开销。Hamilton 等人 [12] 提出了 GraphSAGE,通过对邻近节点的子集进行采样来构建简化的 GCN。然而,GraphSAGE 在某些节点作为公共邻居时会产生冗余计算 [25]。尽管已经提出了几项通过减小采样节点的大小来减轻冗余计算的工作,如 VR-GCN [19] 和 Cluster-GCN [26],但它们在训练非常大且深的 GCN 时仍然不能很好地解决这个问题。为了解决这个问题,已经提出了分层采样方法,如 FastGCN [20] 和 LADIES [21],它们对每一层节点进行独立采样,而不是对每个节点进行邻居采样。这种采样方法可以有效地降低计算成本,但由于独立采样,一些采样节点可能没有连接,这会降低训练精度。此外,上述所有采样方法都依赖于需要手动调整的手工制作参数。现有工作的弱点促使本文设计了具有智能采样的 FedGraph。

7 CONCLUSION

在本文中,我们提出了 FedGraph 作为一种新颖的联邦图系统,以实现隐私保护的分布式 GCN 学习。与传统的联邦学习不同,FedGraph 更具挑战性,因为 GCN 训练过程涉及客户端之间的嵌入共享。为了应对这一挑战,FedGraph 使用一种新颖的跨客户端图卷积操作在共享之前压缩嵌入,以便可以很好地隐藏私人信息。此外,为了减少 GCN 训练开销,FedGraph 采用了基于 DRL 的采样方案,可以很好地平衡训练速度和准确性。在 20 个客户端测试平台上的实验结果表明,FedGraph 明显优于现有方案。

相关推荐
菜菜子爱学习2 小时前
Nginx学习笔记(八)—— Nginx缓存集成
笔记·学习·nginx·缓存·运维开发
chillxiaohan3 小时前
GO学习记录五——数据库表的增删改查
数据库·学习·golang
CV实验室3 小时前
ICCV 2025 | 4相机干掉480机位?CMU MonoFusion高斯泼溅重构4D人体!
人工智能·数码相机·计算机视觉·论文
憨憨の大鸭鸭4 小时前
python爬虫学习(2)
爬虫·学习
_hermit:7 小时前
【从零开始java学习|第六篇】运算符的使用与注意事项
java·学习
rannn_1118 小时前
【Linux学习|黑马笔记|Day4】IP地址、主机名、网络请求、下载、端口、进程管理、主机状态监控、环境变量、文件的上传和下载、压缩和解压
linux·笔记·后端·学习
Moonnnn.9 小时前
【51单片机学习】定时器、串口、LED点阵屏、DS1302实时时钟、蜂鸣器
笔记·单片机·学习·51单片机
ai绘画-安安妮10 小时前
零基础学LangChain:核心概念与基础组件解析
人工智能·学习·ai·程序员·langchain·大模型·转行
云边小贩11 小时前
C++ STL学习 之 泛型编程
开发语言·c++·学习·类与对象
菜菜子爱学习12 小时前
Nginx学习笔记(九)—— Nginx Rewrite深度解析
linux·运维·笔记·学习·nginx