1. GAT 是什么
GAT 全称是 Graph Attention Network,图注意力网络。它是一种图神经网络模型,和 GCN、GraphSAGE 一样,GAT 的核心目标也是通过聚合邻居节点的信息来更新当前节点的表示。
但是 GAT 与普通 GNN 最大的区别在于:它不会把所有邻居都同等看待,而是会学习每个邻居对当前节点的重要性,然后按照重要程度进行加权聚合。
可以用一句话理解 GAT:
GAT 是一种在图结构上引入注意力机制的 GNN。图结构决定节点能看谁,注意力机制决定节点更重视谁。
2. 为什么需要 GAT
在图神经网络中,一个节点通常需要通过邻居节点的信息来更新自己的表示。
比如节点 AAA 有三个邻居:
B、C、D
如果使用普通平均聚合,那么节点 AAA 更新时可能会近似看作:
A 的新特征 = B、C、D 的信息平均结果
这种方式的问题是:不同邻居对当前节点的重要性可能并不相同。
例如在论文引用网络中,一个论文节点可能连接很多论文节点,但这些论文与当前论文的主题相关程度不同。有些邻居论文和当前论文高度相关,有些只是弱相关。如果简单平均所有邻居信息,就可能把重要邻居和不重要邻居混在一起,影响最终节点表示。
因此,GAT 的核心思想就是:
节点在聚合邻居时,不应该一视同仁,而应该让模型自动学习哪些邻居更重要。
3. GAT 的核心思想
GAT 引入了注意力机制。所谓注意力机制,可以简单理解为:
模型自动学习"应该重点关注谁"。
放到图结构中,就是:
一个节点在更新自己时,会先判断每个邻居的重要程度,然后给不同邻居分配不同权重,最后按照这些权重加权聚合邻居信息。
例如节点 AAA 有三个邻居:
B、C、D
GAT 可能会学习出如下权重:
B 的重要性:0.7
C 的重要性:0.2
D 的重要性:0.1
那么节点 AAA 更新时就是:
A 的新特征 = 0.7 × B 的信息 + 0.2 × C 的信息 + 0.1 × D 的信息
这里的 0.7、0.2、0.1 就是注意力权重。
4. 什么是权重
权重可以理解为"占比"或者"重要程度"。
比如:
B 的权重是 0.7
C 的权重是 0.2
D 的权重是 0.1
这表示节点 AAA 更新自己时:
70% 参考 B
20% 参考 C
10% 参考 D
所以 GAT 和 GCN 的核心区别在于:
GCN:邻居权重主要由图结构和节点度数决定
GAT:邻居权重由模型根据节点特征动态学习出来
5. GAT 的整体计算过程
GAT 的一次节点更新可以分成四个步骤:
第一步:对节点特征做线性变换
第二步:计算当前节点和邻居节点之间的重要性分数
第三步:使用 softmax 把分数转成注意力权重
第四步:按照注意力权重加权聚合邻居信息
这四步可以概括为:
先变换特征,再计算分数,再归一化权重,最后加权聚合。
6. 第一步:线性变换
假设一个节点原来的特征是:
[1, 2, 3]
GAT 不会直接使用这个原始特征,而是先用一个可学习的矩阵 WWW 对它进行变换:
WhiWh_iWhi
这里的 hih_ihi 表示节点 iii 的原始特征,WWW 是模型需要学习的参数。
可以把这一步理解为:
把节点原始特征映射到一个新的特征空间,使其更适合后续计算。
比如原始特征:
[1, 2, 3]
经过线性变换后可能变成:
[0.5, 1.2]
这一步和神经网络中的全连接层类似,本质上是对特征进行重新组合和变换。
7. 第二步:计算邻居的重要性分数
在 GAT 中,当前节点会分别和每个邻居计算一个重要性分数。
假设节点 AAA 的邻居是:
B、C、D
那么 GAT 会计算:
A 和 B 的重要性分数
A 和 C 的重要性分数
A 和 D 的重要性分数
一般可以写成:
eij=a(Whi,Whj)e_{ij} = a(Wh_i, Wh_j)eij=a(Whi,Whj)
其中:
i 表示当前节点
j 表示邻居节点
e_ij 表示邻居 j 对当前节点 i 的原始重要性分数
例如对于节点 AAA 来说:
e_AB = A 和 B 的分数
e_AC = A 和 C 的分数
e_AD = A 和 D 的分数
假设模型计算得到:
e_AB = 3.0
e_AC = 1.0
e_AD = 0.5
这说明从原始打分来看,B 对 A 最重要,C 次之,D 最弱。
8. GAT 中的注意力分数是如何计算的
GAT 原论文中常见的注意力计算方式是:
eij=LeakyReLU(aT[Whi∥Whj])e_{ij} = \text{LeakyReLU}(a^T[Wh_i \Vert Wh_j])eij=LeakyReLU(aT[Whi∥Whj])
这个公式看起来比较复杂,可以拆开理解。
首先,WhiWh_iWhi 表示当前节点 iii 经过线性变换后的特征。
其次,WhjWh_jWhj 表示邻居节点 jjj 经过线性变换后的特征。
然后,[Whi∥Whj][Wh_i \Vert Wh_j][Whi∥Whj] 表示把两个特征拼接起来。
例如:
Wh_A = [0.5, 1.2]
Wh_B = [1.1, 0.3]
拼接后得到:
[Wh_A || Wh_B] = [0.5, 1.2, 1.1, 0.3]
这样做的目的是把当前节点和邻居节点的信息放在一起,让模型判断它们之间的关系。
公式中的 aTa^TaT 可以理解为一个可学习的打分向量,它的作用是根据拼接后的特征计算一个标量分数。
LeakyReLU 是一个激活函数,它的作用是增加非线性表达能力。对于初学阶段,不需要过度纠结它的具体形式,只需要知道它不会改变 GAT 的核心逻辑。
所以这个公式整体可以理解为:
把当前节点和邻居节点的特征拼接起来,然后用一组可学习参数判断这个邻居对当前节点有多重要。
9. 第三步:softmax 将分数转成权重
前面得到的 eije_{ij}eij 只是原始分数,还不能直接作为权重使用。因为这些分数不一定在 0 到 1 之间,也不一定加起来等于 1。
所以 GAT 会使用 softmax 把这些分数转成注意力权重:
αij=exp(eij)∑k∈N(i)exp(eik)\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k \in N(i)} \exp(e_{ik})}αij=∑k∈N(i)exp(eik)exp(eij)
这个公式的作用是:
把一组原始分数转换成一组比例,并且这些比例加起来等于 1。
例如节点 AAA 对三个邻居的原始分数是:
B:3.0
C:1.0
D:0.5
经过 softmax 后可能变成:
B:0.79
C:0.14
D:0.07
这时就可以解释为:
A 更新自己时,79% 参考 B,14% 参考 C,7% 参考 D
10. softmax 是对谁做的
这一点非常重要。
GAT 不是对整张图所有边一起做 softmax,而是:
对每个节点自己的邻居集合内部做 softmax。
例如节点 AAA 的邻居是:
B、C、D
那么 softmax 只在下面这些分数之间进行:
A-B
A-C
A-D
如果节点 EEE 的邻居是:
F、G
那么节点 EEE 会单独对:
E-F
E-G
做 softmax。
所以,每个节点都有自己的一组注意力权重。
11. 第四步:加权聚合邻居信息
有了注意力权重以后,节点就可以根据权重聚合邻居信息。
假设节点 AAA 的注意力权重是:
B:0.79
C:0.14
D:0.07
那么节点 AAA 的新特征可以表示为:
hAnew=0.79WhB+0.14WhC+0.07WhDh_A^{new} = 0.79Wh_B + 0.14Wh_C + 0.07Wh_DhAnew=0.79WhB+0.14WhC+0.07WhD
也就是:
A 的新特征 = B 的信息 × 0.79 + C 的信息 × 0.14 + D 的信息 × 0.07
这就是 GAT 的聚合过程。
因此,GAT 的完整逻辑可以总结为:
先计算邻居重要性
再把重要性转成权重
最后按照权重加权聚合邻居特征
12. GAT 是否会使用自己的特征
通常会。
在实际使用 GAT 时,一般会给图加自环,也就是让每个节点自己也连接自己。
如果节点 AAA 原来有邻居:
B、C、D
加上自环后,节点 AAA 的聚合对象变成:
A、B、C、D
这样节点 AAA 更新时不仅会看邻居,也会保留自己的信息。
更新形式可以理解为:
A 的新特征 =
注意力权重 A × A 自己的信息
+ 注意力权重 B × B 的信息
+ 注意力权重 C × C 的信息
+ 注意力权重 D × D 的信息
这样做的目的是避免节点在聚合邻居时丢失自身原有特征。
13. 一个完整的小例子
假设有一个节点 0,它有两个邻居节点 1 和 2,并且加上自环。
所以节点 0 的聚合对象是:
0、1、2
经过线性变换后,它们的特征是:
Wh_0 = [1, 1]
Wh_1 = [2, 0]
Wh_2 = [0, 3]
GAT 会计算节点 0 对这三个节点的注意力分数:
节点 0 看自己:e_00 = 1.2
节点 0 看节点 1:e_01 = 2.5
节点 0 看节点 2:e_02 = 0.3
然后对这些分数做 softmax,得到:
α_00 = 0.20
α_01 = 0.65
α_02 = 0.15
这表示:
节点 0 更新时:
20% 保留自己的信息
65% 吸收节点 1 的信息
15% 吸收节点 2 的信息
最后更新:
h0new=0.20Wh0+0.65Wh1+0.15Wh2h_0^{new} = 0.20Wh_0 + 0.65Wh_1 + 0.15Wh_2h0new=0.20Wh0+0.65Wh1+0.15Wh2
代入具体数值:
0.20 × [1, 1] = [0.20, 0.20]
0.65 × [2, 0] = [1.30, 0.00]
0.15 × [0, 3] = [0.00, 0.45]
相加得到:
h_0_new = [1.50, 0.65]
这就是节点 0 更新后的新特征。
从这个例子可以看到,节点 1 的权重最大,所以节点 0 的新表示受节点 1 的影响最大。
14. GAT 和 GCN 的区别
GCN 和 GAT 都是通过聚合邻居信息来更新节点表示,但它们的聚合方式不同。
GCN 的聚合权重主要由图结构决定。它通过归一化邻接矩阵对邻居特征进行加权求和,权重通常和节点的度有关。也就是说,GCN 中某个邻居对当前节点的影响大小,主要由图的连接结构决定,而不是由节点特征动态决定。
GAT 则不同。GAT 会根据当前节点和邻居节点的特征计算注意力权重。也就是说,GAT 会动态判断:
这个邻居对当前节点是否重要?
这个邻居的信息应该占多大比例?
因此可以总结为:
GCN:结构驱动的固定权重聚合
GAT:特征驱动的自适应权重聚合
但是需要注意,GAT 并不是完全脱离图结构。GAT 仍然只在有边连接的节点之间计算注意力。
所以更准确地说:
图结构决定节点能接收谁的信息,注意力机制决定这些信息的重要程度。
15. GAT 和 GraphSAGE 的区别
GraphSAGE 的核心特点是邻居采样和聚合。
在大规模图中,一个节点可能有成百上千个邻居。如果每次都聚合所有邻居,计算量会很大。GraphSAGE 可以从邻居中采样一部分节点,然后使用 mean、pooling 或 LSTM 等方式聚合邻居信息。
例如一个节点有 1000 个邻居,GraphSAGE 可以只采样 10 个邻居进行聚合。
如果使用 mean 聚合,那么 GraphSAGE 可以理解为:
先采样邻居,再对采样到的邻居求平均
而 GAT 的核心不是采样,而是注意力权重。
GAT 更关心:
这些邻居里面,谁更重要?
因此三者可以这样对比:
GCN:
根据归一化邻接矩阵聚合邻居信息。
GraphSAGE:
采样一部分邻居,然后使用聚合函数聚合邻居信息。
GAT:
计算每个邻居的重要性权重,然后按照权重加权聚合邻居信息。
更形象地说:
GCN:大家按固定规则发言。
GraphSAGE:我只抽几个人来听。
GAT:我听允许听的人,但会重点听更重要的人。
16. 多头注意力
GAT 中还有一个重要概念:多头注意力。
所谓多头注意力,就是不只使用一套注意力机制,而是同时使用多套注意力机制。
例如节点 AAA 有三个邻居:
B、C、D
第一个注意力头可能学到:
B:0.7
C:0.2
D:0.1
第二个注意力头可能学到:
B:0.2
C:0.6
D:0.2
第三个注意力头可能学到:
B:0.3
C:0.3
D:0.4
为什么要使用多个头?
因为判断邻居是否重要可能有不同角度。一个注意力头可能更关注某种特征关系,另一个注意力头可能关注另一种特征关系。多个注意力头共同作用,可以增强模型表达能力,也可以提高训练稳定性。
可以简单理解为:
单头注意力是从一个角度判断邻居重要性,多头注意力是从多个角度判断邻居重要性。
17. 多头注意力的输出维度
在 PyTorch Geometric 中,常用 GATConv 实现 GAT。
例如:
GATConv(in_channels, hidden_channels, heads=8)
这里 heads=8 表示有 8 个注意力头。
如果:
hidden_channels = 8
heads = 8
并且默认 concat=True,那么输出维度不是 8,而是:
8 × 8 = 64
原因是每个注意力头都会输出一个 8 维特征,8 个头拼接起来就是 64 维。
所以第一层 GAT 后面经常会写:
self.conv2 = GATConv(hidden_channels * heads, out_channels)
如果设置:
concat=False
那么多个注意力头的结果会被平均,输出维度仍然是:
hidden_channels
因此可以总结为:
concat=True:
多个头的结果拼接,输出维度 = hidden_channels × heads
concat=False:
多个头的结果平均,输出维度 = hidden_channels
18. PyTorch Geometric 中的 GATConv
一个简单的两层 GAT 网络可以写成:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv
class GAT(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GATConv(
in_channels=in_channels,
out_channels=hidden_channels,
heads=8,
dropout=0.6
)
self.conv2 = GATConv(
in_channels=hidden_channels * 8,
out_channels=out_channels,
heads=1,
concat=False,
dropout=0.6
)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = F.elu(x)
x = F.dropout(x, p=0.6, training=self.training)
x = self.conv2(x, edge_index)
return x
其中:
x = self.conv1(x, edge_index)
内部大致完成了以下操作:
1. 对所有节点特征 x 做线性变换
2. 根据 edge_index 找到边连接的节点
3. 对每条边计算注意力分数
4. 对每个节点的邻居分数做 softmax
5. 按照注意力权重聚合邻居特征
6. 输出每个节点的新特征
所以 GATConv 已经把注意力计算、softmax、邻居聚合等过程封装好了。
19. edge_index 在 GAT 中的作用
在 PyTorch Geometric 中,GATConv 的输入通常是:
x, edge_index
其中,x 是节点特征矩阵,每一行表示一个节点的特征。
edge_index 是边结构,通常形状为:
[2, num_edges]
例如:
edge_index = torch.tensor([
[0, 1, 2],
[1, 2, 0]
])
表示三条边:
0 -> 1
1 -> 2
2 -> 0
在 GAT 中,edge_index 决定了每个节点可以从哪些节点接收信息。
也就是说:
edge_index 决定消息从哪里传到哪里,attention 决定这些消息的重要程度。
GAT 并不是让每个节点关注全图所有节点,而是通常只在图结构规定的邻居范围内计算注意力。
20. GAT 的核心总结
GAT 的核心流程可以总结为:
对于每个节点 i:
1. 找到它的邻居节点
2. 对节点特征进行线性变换
3. 计算当前节点和每个邻居之间的重要性分数
4. 对这些分数做 softmax,得到注意力权重
5. 按照注意力权重加权聚合邻居特征
6. 得到节点 i 的新表示
一句话总结:
GAT 是一种在图神经网络中引入注意力机制的方法。它通过学习不同邻居的重要性权重,使节点在聚合邻居信息时能够更加灵活和有选择性。
21. GCN、GraphSAGE、GAT 对比总结
| 模型 | 核心思想 | 聚合方式 | 重点 |
|---|---|---|---|
| GCN | 使用归一化邻接矩阵传播信息 | 根据图结构固定加权聚合 | 图卷积基础思想 |
| GraphSAGE | 采样邻居并聚合 | 采样后使用 mean、pooling、LSTM 等聚合 | 大规模图、归纳学习 |
| GAT | 学习邻居注意力权重 | 根据注意力权重加权聚合 | 区分不同邻居的重要性 |
可以进一步理解为:
GCN:
我根据邻接矩阵和节点度数,把邻居信息按固定规则加起来。
GraphSAGE:
我先采样一部分邻居,再用聚合函数把邻居信息合起来。
GAT:
我先判断每个邻居对我有多重要,再按照重要性加权聚合邻居信息。
最终记忆版:
GCN 是固定权重聚合,GraphSAGE 是采样邻居后聚合,GAT 是学习邻居权重后聚合。