GraphSAGE 学习笔记

1. GraphSAGE 是什么

GraphSAGE 是一种典型的图神经网络模型,它的全称可以理解为 Graph Sample and Aggregate,也就是"图上的采样与聚合"。

它的核心思想是:

复制代码
对每个节点,从它的邻居中采样一部分节点,然后聚合这些邻居的信息,再和节点自身的信息进行融合,最终得到新的节点表示。

如果用一句话概括 GraphSAGE:

复制代码
GraphSAGE = 邻居采样 + 邻居聚合 + 自身信息融合

它和 GCN 一样,都属于消息传递类图神经网络。节点会通过边从邻居那里接收信息,然后更新自己的表示。

但是 GraphSAGE 和普通 GCN 的一个重要区别是:

复制代码
GCN 通常聚合所有邻居;
GraphSAGE 通常从邻居中采样固定数量的节点进行聚合。

这使得 GraphSAGE 更适合大规模图数据,因为它可以控制每次训练时涉及的邻居数量,避免计算量和显存占用随着图规模急剧增加。


2. 为什么需要 GraphSAGE

在 GCN 中,节点更新通常可以写成类似下面的形式:

复制代码
H' = A_hat H W

其中:

复制代码
H

表示节点特征矩阵;

复制代码
A_hat

表示经过处理后的邻接矩阵;

复制代码
W

表示可学习权重矩阵。

这种方式本质上是通过邻接矩阵把邻居信息聚合到当前节点上。

对于 Cora 这种小规模图来说,可以一次性把整张图送入模型:

复制代码
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])

但是如果图很大,例如有几十万甚至几百万个节点,那么每次都对整张图做消息传播就会非常消耗显存和计算资源。

GraphSAGE 的提出就是为了解决这个问题之一。

它不要求每次都使用所有邻居,而是从邻居中采样固定数量的节点。例如:

复制代码
每个节点只采样 10 个一阶邻居;
每个一阶邻居再采样 10 个二阶邻居。

这样模型每次看到的是一个局部子图,而不是完整大图。


3. GraphSAGE 的基本流程

对于一个节点 v,GraphSAGE 的一层传播通常包含三个步骤。


3.1 第一步:采样邻居

假设节点 v 的邻居集合是:

复制代码
N(v) = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}

GraphSAGE 不一定使用所有邻居,而是从中采样一部分,例如:

复制代码
S(v) = {2, 5, 8}

也就是说,GraphSAGE 会从节点 v 的邻居中选出一部分节点参与本次计算。

这样做的好处是可以控制计算规模。


3.2 第二步:聚合邻居特征

假设采样到的邻居节点是 2、5、8,它们的特征分别是:

复制代码
h_2, h_5, h_8

GraphSAGE 会使用一个聚合函数把这些邻居特征合成为一个邻居表示:

复制代码
h_N(v) = AGGREGATE({h_2, h_5, h_8})

最简单的聚合方式是求平均:

复制代码
h_N(v) = mean(h_2, h_5, h_8)

可以理解为:

复制代码
把多个邻居节点的信息压缩成一个邻居向量。

3.3 第三步:融合自身信息和邻居信息

GraphSAGE 不只看邻居,也会保留节点自身的信息。

所以它会把节点自己的表示和邻居聚合表示结合起来:

复制代码
h_v' = σ(W · CONCAT(h_v, h_N(v)))

其中:

复制代码
h_v

表示节点 v 自身的表示;

复制代码
h_N(v)

表示邻居聚合后的表示;

复制代码
CONCAT

表示拼接;

复制代码
W

表示可学习权重矩阵;

复制代码
σ

表示激活函数。

这一步的本质是:

复制代码
把自身信息和邻居信息放在一起,再让模型学习如何融合它们。

4. GraphSAGE 中为什么是拼接,而不是相加

GraphSAGE 中常见的更新方式是:

复制代码
h_v' = σ(W · CONCAT(h_v, h_N(v)))

这里有一个很重要的问题:为什么要把自身表示和邻居表示拼接起来,而不是直接相加?


4.1 相加会提前混合信息

假设节点自身表示是:

复制代码
h_v = [1, 2, 3]

邻居聚合表示是:

复制代码
h_N = [4, 5, 6]

如果直接相加:

复制代码
h_v + h_N = [5, 7, 9]

相加之后,模型就很难分清楚哪些信息来自节点自己,哪些信息来自邻居。

例如第一个维度的 5 可能来自:

复制代码
1 + 4

也可能来自其他组合。

也就是说,相加会把自身信息和邻居信息提前混在一起。


4.2 拼接可以保留信息来源

如果使用拼接:

复制代码
CONCAT(h_v, h_N) = [1, 2, 3, 4, 5, 6]

那么前半部分仍然表示节点自身信息,后半部分仍然表示邻居信息。

这样模型可以清楚地区分:

复制代码
哪些信息来自自己;
哪些信息来自邻居。

之后再通过可学习矩阵 W 进行融合,模型就可以自己学习:

复制代码
应该更重视自身信息,还是更重视邻居信息。

例如,模型可以学到:

复制代码
0.8 × 自身信息 + 0.2 × 邻居信息

也可以学到:

复制代码
0.3 × 自身信息 + 0.7 × 邻居信息

甚至可以学习更复杂的跨维度组合。


4.3 拼接后维度会变大

假设:

复制代码
h_v 的维度是 128
h_N 的维度也是 128

那么拼接后维度会变成:

复制代码
256

然后线性层 W 会把它重新映射到目标维度,比如重新变成 128 维。

所以可以理解为:

复制代码
拼接先保留更多信息,再由线性层学习如何压缩和融合。

4.4 一句话理解拼接

GraphSAGE 使用拼接而不是相加,是因为:

复制代码
拼接可以保留"自身信息"和"邻居信息"的区别,然后让后面的可学习线性层决定如何融合;
相加则会提前把两类信息混在一起,降低模型表达能力。

可以记成:

复制代码
拼接 = 先保留信息,再学习融合
相加 = 先固定融合,再继续学习

5. GraphSAGE 的几种聚合方式

GraphSAGE 中的 AGGREGATE 可以有不同实现方式。


5.1 Mean Aggregator

Mean Aggregator 是最容易理解的方式。

它直接对邻居节点的特征求平均:

复制代码
h_N(v) = mean({h_u, u ∈ N(v)})

可以理解为:

复制代码
用邻居节点的平均特征代表邻居整体信息。

5.2 LSTM Aggregator

LSTM Aggregator 会把邻居特征输入到 LSTM 中,再得到聚合后的邻居表示。

不过图中的邻居本身没有天然顺序,而 LSTM 对输入顺序敏感,所以这个方法理解起来稍微别扭,实际学习时不需要优先掌握。


5.3 Pooling Aggregator

Pooling Aggregator 会先对每个邻居特征进行一次非线性变换,然后再进行池化操作,例如 max pooling。

大致形式是:

复制代码
h_u' = σ(W h_u + b)
h_N(v) = max({h_u'})

可以理解为:

复制代码
先提取每个邻居的特征,再从这些邻居特征中选出最显著的信息。

6. PyG 中的 GraphSAGE 模型代码

在 PyTorch Geometric 中,可以使用 SAGEConv 实现 GraphSAGE。

一个简单的两层 GraphSAGE 模型如下:

复制代码
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv


class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x

其中:

复制代码
SAGEConv(in_channels, hidden_channels)

表示第一层 GraphSAGE,把输入特征维度变成隐藏层维度。

复制代码
SAGEConv(hidden_channels, out_channels)

表示第二层 GraphSAGE,把隐藏层维度变成输出维度。

如果是节点分类任务,out_channels 通常等于类别数。


7. GraphSAGE 的输入

GraphSAGE 的输入和 GCN 类似,主要包括:

复制代码
x
edge_index

7.1 x:节点特征矩阵

复制代码
x

表示节点特征矩阵。

假设图中有 2708 个节点,每个节点有 1433 维特征,那么:

复制代码
x.shape = [2708, 1433]

每一行表示一个节点的特征。


7.2 edge_index:边结构

复制代码
edge_index

表示图中的边结构。

它的形状通常是:

复制代码
edge_index.shape = [2, num_edges]

例如:

复制代码
edge_index = [
    [0, 0, 1, 2],
    [1, 2, 2, 3]
]

表示边:

复制代码
0 -> 1
0 -> 2
1 -> 2
2 -> 3

从数学理解上,可以把它看成邻接矩阵的稀疏表示;但在工程实现中,PyG 通常不会真的构造完整的 N × N 邻接矩阵,而是基于 edge_index 做消息传递。


8. 小图上的普通训练方式

如果图比较小,例如 Cora,可以不使用邻居采样,而是直接整图训练:

复制代码
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])

这表示:

复制代码
模型对整张图所有节点都输出预测结果;
但计算 loss 时,只使用训练节点。

其中:

复制代码
out[data.train_mask]

表示训练节点的预测结果。

复制代码
data.y[data.train_mask]

表示训练节点的真实标签。


9. NeighborLoader 是什么

NeighborLoader 不是 GraphSAGE 的模型层,而是 PyTorch Geometric 中用于图数据采样的加载器。

它的作用是:

复制代码
在训练时,每次从训练节点中取出一批中心节点;
然后围绕这些中心节点采样一定数量的邻居节点;
构造出一个局部子图;
再把这个子图送入 GraphSAGE 模型进行训练。

也就是说,NeighborLoader 每次返回的不是一批孤立节点,而是一个围绕中心节点采样出来的小图。

它可以理解为图神经网络中的特殊 DataLoader。

普通图像分类任务中的 DataLoader 每次返回一批图片,而 NeighborLoader 每次返回一批中心节点及其邻居组成的子图。


10. NeighborLoader 代码示例

常见写法如下:

复制代码
from torch_geometric.loader import NeighborLoader

train_loader = NeighborLoader(
    data,
    num_neighbors=[10, 10],
    batch_size=128,
    input_nodes=data.train_mask,
    shuffle=True
)

这段代码的意思是:

复制代码
从 data.train_mask 指定的训练节点中,
每次取 128 个节点作为中心节点;
由于模型是两层 GraphSAGE,所以采样两跳邻居;
每一跳最多采样 10 个邻居;
最后把中心节点和采样到的邻居节点组成一个局部子图,用于模型训练。

11. NeighborLoader 参数解释

11.1 data

复制代码
data

表示原始整张图。

其中通常包含:

复制代码
data.x

节点特征矩阵。

复制代码
data.edge_index

边结构。

复制代码
data.y

节点标签。

复制代码
data.train_mask

训练节点掩码。


11.2 input_nodes=data.train_mask

复制代码
input_nodes=data.train_mask

表示只从训练节点中选择中心节点。

data.train_mask 通常是一个布尔向量,例如:

复制代码
[True, True, False, False, True, ...]

其中 True 表示该节点属于训练集,False 表示该节点不属于训练集。

所以 NeighborLoader 每次会从这些值为 True 的节点中选出一批节点作为当前 batch 的中心节点。


11.3 batch_size=128

复制代码
batch_size=128

表示每次选取 128 个训练节点作为中心节点。

需要注意:

复制代码
这里的 batch_size 不是最终子图中的节点数量,而是中心节点数量。

因为每个中心节点还会采样它的邻居,所以最终返回的子图节点数通常会大于 128。

例如:

复制代码
batch_size=128
num_neighbors=[10, 10]

最终采样出来的小图可能包含几百个甚至上千个节点。


11.4 num_neighbors=[10, 10]

复制代码
num_neighbors=[10, 10]

表示进行两层邻居采样,通常对应两层 GraphSAGE 模型。

可以这样理解:

复制代码
第一层:采样中心节点的一阶邻居,每个中心节点最多采样 10 个邻居;
第二层:再为这些一阶邻居采样它们的邻居,也就是中心节点的二阶邻居,每个节点最多采样 10 个邻居。

所以对于一个中心节点来说,理论上最多会涉及:

复制代码
1 + 10 + 10 × 10 = 111 个节点

其中:

复制代码
1 表示中心节点自己;
10 表示一阶邻居;
10 × 10 表示二阶邻居。

实际采样节点数通常会小于这个数,因为不同节点之间可能有重复邻居。


12. 为什么两层 GraphSAGE 需要采样两跳邻居

如果模型是两层 GraphSAGE:

复制代码
self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, out_channels)

那么中心节点最终会进行两次消息传递。

第一层让节点聚合一阶邻居的信息。

第二层让中心节点继续聚合一阶邻居更新后的表示,而这些一阶邻居的表示中已经包含了二阶邻居的信息。

因此,两层 GraphSAGE 最终可以让中心节点获得二阶邻居的信息。

所以:

复制代码
一层 GraphSAGE 通常对应 num_neighbors=[10]
两层 GraphSAGE 通常对应 num_neighbors=[10, 10]
三层 GraphSAGE 通常对应 num_neighbors=[10, 10, 10]

13. NeighborLoader 每次返回什么

当我们写:

复制代码
for batch in train_loader:
    print(batch)
    break

NeighborLoader 每次会返回一个 batch

这个 batch 本质上也是一个 PyG 的 Data 对象,表示一个采样出来的小图。


13.1 batch.x

复制代码
batch.x

表示当前采样子图中所有节点的特征。

注意,它不只是中心节点的特征,还包括采样到的邻居节点特征。

因此它的形状可能是:

复制代码
batch.x.shape = [采样子图节点数, 特征维度]

例如:

复制代码
batch.x.shape = [843, 1433]

表示当前小图中一共有 843 个节点,每个节点有 1433 维特征。


13.2 batch.edge_index

复制代码
batch.edge_index

表示当前采样子图中的边结构。

它不是原始整张图的完整边结构,而是当前采样小图中的边结构。

它的形状仍然是:

复制代码
batch.edge_index.shape = [2, num_edges]

例如:

复制代码
batch.edge_index.shape = [2, 3265]

表示当前小图中有 3265 条边。


13.3 batch.y

复制代码
batch.y

表示当前采样子图中所有节点的标签。

注意,它也包含中心节点和邻居节点的标签。

但是在训练时,通常只对中心节点计算 loss,邻居节点只是用于提供消息传播的信息。


13.4 batch.batch_size

复制代码
batch.batch_size

表示当前 batch 中中心节点的数量。

如果设置:

复制代码
batch_size=128

那么多数情况下:

复制代码
batch.batch_size == 128

但是最后一个 batch 可能不足 128。


13.5 batch.n_id

复制代码
batch.n_id

表示当前子图中的节点在原始整张图中的真实编号。

因为 NeighborLoader 会对采样出来的小图重新编号。

例如,原图中的节点编号可能是:

复制代码
[0, 5, 9, 20, 33, 100]

但是在当前采样子图中,它们会被重新编号成:

复制代码
[0, 1, 2, 3, 4, 5]

batch.n_id 就用来记录当前子图节点编号和原图节点编号之间的对应关系。


14. 为什么 loss 只计算前 batch_size 个节点

使用 NeighborLoader 训练时,常见写法是:

复制代码
for batch in train_loader:
    batch = batch.to(device)

    out = model(batch.x, batch.edge_index)

    loss = F.cross_entropy(
        out[:batch.batch_size],
        batch.y[:batch.batch_size]
    )

这里最重要的是:

复制代码
out[:batch.batch_size]

因为 NeighborLoader 返回的 batch 中:

复制代码
前 batch.batch_size 个节点是本次真正要训练的中心节点;
后面的节点是被采样进来的邻居节点。

也就是说:

复制代码
batch.x 中的节点可以分成两部分:

前 batch.batch_size 个节点:中心节点,需要计算 loss;
后面的节点:邻居节点,只用于消息传播,不计算 loss。

所以训练时只对中心节点的预测结果计算损失。

如果对整个 out 计算 loss,就会把邻居节点也当成训练目标,这通常不是我们想要的。


15. 完整训练代码示例

下面是一个使用 Cora 数据集训练 GraphSAGE 的简单例子:

复制代码
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import SAGEConv
from torch_geometric.loader import NeighborLoader


dataset = Planetoid(root='data/Cora', name='Cora')
data = dataset[0]


class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return x


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = GraphSAGE(
    in_channels=dataset.num_features,
    hidden_channels=64,
    out_channels=dataset.num_classes
).to(device)

data = data.to(device)

train_loader = NeighborLoader(
    data,
    num_neighbors=[10, 10],
    batch_size=128,
    input_nodes=data.train_mask,
    shuffle=True
)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)


for epoch in range(1, 101):
    model.train()
    total_loss = 0

    for batch in train_loader:
        batch = batch.to(device)

        optimizer.zero_grad()

        out = model(batch.x, batch.edge_index)

        loss = F.cross_entropy(
            out[:batch.batch_size],
            batch.y[:batch.batch_size]
        )

        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f'Epoch: {epoch:03d}, Loss: {total_loss:.4f}')

16. 训练过程理解

假设原图中有 1000 个训练节点,并且设置:

复制代码
batch_size=128
num_neighbors=[10, 10]

那么一个 epoch 的训练过程大致是:

复制代码
第一步:从 1000 个训练节点中取出 128 个中心节点。
第二步:为这 128 个中心节点采样一阶邻居和二阶邻居。
第三步:把中心节点和采样到的邻居节点组成一个小图。
第四步:把这个小图送入 GraphSAGE 模型,执行前向传播。
第五步:只对这 128 个中心节点的预测结果计算 loss。
第六步:反向传播,更新模型参数。
第七步:继续取下一批 128 个训练节点,重复上述过程。

直到所有训练节点都被遍历一遍,就完成了一个 epoch。


17. 与普通 DataLoader 的对比

普通图像分类中的 DataLoader:

复制代码
for images, labels in train_loader:
    out = model(images)
    loss = criterion(out, labels)

每次返回的是一批图片和对应标签。

GraphSAGE 中的 NeighborLoader:

复制代码
for batch in train_loader:
    out = model(batch.x, batch.edge_index)
    loss = criterion(out[:batch.batch_size], batch.y[:batch.batch_size])

每次返回的是一批中心节点和它们采样得到的局部子图。

所以可以这样对比:

复制代码
普通 DataLoader:
一批样本 = 一批图片

NeighborLoader:
一批样本 = 一批中心节点 + 它们的局部邻居子图

18. 容易混淆的地方

18.1 batch_size 不是子图节点总数

复制代码
batch_size=128

表示每次选择 128 个中心节点,不表示采样子图中只有 128 个节点。

由于还要采样邻居,最终子图节点数通常会大于 128。


18.2 num_neighbors=[10, 10] 不是总共采样 20 个邻居

复制代码
num_neighbors=[10, 10]

表示两层邻居采样,不是总共采样 20 个邻居。

可以理解为:

复制代码
第一跳:每个中心节点最多采样 10 个邻居;
第二跳:每个一阶邻居最多再采样 10 个邻居。

18.3 batch.x 包含中心节点和邻居节点

batch.x 中不只有中心节点,也包含采样到的邻居节点。

模型需要这些邻居节点的信息来完成消息传播。


18.4 loss 只对中心节点计算

虽然 batch.x 中包含很多节点,但是训练目标是当前 batch 的中心节点。

因此 loss 通常写成:

复制代码
loss = F.cross_entropy(
    out[:batch.batch_size],
    batch.y[:batch.batch_size]
)

19. 最终总结

GraphSAGE 的核心思想是:

复制代码
从邻居中采样一部分节点,聚合这些邻居的信息,再和节点自身信息融合,得到新的节点表示。

它相比 GCN 更适合大图,因为它不必每次都聚合所有邻居,而是通过邻居采样控制计算规模。

GraphSAGE 中自身信息和邻居信息通常使用拼接,而不是直接相加。原因是:

复制代码
拼接可以保留自身信息和邻居信息的区别,
然后让后面的线性层学习如何融合;
而相加会提前把两类信息混在一起,降低表达能力。

NeighborLoader 是 GraphSAGE 训练中非常重要的工具。它的作用是:

复制代码
选择中心节点 + 采样邻居 + 构造小图 + 分批训练

最核心的训练代码是:

复制代码
for batch in train_loader:
    batch = batch.to(device)

    out = model(batch.x, batch.edge_index)

    loss = F.cross_entropy(
        out[:batch.batch_size],
        batch.y[:batch.batch_size]
    )

    loss.backward()
    optimizer.step()

其中:

复制代码
batch.x

表示采样子图中的节点特征。

复制代码
batch.edge_index

表示采样子图中的边结构。

复制代码
batch.batch_size

表示中心节点数量。

复制代码
out[:batch.batch_size]

表示中心节点的预测结果。

复制代码
batch.y[:batch.batch_size]

表示中心节点的真实标签。

最后可以把 GraphSAGE 和 NeighborLoader 的关系记成一句话:

复制代码
GraphSAGE 负责学习"如何聚合邻居信息",NeighborLoader 负责训练时"采样哪些邻居来聚合"。
相关推荐
AI科技星1 小时前
全域数学版木牛流马(融合仿生兽+古制复原终版优化方案)【乖乖数学】
人工智能·算法·数学建模·数据挖掘·量子计算
richard_yuu1 小时前
数据结构精讲:图的最短路径与关键路径
数据结构·算法
佳xuan1 小时前
神经网络解析
人工智能·深度学习·神经网络
智者知已应修善业1 小时前
【51单片机一个按键切合初始流水灯按一下对半闪烁按一下显示时间】2023-10-16
c++·经验分享·笔记·算法·51单片机
沪漂阿龙1 小时前
面试题:激活函数是什么?为什么必须非线性,Sigmoid、ReLU、Softmax 怎么选,一文讲透深度学习高频考点
人工智能·深度学习
沪漂阿龙1 小时前
AI大模型面试题:模型求解和优化全解析——梯度下降、BGD、SGD、MBGD、学习率、Batch Size、损失函数、优化器一文讲透
人工智能·学习·机器学习
晚风叙码1 小时前
堆排序建堆策略对比:向上调整与向下调整的时间复杂度分析
算法
lsjweiyi2 小时前
WSL2 + ROCm + PyTorch 深度学习环境配置全记录
人工智能·pytorch·深度学习
洛水水2 小时前
【力扣100题】28. 翻转二叉树
算法·leetcode