GAT(Graph Attention Network)学习笔记

1. GAT 是什么

GAT 全称是 Graph Attention Network,图注意力网络。它是一种图神经网络模型,和 GCN、GraphSAGE 一样,GAT 的核心目标也是通过聚合邻居节点的信息来更新当前节点的表示。

但是 GAT 与普通 GNN 最大的区别在于:它不会把所有邻居都同等看待,而是会学习每个邻居对当前节点的重要性,然后按照重要程度进行加权聚合。

可以用一句话理解 GAT:

GAT 是一种在图结构上引入注意力机制的 GNN。图结构决定节点能看谁,注意力机制决定节点更重视谁。


2. 为什么需要 GAT

在图神经网络中,一个节点通常需要通过邻居节点的信息来更新自己的表示。

比如节点 AAA 有三个邻居:

复制代码
B、C、D

如果使用普通平均聚合,那么节点 AAA 更新时可能会近似看作:

复制代码
A 的新特征 = B、C、D 的信息平均结果

这种方式的问题是:不同邻居对当前节点的重要性可能并不相同。

例如在论文引用网络中,一个论文节点可能连接很多论文节点,但这些论文与当前论文的主题相关程度不同。有些邻居论文和当前论文高度相关,有些只是弱相关。如果简单平均所有邻居信息,就可能把重要邻居和不重要邻居混在一起,影响最终节点表示。

因此,GAT 的核心思想就是:

节点在聚合邻居时,不应该一视同仁,而应该让模型自动学习哪些邻居更重要。


3. GAT 的核心思想

GAT 引入了注意力机制。所谓注意力机制,可以简单理解为:

模型自动学习"应该重点关注谁"。

放到图结构中,就是:

一个节点在更新自己时,会先判断每个邻居的重要程度,然后给不同邻居分配不同权重,最后按照这些权重加权聚合邻居信息。

例如节点 AAA 有三个邻居:

复制代码
B、C、D

GAT 可能会学习出如下权重:

复制代码
B 的重要性:0.7
C 的重要性:0.2
D 的重要性:0.1

那么节点 AAA 更新时就是:

复制代码
A 的新特征 = 0.7 × B 的信息 + 0.2 × C 的信息 + 0.1 × D 的信息

这里的 0.7、0.2、0.1 就是注意力权重。


4. 什么是权重

权重可以理解为"占比"或者"重要程度"。

比如:

复制代码
B 的权重是 0.7
C 的权重是 0.2
D 的权重是 0.1

这表示节点 AAA 更新自己时:

复制代码
70% 参考 B
20% 参考 C
10% 参考 D

所以 GAT 和 GCN 的核心区别在于:

复制代码
GCN:邻居权重主要由图结构和节点度数决定
GAT:邻居权重由模型根据节点特征动态学习出来

5. GAT 的整体计算过程

GAT 的一次节点更新可以分成四个步骤:

复制代码
第一步:对节点特征做线性变换
第二步:计算当前节点和邻居节点之间的重要性分数
第三步:使用 softmax 把分数转成注意力权重
第四步:按照注意力权重加权聚合邻居信息

这四步可以概括为:

先变换特征,再计算分数,再归一化权重,最后加权聚合。


6. 第一步:线性变换

假设一个节点原来的特征是:

复制代码
[1, 2, 3]

GAT 不会直接使用这个原始特征,而是先用一个可学习的矩阵 WWW 对它进行变换:

WhiWh_iWhi​

这里的 hih_ihi​ 表示节点 iii 的原始特征,WWW 是模型需要学习的参数。

可以把这一步理解为:

把节点原始特征映射到一个新的特征空间,使其更适合后续计算。

比如原始特征:

复制代码
[1, 2, 3]

经过线性变换后可能变成:

复制代码
[0.5, 1.2]

这一步和神经网络中的全连接层类似,本质上是对特征进行重新组合和变换。


7. 第二步:计算邻居的重要性分数

在 GAT 中,当前节点会分别和每个邻居计算一个重要性分数。

假设节点 AAA 的邻居是:

复制代码
B、C、D

那么 GAT 会计算:

复制代码
A 和 B 的重要性分数
A 和 C 的重要性分数
A 和 D 的重要性分数

一般可以写成:

eij=a(Whi,Whj)e_{ij} = a(Wh_i, Wh_j)eij​=a(Whi​,Whj​)

其中:

复制代码
i 表示当前节点
j 表示邻居节点
e_ij 表示邻居 j 对当前节点 i 的原始重要性分数

例如对于节点 AAA 来说:

复制代码
e_AB = A 和 B 的分数
e_AC = A 和 C 的分数
e_AD = A 和 D 的分数

假设模型计算得到:

复制代码
e_AB = 3.0
e_AC = 1.0
e_AD = 0.5

这说明从原始打分来看,B 对 A 最重要,C 次之,D 最弱。


8. GAT 中的注意力分数是如何计算的

GAT 原论文中常见的注意力计算方式是:

eij=LeakyReLU(aT[Whi∥Whj])e_{ij} = \text{LeakyReLU}(a^T[Wh_i \Vert Wh_j])eij​=LeakyReLU(aT[Whi​∥Whj​])

这个公式看起来比较复杂,可以拆开理解。

首先,WhiWh_iWhi​ 表示当前节点 iii 经过线性变换后的特征。

其次,WhjWh_jWhj​ 表示邻居节点 jjj 经过线性变换后的特征。

然后,[Whi∥Whj][Wh_i \Vert Wh_j][Whi​∥Whj​] 表示把两个特征拼接起来。

例如:

复制代码
Wh_A = [0.5, 1.2]
Wh_B = [1.1, 0.3]

拼接后得到:

复制代码
[Wh_A || Wh_B] = [0.5, 1.2, 1.1, 0.3]

这样做的目的是把当前节点和邻居节点的信息放在一起,让模型判断它们之间的关系。

公式中的 aTa^TaT 可以理解为一个可学习的打分向量,它的作用是根据拼接后的特征计算一个标量分数。

LeakyReLU 是一个激活函数,它的作用是增加非线性表达能力。对于初学阶段,不需要过度纠结它的具体形式,只需要知道它不会改变 GAT 的核心逻辑。

所以这个公式整体可以理解为:

把当前节点和邻居节点的特征拼接起来,然后用一组可学习参数判断这个邻居对当前节点有多重要。


9. 第三步:softmax 将分数转成权重

前面得到的 eije_{ij}eij​ 只是原始分数,还不能直接作为权重使用。因为这些分数不一定在 0 到 1 之间,也不一定加起来等于 1。

所以 GAT 会使用 softmax 把这些分数转成注意力权重:

αij=exp⁡(eij)∑k∈N(i)exp⁡(eik)\alpha_{ij} = \frac{\exp(e_{ij})}{\sum_{k \in N(i)} \exp(e_{ik})}αij​=∑k∈N(i)​exp(eik​)exp(eij​)​

这个公式的作用是:

把一组原始分数转换成一组比例,并且这些比例加起来等于 1。

例如节点 AAA 对三个邻居的原始分数是:

复制代码
B:3.0
C:1.0
D:0.5

经过 softmax 后可能变成:

复制代码
B:0.79
C:0.14
D:0.07

这时就可以解释为:

复制代码
A 更新自己时,79% 参考 B,14% 参考 C,7% 参考 D

10. softmax 是对谁做的

这一点非常重要。

GAT 不是对整张图所有边一起做 softmax,而是:

对每个节点自己的邻居集合内部做 softmax。

例如节点 AAA 的邻居是:

复制代码
复制代码
B、C、D

那么 softmax 只在下面这些分数之间进行:

复制代码
A-B
A-C
A-D

如果节点 EEE 的邻居是:

复制代码
F、G

那么节点 EEE 会单独对:

复制代码
E-F
E-G

做 softmax。

所以,每个节点都有自己的一组注意力权重。


11. 第四步:加权聚合邻居信息

有了注意力权重以后,节点就可以根据权重聚合邻居信息。

假设节点 AAA 的注意力权重是:

复制代码
B:0.79
C:0.14
D:0.07

那么节点 AAA 的新特征可以表示为:

hAnew=0.79WhB+0.14WhC+0.07WhDh_A^{new} = 0.79Wh_B + 0.14Wh_C + 0.07Wh_DhAnew​=0.79WhB​+0.14WhC​+0.07WhD​

也就是:

复制代码
A 的新特征 = B 的信息 × 0.79 + C 的信息 × 0.14 + D 的信息 × 0.07

这就是 GAT 的聚合过程。

因此,GAT 的完整逻辑可以总结为:

复制代码
先计算邻居重要性
再把重要性转成权重
最后按照权重加权聚合邻居特征

12. GAT 是否会使用自己的特征

通常会。

在实际使用 GAT 时,一般会给图加自环,也就是让每个节点自己也连接自己。

如果节点 AAA 原来有邻居:

复制代码
B、C、D

加上自环后,节点 AAA 的聚合对象变成:

复制代码
A、B、C、D

这样节点 AAA 更新时不仅会看邻居,也会保留自己的信息。

更新形式可以理解为:

复制代码
A 的新特征 =
注意力权重 A × A 自己的信息
+ 注意力权重 B × B 的信息
+ 注意力权重 C × C 的信息
+ 注意力权重 D × D 的信息

这样做的目的是避免节点在聚合邻居时丢失自身原有特征。


13. 一个完整的小例子

假设有一个节点 0,它有两个邻居节点 1 和 2,并且加上自环。

所以节点 0 的聚合对象是:

复制代码
0、1、2

经过线性变换后,它们的特征是:

复制代码
Wh_0 = [1, 1]
Wh_1 = [2, 0]
Wh_2 = [0, 3]

GAT 会计算节点 0 对这三个节点的注意力分数:

复制代码
节点 0 看自己:e_00 = 1.2
节点 0 看节点 1:e_01 = 2.5
节点 0 看节点 2:e_02 = 0.3

然后对这些分数做 softmax,得到:

复制代码
α_00 = 0.20
α_01 = 0.65
α_02 = 0.15

这表示:

复制代码
节点 0 更新时:
20% 保留自己的信息
65% 吸收节点 1 的信息
15% 吸收节点 2 的信息

最后更新:

h0new=0.20Wh0+0.65Wh1+0.15Wh2h_0^{new} = 0.20Wh_0 + 0.65Wh_1 + 0.15Wh_2h0new​=0.20Wh0​+0.65Wh1​+0.15Wh2​

代入具体数值:

复制代码
0.20 × [1, 1] = [0.20, 0.20]
0.65 × [2, 0] = [1.30, 0.00]
0.15 × [0, 3] = [0.00, 0.45]

相加得到:

复制代码
h_0_new = [1.50, 0.65]

这就是节点 0 更新后的新特征。

从这个例子可以看到,节点 1 的权重最大,所以节点 0 的新表示受节点 1 的影响最大。


14. GAT 和 GCN 的区别

GCN 和 GAT 都是通过聚合邻居信息来更新节点表示,但它们的聚合方式不同。

GCN 的聚合权重主要由图结构决定。它通过归一化邻接矩阵对邻居特征进行加权求和,权重通常和节点的度有关。也就是说,GCN 中某个邻居对当前节点的影响大小,主要由图的连接结构决定,而不是由节点特征动态决定。

GAT 则不同。GAT 会根据当前节点和邻居节点的特征计算注意力权重。也就是说,GAT 会动态判断:

复制代码
这个邻居对当前节点是否重要?
这个邻居的信息应该占多大比例?

因此可以总结为:

复制代码
GCN:结构驱动的固定权重聚合
GAT:特征驱动的自适应权重聚合

但是需要注意,GAT 并不是完全脱离图结构。GAT 仍然只在有边连接的节点之间计算注意力。

所以更准确地说:

图结构决定节点能接收谁的信息,注意力机制决定这些信息的重要程度。


15. GAT 和 GraphSAGE 的区别

GraphSAGE 的核心特点是邻居采样和聚合。

在大规模图中,一个节点可能有成百上千个邻居。如果每次都聚合所有邻居,计算量会很大。GraphSAGE 可以从邻居中采样一部分节点,然后使用 mean、pooling 或 LSTM 等方式聚合邻居信息。

例如一个节点有 1000 个邻居,GraphSAGE 可以只采样 10 个邻居进行聚合。

如果使用 mean 聚合,那么 GraphSAGE 可以理解为:

复制代码
先采样邻居,再对采样到的邻居求平均

而 GAT 的核心不是采样,而是注意力权重。

GAT 更关心:

复制代码
这些邻居里面,谁更重要?

因此三者可以这样对比:

复制代码
GCN:
根据归一化邻接矩阵聚合邻居信息。

GraphSAGE:
采样一部分邻居,然后使用聚合函数聚合邻居信息。

GAT:
计算每个邻居的重要性权重,然后按照权重加权聚合邻居信息。

更形象地说:

复制代码
GCN:大家按固定规则发言。
GraphSAGE:我只抽几个人来听。
GAT:我听允许听的人,但会重点听更重要的人。

16. 多头注意力

GAT 中还有一个重要概念:多头注意力

所谓多头注意力,就是不只使用一套注意力机制,而是同时使用多套注意力机制。

例如节点 AAA 有三个邻居:

复制代码
B、C、D

第一个注意力头可能学到:

复制代码
B:0.7
C:0.2
D:0.1

第二个注意力头可能学到:

复制代码
B:0.2
C:0.6
D:0.2

第三个注意力头可能学到:

复制代码
B:0.3
C:0.3
D:0.4

为什么要使用多个头?

因为判断邻居是否重要可能有不同角度。一个注意力头可能更关注某种特征关系,另一个注意力头可能关注另一种特征关系。多个注意力头共同作用,可以增强模型表达能力,也可以提高训练稳定性。

可以简单理解为:

单头注意力是从一个角度判断邻居重要性,多头注意力是从多个角度判断邻居重要性。


17. 多头注意力的输出维度

在 PyTorch Geometric 中,常用 GATConv 实现 GAT。

例如:

复制代码
GATConv(in_channels, hidden_channels, heads=8)

这里 heads=8 表示有 8 个注意力头。

如果:

复制代码
hidden_channels = 8
heads = 8

并且默认 concat=True,那么输出维度不是 8,而是:

复制代码
8 × 8 = 64

原因是每个注意力头都会输出一个 8 维特征,8 个头拼接起来就是 64 维。

所以第一层 GAT 后面经常会写:

复制代码
self.conv2 = GATConv(hidden_channels * heads, out_channels)

如果设置:

复制代码
concat=False

那么多个注意力头的结果会被平均,输出维度仍然是:

复制代码
hidden_channels

因此可以总结为:

复制代码
concat=True:
多个头的结果拼接,输出维度 = hidden_channels × heads

concat=False:
多个头的结果平均,输出维度 = hidden_channels

18. PyTorch Geometric 中的 GATConv

一个简单的两层 GAT 网络可以写成:

复制代码
import torch
import torch.nn.functional as F
from torch_geometric.nn import GATConv


class GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()

        self.conv1 = GATConv(
            in_channels=in_channels,
            out_channels=hidden_channels,
            heads=8,
            dropout=0.6
        )

        self.conv2 = GATConv(
            in_channels=hidden_channels * 8,
            out_channels=out_channels,
            heads=1,
            concat=False,
            dropout=0.6
        )

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=self.training)

        x = self.conv2(x, edge_index)
        return x

其中:

复制代码
x = self.conv1(x, edge_index)

内部大致完成了以下操作:

复制代码
1. 对所有节点特征 x 做线性变换
2. 根据 edge_index 找到边连接的节点
3. 对每条边计算注意力分数
4. 对每个节点的邻居分数做 softmax
5. 按照注意力权重聚合邻居特征
6. 输出每个节点的新特征

所以 GATConv 已经把注意力计算、softmax、邻居聚合等过程封装好了。


19. edge_index 在 GAT 中的作用

在 PyTorch Geometric 中,GATConv 的输入通常是:

复制代码
x, edge_index

其中,x 是节点特征矩阵,每一行表示一个节点的特征。

edge_index 是边结构,通常形状为:

复制代码
[2, num_edges]

例如:

复制代码
edge_index = torch.tensor([
    [0, 1, 2],
    [1, 2, 0]
])

表示三条边:

复制代码
0 -> 1
1 -> 2
2 -> 0

在 GAT 中,edge_index 决定了每个节点可以从哪些节点接收信息。

也就是说:

edge_index 决定消息从哪里传到哪里,attention 决定这些消息的重要程度。

GAT 并不是让每个节点关注全图所有节点,而是通常只在图结构规定的邻居范围内计算注意力。


20. GAT 的核心总结

GAT 的核心流程可以总结为:

复制代码
对于每个节点 i:

1. 找到它的邻居节点
2. 对节点特征进行线性变换
3. 计算当前节点和每个邻居之间的重要性分数
4. 对这些分数做 softmax,得到注意力权重
5. 按照注意力权重加权聚合邻居特征
6. 得到节点 i 的新表示

一句话总结:

GAT 是一种在图神经网络中引入注意力机制的方法。它通过学习不同邻居的重要性权重,使节点在聚合邻居信息时能够更加灵活和有选择性。


21. GCN、GraphSAGE、GAT 对比总结

模型 核心思想 聚合方式 重点
GCN 使用归一化邻接矩阵传播信息 根据图结构固定加权聚合 图卷积基础思想
GraphSAGE 采样邻居并聚合 采样后使用 mean、pooling、LSTM 等聚合 大规模图、归纳学习
GAT 学习邻居注意力权重 根据注意力权重加权聚合 区分不同邻居的重要性

可以进一步理解为:

复制代码
GCN:
我根据邻接矩阵和节点度数,把邻居信息按固定规则加起来。

GraphSAGE:
我先采样一部分邻居,再用聚合函数把邻居信息合起来。

GAT:
我先判断每个邻居对我有多重要,再按照重要性加权聚合邻居信息。

最终记忆版:

GCN 是固定权重聚合,GraphSAGE 是采样邻居后聚合,GAT 是学习邻居权重后聚合。

相关推荐
拾薪1 小时前
CodeGraph安装使用
人工智能·ai·codegraph
Tutankaaa1 小时前
学校知识竞赛怎么组织?从班级到年级的进阶方案
经验分享·学习·算法·职场和发展
TMT星球1 小时前
汉王科技发布录写本M6,定义“国民级AI数字文具”新物种
人工智能·科技
wechat_Neal1 小时前
AI基础_LLM推理过程
人工智能
covco1 小时前
端云协同架构下:AI 原生矩阵系统端侧推理与离线生产技术实践
人工智能·矩阵·架构
qcx231 小时前
混合检索+重排序:当前 RAG 精度提升最成熟的工程路径
算法·ai·llm·agent·rag·agentic
隔窗听雨眠1 小时前
读懂AI自动化的两种范式
运维·人工智能·自动化
Komorebi_99991 小时前
Agent 第二课:ReAct 框架 思考与行动机制(
人工智能·agent
洛水水1 小时前
【力扣100题】42.杨辉三角
算法·leetcode·职场和发展