1. GraphSAGE 是什么
GraphSAGE 是一种典型的图神经网络模型,它的全称可以理解为 Graph Sample and Aggregate,也就是"图上的采样与聚合"。
它的核心思想是:
对每个节点,从它的邻居中采样一部分节点,然后聚合这些邻居的信息,再和节点自身的信息进行融合,最终得到新的节点表示。
如果用一句话概括 GraphSAGE:
GraphSAGE = 邻居采样 + 邻居聚合 + 自身信息融合
它和 GCN 一样,都属于消息传递类图神经网络。节点会通过边从邻居那里接收信息,然后更新自己的表示。
但是 GraphSAGE 和普通 GCN 的一个重要区别是:
GCN 通常聚合所有邻居;
GraphSAGE 通常从邻居中采样固定数量的节点进行聚合。
这使得 GraphSAGE 更适合大规模图数据,因为它可以控制每次训练时涉及的邻居数量,避免计算量和显存占用随着图规模急剧增加。
2. 为什么需要 GraphSAGE
在 GCN 中,节点更新通常可以写成类似下面的形式:
H' = A_hat H W
其中:
H
表示节点特征矩阵;
A_hat
表示经过处理后的邻接矩阵;
W
表示可学习权重矩阵。
这种方式本质上是通过邻接矩阵把邻居信息聚合到当前节点上。
对于 Cora 这种小规模图来说,可以一次性把整张图送入模型:
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
但是如果图很大,例如有几十万甚至几百万个节点,那么每次都对整张图做消息传播就会非常消耗显存和计算资源。
GraphSAGE 的提出就是为了解决这个问题之一。
它不要求每次都使用所有邻居,而是从邻居中采样固定数量的节点。例如:
每个节点只采样 10 个一阶邻居;
每个一阶邻居再采样 10 个二阶邻居。
这样模型每次看到的是一个局部子图,而不是完整大图。
3. GraphSAGE 的基本流程
对于一个节点 v,GraphSAGE 的一层传播通常包含三个步骤。
3.1 第一步:采样邻居
假设节点 v 的邻居集合是:
N(v) = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
GraphSAGE 不一定使用所有邻居,而是从中采样一部分,例如:
S(v) = {2, 5, 8}
也就是说,GraphSAGE 会从节点 v 的邻居中选出一部分节点参与本次计算。
这样做的好处是可以控制计算规模。
3.2 第二步:聚合邻居特征
假设采样到的邻居节点是 2、5、8,它们的特征分别是:
h_2, h_5, h_8
GraphSAGE 会使用一个聚合函数把这些邻居特征合成为一个邻居表示:
h_N(v) = AGGREGATE({h_2, h_5, h_8})
最简单的聚合方式是求平均:
h_N(v) = mean(h_2, h_5, h_8)
可以理解为:
把多个邻居节点的信息压缩成一个邻居向量。
3.3 第三步:融合自身信息和邻居信息
GraphSAGE 不只看邻居,也会保留节点自身的信息。
所以它会把节点自己的表示和邻居聚合表示结合起来:
h_v' = σ(W · CONCAT(h_v, h_N(v)))
其中:
h_v
表示节点 v 自身的表示;
h_N(v)
表示邻居聚合后的表示;
CONCAT
表示拼接;
W
表示可学习权重矩阵;
σ
表示激活函数。
这一步的本质是:
把自身信息和邻居信息放在一起,再让模型学习如何融合它们。
4. GraphSAGE 中为什么是拼接,而不是相加
GraphSAGE 中常见的更新方式是:
h_v' = σ(W · CONCAT(h_v, h_N(v)))
这里有一个很重要的问题:为什么要把自身表示和邻居表示拼接起来,而不是直接相加?
4.1 相加会提前混合信息
假设节点自身表示是:
h_v = [1, 2, 3]
邻居聚合表示是:
h_N = [4, 5, 6]
如果直接相加:
h_v + h_N = [5, 7, 9]
相加之后,模型就很难分清楚哪些信息来自节点自己,哪些信息来自邻居。
例如第一个维度的 5 可能来自:
1 + 4
也可能来自其他组合。
也就是说,相加会把自身信息和邻居信息提前混在一起。
4.2 拼接可以保留信息来源
如果使用拼接:
CONCAT(h_v, h_N) = [1, 2, 3, 4, 5, 6]
那么前半部分仍然表示节点自身信息,后半部分仍然表示邻居信息。
这样模型可以清楚地区分:
哪些信息来自自己;
哪些信息来自邻居。
之后再通过可学习矩阵 W 进行融合,模型就可以自己学习:
应该更重视自身信息,还是更重视邻居信息。
例如,模型可以学到:
0.8 × 自身信息 + 0.2 × 邻居信息
也可以学到:
0.3 × 自身信息 + 0.7 × 邻居信息
甚至可以学习更复杂的跨维度组合。
4.3 拼接后维度会变大
假设:
h_v 的维度是 128
h_N 的维度也是 128
那么拼接后维度会变成:
256
然后线性层 W 会把它重新映射到目标维度,比如重新变成 128 维。
所以可以理解为:
拼接先保留更多信息,再由线性层学习如何压缩和融合。
4.4 一句话理解拼接
GraphSAGE 使用拼接而不是相加,是因为:
拼接可以保留"自身信息"和"邻居信息"的区别,然后让后面的可学习线性层决定如何融合;
相加则会提前把两类信息混在一起,降低模型表达能力。
可以记成:
拼接 = 先保留信息,再学习融合
相加 = 先固定融合,再继续学习
5. GraphSAGE 的几种聚合方式
GraphSAGE 中的 AGGREGATE 可以有不同实现方式。
5.1 Mean Aggregator
Mean Aggregator 是最容易理解的方式。
它直接对邻居节点的特征求平均:
h_N(v) = mean({h_u, u ∈ N(v)})
可以理解为:
用邻居节点的平均特征代表邻居整体信息。
5.2 LSTM Aggregator
LSTM Aggregator 会把邻居特征输入到 LSTM 中,再得到聚合后的邻居表示。
不过图中的邻居本身没有天然顺序,而 LSTM 对输入顺序敏感,所以这个方法理解起来稍微别扭,实际学习时不需要优先掌握。
5.3 Pooling Aggregator
Pooling Aggregator 会先对每个邻居特征进行一次非线性变换,然后再进行池化操作,例如 max pooling。
大致形式是:
h_u' = σ(W h_u + b)
h_N(v) = max({h_u'})
可以理解为:
先提取每个邻居的特征,再从这些邻居特征中选出最显著的信息。
6. PyG 中的 GraphSAGE 模型代码
在 PyTorch Geometric 中,可以使用 SAGEConv 实现 GraphSAGE。
一个简单的两层 GraphSAGE 模型如下:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
class GraphSAGE(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
return x
其中:
SAGEConv(in_channels, hidden_channels)
表示第一层 GraphSAGE,把输入特征维度变成隐藏层维度。
SAGEConv(hidden_channels, out_channels)
表示第二层 GraphSAGE,把隐藏层维度变成输出维度。
如果是节点分类任务,out_channels 通常等于类别数。
7. GraphSAGE 的输入
GraphSAGE 的输入和 GCN 类似,主要包括:
x
edge_index
7.1 x:节点特征矩阵
x
表示节点特征矩阵。
假设图中有 2708 个节点,每个节点有 1433 维特征,那么:
x.shape = [2708, 1433]
每一行表示一个节点的特征。
7.2 edge_index:边结构
edge_index
表示图中的边结构。
它的形状通常是:
edge_index.shape = [2, num_edges]
例如:
edge_index = [
[0, 0, 1, 2],
[1, 2, 2, 3]
]
表示边:
0 -> 1
0 -> 2
1 -> 2
2 -> 3
从数学理解上,可以把它看成邻接矩阵的稀疏表示;但在工程实现中,PyG 通常不会真的构造完整的 N × N 邻接矩阵,而是基于 edge_index 做消息传递。
8. 小图上的普通训练方式
如果图比较小,例如 Cora,可以不使用邻居采样,而是直接整图训练:
out = model(data.x, data.edge_index)
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
这表示:
模型对整张图所有节点都输出预测结果;
但计算 loss 时,只使用训练节点。
其中:
out[data.train_mask]
表示训练节点的预测结果。
data.y[data.train_mask]
表示训练节点的真实标签。
9. NeighborLoader 是什么
NeighborLoader 不是 GraphSAGE 的模型层,而是 PyTorch Geometric 中用于图数据采样的加载器。
它的作用是:
在训练时,每次从训练节点中取出一批中心节点;
然后围绕这些中心节点采样一定数量的邻居节点;
构造出一个局部子图;
再把这个子图送入 GraphSAGE 模型进行训练。
也就是说,NeighborLoader 每次返回的不是一批孤立节点,而是一个围绕中心节点采样出来的小图。
它可以理解为图神经网络中的特殊 DataLoader。
普通图像分类任务中的 DataLoader 每次返回一批图片,而 NeighborLoader 每次返回一批中心节点及其邻居组成的子图。
10. NeighborLoader 代码示例
常见写法如下:
from torch_geometric.loader import NeighborLoader
train_loader = NeighborLoader(
data,
num_neighbors=[10, 10],
batch_size=128,
input_nodes=data.train_mask,
shuffle=True
)
这段代码的意思是:
从 data.train_mask 指定的训练节点中,
每次取 128 个节点作为中心节点;
由于模型是两层 GraphSAGE,所以采样两跳邻居;
每一跳最多采样 10 个邻居;
最后把中心节点和采样到的邻居节点组成一个局部子图,用于模型训练。
11. NeighborLoader 参数解释
11.1 data
data
表示原始整张图。
其中通常包含:
data.x
节点特征矩阵。
data.edge_index
边结构。
data.y
节点标签。
data.train_mask
训练节点掩码。
11.2 input_nodes=data.train_mask
input_nodes=data.train_mask
表示只从训练节点中选择中心节点。
data.train_mask 通常是一个布尔向量,例如:
[True, True, False, False, True, ...]
其中 True 表示该节点属于训练集,False 表示该节点不属于训练集。
所以 NeighborLoader 每次会从这些值为 True 的节点中选出一批节点作为当前 batch 的中心节点。
11.3 batch_size=128
batch_size=128
表示每次选取 128 个训练节点作为中心节点。
需要注意:
这里的 batch_size 不是最终子图中的节点数量,而是中心节点数量。
因为每个中心节点还会采样它的邻居,所以最终返回的子图节点数通常会大于 128。
例如:
batch_size=128
num_neighbors=[10, 10]
最终采样出来的小图可能包含几百个甚至上千个节点。
11.4 num_neighbors=[10, 10]
num_neighbors=[10, 10]
表示进行两层邻居采样,通常对应两层 GraphSAGE 模型。
可以这样理解:
第一层:采样中心节点的一阶邻居,每个中心节点最多采样 10 个邻居;
第二层:再为这些一阶邻居采样它们的邻居,也就是中心节点的二阶邻居,每个节点最多采样 10 个邻居。
所以对于一个中心节点来说,理论上最多会涉及:
1 + 10 + 10 × 10 = 111 个节点
其中:
1 表示中心节点自己;
10 表示一阶邻居;
10 × 10 表示二阶邻居。
实际采样节点数通常会小于这个数,因为不同节点之间可能有重复邻居。
12. 为什么两层 GraphSAGE 需要采样两跳邻居
如果模型是两层 GraphSAGE:
self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, out_channels)
那么中心节点最终会进行两次消息传递。
第一层让节点聚合一阶邻居的信息。
第二层让中心节点继续聚合一阶邻居更新后的表示,而这些一阶邻居的表示中已经包含了二阶邻居的信息。
因此,两层 GraphSAGE 最终可以让中心节点获得二阶邻居的信息。
所以:
一层 GraphSAGE 通常对应 num_neighbors=[10]
两层 GraphSAGE 通常对应 num_neighbors=[10, 10]
三层 GraphSAGE 通常对应 num_neighbors=[10, 10, 10]
13. NeighborLoader 每次返回什么
当我们写:
for batch in train_loader:
print(batch)
break
NeighborLoader 每次会返回一个 batch。
这个 batch 本质上也是一个 PyG 的 Data 对象,表示一个采样出来的小图。
13.1 batch.x
batch.x
表示当前采样子图中所有节点的特征。
注意,它不只是中心节点的特征,还包括采样到的邻居节点特征。
因此它的形状可能是:
batch.x.shape = [采样子图节点数, 特征维度]
例如:
batch.x.shape = [843, 1433]
表示当前小图中一共有 843 个节点,每个节点有 1433 维特征。
13.2 batch.edge_index
batch.edge_index
表示当前采样子图中的边结构。
它不是原始整张图的完整边结构,而是当前采样小图中的边结构。
它的形状仍然是:
batch.edge_index.shape = [2, num_edges]
例如:
batch.edge_index.shape = [2, 3265]
表示当前小图中有 3265 条边。
13.3 batch.y
batch.y
表示当前采样子图中所有节点的标签。
注意,它也包含中心节点和邻居节点的标签。
但是在训练时,通常只对中心节点计算 loss,邻居节点只是用于提供消息传播的信息。
13.4 batch.batch_size
batch.batch_size
表示当前 batch 中中心节点的数量。
如果设置:
batch_size=128
那么多数情况下:
batch.batch_size == 128
但是最后一个 batch 可能不足 128。
13.5 batch.n_id
batch.n_id
表示当前子图中的节点在原始整张图中的真实编号。
因为 NeighborLoader 会对采样出来的小图重新编号。
例如,原图中的节点编号可能是:
[0, 5, 9, 20, 33, 100]
但是在当前采样子图中,它们会被重新编号成:
[0, 1, 2, 3, 4, 5]
batch.n_id 就用来记录当前子图节点编号和原图节点编号之间的对应关系。
14. 为什么 loss 只计算前 batch_size 个节点
使用 NeighborLoader 训练时,常见写法是:
for batch in train_loader:
batch = batch.to(device)
out = model(batch.x, batch.edge_index)
loss = F.cross_entropy(
out[:batch.batch_size],
batch.y[:batch.batch_size]
)
这里最重要的是:
out[:batch.batch_size]
因为 NeighborLoader 返回的 batch 中:
前 batch.batch_size 个节点是本次真正要训练的中心节点;
后面的节点是被采样进来的邻居节点。
也就是说:
batch.x 中的节点可以分成两部分:
前 batch.batch_size 个节点:中心节点,需要计算 loss;
后面的节点:邻居节点,只用于消息传播,不计算 loss。
所以训练时只对中心节点的预测结果计算损失。
如果对整个 out 计算 loss,就会把邻居节点也当成训练目标,这通常不是我们想要的。
15. 完整训练代码示例
下面是一个使用 Cora 数据集训练 GraphSAGE 的简单例子:
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import SAGEConv
from torch_geometric.loader import NeighborLoader
dataset = Planetoid(root='data/Cora', name='Cora')
data = dataset[0]
class GraphSAGE(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = SAGEConv(in_channels, hidden_channels)
self.conv2 = SAGEConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.relu(x)
x = self.conv2(x, edge_index)
return x
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GraphSAGE(
in_channels=dataset.num_features,
hidden_channels=64,
out_channels=dataset.num_classes
).to(device)
data = data.to(device)
train_loader = NeighborLoader(
data,
num_neighbors=[10, 10],
batch_size=128,
input_nodes=data.train_mask,
shuffle=True
)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
for epoch in range(1, 101):
model.train()
total_loss = 0
for batch in train_loader:
batch = batch.to(device)
optimizer.zero_grad()
out = model(batch.x, batch.edge_index)
loss = F.cross_entropy(
out[:batch.batch_size],
batch.y[:batch.batch_size]
)
loss.backward()
optimizer.step()
total_loss += loss.item()
print(f'Epoch: {epoch:03d}, Loss: {total_loss:.4f}')
16. 训练过程理解
假设原图中有 1000 个训练节点,并且设置:
batch_size=128
num_neighbors=[10, 10]
那么一个 epoch 的训练过程大致是:
第一步:从 1000 个训练节点中取出 128 个中心节点。
第二步:为这 128 个中心节点采样一阶邻居和二阶邻居。
第三步:把中心节点和采样到的邻居节点组成一个小图。
第四步:把这个小图送入 GraphSAGE 模型,执行前向传播。
第五步:只对这 128 个中心节点的预测结果计算 loss。
第六步:反向传播,更新模型参数。
第七步:继续取下一批 128 个训练节点,重复上述过程。
直到所有训练节点都被遍历一遍,就完成了一个 epoch。
17. 与普通 DataLoader 的对比
普通图像分类中的 DataLoader:
for images, labels in train_loader:
out = model(images)
loss = criterion(out, labels)
每次返回的是一批图片和对应标签。
GraphSAGE 中的 NeighborLoader:
for batch in train_loader:
out = model(batch.x, batch.edge_index)
loss = criterion(out[:batch.batch_size], batch.y[:batch.batch_size])
每次返回的是一批中心节点和它们采样得到的局部子图。
所以可以这样对比:
普通 DataLoader:
一批样本 = 一批图片
NeighborLoader:
一批样本 = 一批中心节点 + 它们的局部邻居子图
18. 容易混淆的地方
18.1 batch_size 不是子图节点总数
batch_size=128
表示每次选择 128 个中心节点,不表示采样子图中只有 128 个节点。
由于还要采样邻居,最终子图节点数通常会大于 128。
18.2 num_neighbors=[10, 10] 不是总共采样 20 个邻居
num_neighbors=[10, 10]
表示两层邻居采样,不是总共采样 20 个邻居。
可以理解为:
第一跳:每个中心节点最多采样 10 个邻居;
第二跳:每个一阶邻居最多再采样 10 个邻居。
18.3 batch.x 包含中心节点和邻居节点
batch.x 中不只有中心节点,也包含采样到的邻居节点。
模型需要这些邻居节点的信息来完成消息传播。
18.4 loss 只对中心节点计算
虽然 batch.x 中包含很多节点,但是训练目标是当前 batch 的中心节点。
因此 loss 通常写成:
loss = F.cross_entropy(
out[:batch.batch_size],
batch.y[:batch.batch_size]
)
19. 最终总结
GraphSAGE 的核心思想是:
从邻居中采样一部分节点,聚合这些邻居的信息,再和节点自身信息融合,得到新的节点表示。
它相比 GCN 更适合大图,因为它不必每次都聚合所有邻居,而是通过邻居采样控制计算规模。
GraphSAGE 中自身信息和邻居信息通常使用拼接,而不是直接相加。原因是:
拼接可以保留自身信息和邻居信息的区别,
然后让后面的线性层学习如何融合;
而相加会提前把两类信息混在一起,降低表达能力。
NeighborLoader 是 GraphSAGE 训练中非常重要的工具。它的作用是:
选择中心节点 + 采样邻居 + 构造小图 + 分批训练
最核心的训练代码是:
for batch in train_loader:
batch = batch.to(device)
out = model(batch.x, batch.edge_index)
loss = F.cross_entropy(
out[:batch.batch_size],
batch.y[:batch.batch_size]
)
loss.backward()
optimizer.step()
其中:
batch.x
表示采样子图中的节点特征。
batch.edge_index
表示采样子图中的边结构。
batch.batch_size
表示中心节点数量。
out[:batch.batch_size]
表示中心节点的预测结果。
batch.y[:batch.batch_size]
表示中心节点的真实标签。
最后可以把 GraphSAGE 和 NeighborLoader 的关系记成一句话:
GraphSAGE 负责学习"如何聚合邻居信息",NeighborLoader 负责训练时"采样哪些邻居来聚合"。