1. edge_index 与邻接矩阵的关系
在 PyTorch Geometric 中,图结构通常使用 edge_index 表示。它的形状不是 2 × N,而是:
edge_index.shape = [2, E]
其中:
N:节点数量
E:边的数量
也就是说,edge_index 的每一列表示一条边。第一行表示源节点,第二行表示目标节点。
例如:
edge_index = torch.tensor([
[0, 1, 1, 2],
[1, 0, 2, 1]
])
表示下面几条边:
0 -> 1
1 -> 0
1 -> 2
2 -> 1
如果把它转成邻接矩阵,可以理解为:
A =
[
[0, 1, 0],
[1, 0, 1],
[0, 1, 0]
]
从数学推导角度看,GNN 经常使用邻接矩阵 A 表示图结构,例如:
H^(l+1) = σ(A_norm H^(l) W)
这里的 A_norm 可以理解为归一化后的邻接矩阵,它决定每个节点从哪些邻居那里聚合信息。
但是在 PyG 的工程实现中,通常不会真的构造完整的 N × N 邻接矩阵,而是直接使用 edge_index 这种稀疏边表示。
原因是:真实图通常非常稀疏。如果有很多节点,完整邻接矩阵会非常浪费内存。
例如有 10000 个节点,完整邻接矩阵大小是:
10000 × 10000 = 100000000
但是如果每个节点平均只有 10 条边,那么实际边数可能只有:
10000 × 10 = 100000
完整邻接矩阵里绝大多数位置都是 0,因此没有必要全部保存。
所以可以这样理解:
数学公式中的 A ≈ 代码实现中的 edge_index
矩阵乘法 A @ X ≈ 沿着 edge_index 传消息并聚合
一句话总结:
edge_index 是邻接矩阵的稀疏表示;数学上可以把它理解成邻接矩阵 A,但工程实现中框架通常直接利用 edge_index 完成消息传播,不会显式构造完整的 N × N 矩阵。
2. 邻接矩阵方式和 edge_index 方式的区别
如果使用邻接矩阵,消息传播可以写得非常直观:
H = A @ X
其中:
A:邻接矩阵,表示节点之间的连接关系
X:节点特征矩阵
H:聚合邻居信息后的节点特征
这表示:
A 的第 i 行决定第 i 个节点聚合哪些节点的特征。
这种方式非常适合理解 GNN 的数学原理。
但是在大图中,完整邻接矩阵的存储和计算开销都很大。因此,实际框架更常用 edge_index。
edge_index 的思想是:
只保存真实存在的边,只沿着真实存在的边传递消息。
例如:
edge_index = torch.tensor([
[0, 1, 2],
[1, 2, 3]
])
表示:
0 -> 1
1 -> 2
2 -> 3
在消息传播时,可以理解为:
节点0的信息传给节点1
节点1的信息传给节点2
节点2的信息传给节点3
从底层实现上看,可能类似于:
src = edge_index[0]
dst = edge_index[1]
messages = X[src]
H = torch.zeros_like(X)
H.index_add_(0, dst, messages)
这段代码的含义是:
src:所有边的源节点
dst:所有边的目标节点
messages = X[src]:取出源节点要发送的信息
index_add_:把消息加到对应的目标节点上
虽然 edge_index 的底层实现比直接矩阵乘法更复杂,但 PyG 已经把这些细节封装好了。我们使用时只需要写:
x = conv(x, edge_index)
就可以完成消息传播。
所以学习时可以采用这样的理解方式:
学习数学原理时,用 A @ X 理解邻居聚合;
阅读 PyG 代码时,把 conv(x, edge_index) 理解成框架帮我们完成了等价的邻居聚合。
一句话总结:
邻接矩阵适合理解 GNN,edge_index 适合真正训练大规模 GNN。
3. W 参数矩阵的作用
在 GCN 公式中:
H^(l+1) = σ(A_norm H^(l) W)
其中,W 是可学习参数矩阵。它的作用是对节点特征进行线性变换,把每个节点的特征从输入维度映射到输出维度。
假设当前节点表示为:
H^(l).shape = [N, F_in]
其中:
N:节点数量
F_in:当前每个节点的特征维度
如果希望输出特征维度变成 F_out,那么参数矩阵 W 的形状为:
W.shape = [F_in, F_out]
因此:
H^(l) @ W 的形状 = [N, F_out]
也就是说,节点数量 N 不变,但每个节点的特征维度从 F_in 变成了 F_out。
例如:
H.shape = [4, 2]
W.shape = [2, 3]
H @ W.shape = [4, 3]
这表示:
4 个节点不变;
每个节点的特征从 2 维变成 3 维。
4. W 不是简单 reshape
这里要注意:W 改变特征维度,但它不是简单的 reshape。
reshape 只是改变张量的形状或排列方式,不引入可学习参数,也不会进行特征组合。例如:
x = x.reshape(4, 3)
这只是重新组织数据。
而 W 是通过矩阵乘法完成特征变换:
h = x @ W
它会把原来的特征进行加权组合,生成新的特征。
比如一个节点原来的特征是:
x_i = [a, b]
经过 W 后变成 3 维,新特征可能类似于:
h_i = [0.2a + 0.7b, -0.5a + 0.3b, 1.1a - 0.2b]
也就是说,新特征的每一维,都是旧特征各个维度的加权组合。这些权重不是人工指定的,而是在训练过程中通过反向传播自动学习出来的。
所以,更准确的说法是:
W 用来对节点特征做可学习的线性变换,把原始特征维度 F_in 映射到新的特征维度 F_out。
5. W 和 CNN 中卷积核数量的类比
可以把 GCN 中的 W 类比为 CNN 中控制输出通道数的部分。
在 CNN 中:
卷积核个数决定输出通道数。
在 GCN 中:
W 的输出维度决定节点新特征维度。
例如:
self.conv1 = GCNConv(num_features, hidden_channels)
如果:
num_features = 1433
hidden_channels = 16
那么第一层 GCN 可以理解为:
把每个节点从 1433 维原始特征,变成 16 维隐藏特征;
同时根据 edge_index 表示的图结构聚合邻居信息。
6. 本阶段总结
这一阶段主要理解两个问题:
1. edge_index 和邻接矩阵 A 的关系;
2. GCN 中 W 参数矩阵的作用。
可以总结为:
edge_index 是邻接矩阵的稀疏表示。数学上可以用 A @ X 理解邻居聚合,工程上 PyG 会根据 edge_index 自动完成消息传播。
以及:
W 不是简单 reshape,而是可学习的线性变换。它负责把每个节点的特征从输入维度映射到输出维度,并在训练过程中学习更有用的特征组合方式。
最终可以把 GCN 公式理解为:
H^(l+1) = σ(A_norm H^(l) W)
其中:
A_norm:负责根据图结构聚合邻居信息;
H^(l):当前层节点特征;
W:负责对节点特征做可学习的线性变换;
σ:激活函数,引入非线性表达能力。