什么是图神经网络?

一、概念

图神经网络(Graph Neural Network, GNN)是一类专门用于处理图结构数据 的神经网络。图结构数据广泛存在于各种实际应用中,如社交网络、分子结构、知识图谱等。GNN通过在图的节点和边上进行信息传递和聚合,能够有效地捕捉图结构中的复杂关系和特征

GNN的输入通常是一个图 G=(V,E),其中 V 是节点集合,E 是边集合。每个节点 v∈V 可能有一个特征向量 ​,每条边 (u,v)∈E 可能有一个特征向量​。

二、核心算法

GNN的核心思想是通过迭代地更新节点的表示来捕捉图结构中的信息。每一轮迭代(也称为层)包括以下两个步骤:

  • 消息传递(Message Passing):每个节点从其邻居节点接收信息。
  • 节点更新(Node Update):每个节点根据接收到的信息和自身的特征更新其表示。

假设我们有一个图 G=(V,E),每个节点 v∈V 的特征向量为 ,每条边 (u,v)∈E 的特征向量为 ​。GNN的计算公式可以表示为:

1、消息传递

其中,N(v)表示节点 v 的邻居节点集合,M是消息传递函数,是节点 v 在第 k 层接收到的消息。

2、节点更新

其中,U是节点更新函数,是节点 v 在第 k 层的表示。

三、python实现

这里,我们构建一个create_graph函数来生成一个空手道俱乐部的图(Karate Club Graph),并为每个节点生成一个特征向量(单位矩阵)和标签(根据俱乐部分组)。通过加载 Karate Club 图数据集,我们可以获得一个社交网络图,其中包含 34 个节点和 78 条边。我们为每个节点生成标签(0 或 1),表示节点属于哪个社区(Mr. HiOfficer)。进而基于这份数据进行GNN分类。

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt

# 生成一个小的图数据集
def create_graph():
    # 加载 Karate Club 图数据集,这是一个社交网络图,包含 34 个节点和 78 条边。
    G = nx.karate_club_graph()
    features = np.eye(G.number_of_nodes())
    # 为每个节点生成标签(0 或 1),表示节点属于哪个社区(Mr. Hi 或 Officer)。
    labels = np.array([G.nodes[i]['club'] == 'Mr. Hi' for i in range(G.number_of_nodes())], dtype=int)
    return G, features, labels

# 定义原始GNN模型
class GNN(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x, adj):
        h = F.relu(self.fc1(x))
        # 使用邻接矩阵 adj 聚合邻居节点的信息。
        h = torch.matmul(adj, h)
        h = self.fc2(h)
        return F.log_softmax(h, dim=1)

# 训练和测试函数
def train(model, optimizer, features, labels, adj, train_mask, epochs=10):
    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        output = model(features, adj)
        # 计算负对数似然损失
        loss = F.nll_loss(output[train_mask], labels[train_mask])
        loss.backward()
        optimizer.step()
        print(f'Epoch: {epoch + 1}, Loss: {loss.item():.4f}')

def test(model, features, labels, adj, mask):
    model.eval()
    with torch.no_grad():
        output = model(features, adj)
        pred = output[mask].max(1)[1]
        acc = pred.eq(labels[mask]).sum().item() / mask.sum().item()
    return acc

# 主函数
# 创建图数据集
G, features, labels = create_graph()
adj = nx.adjacency_matrix(G).todense()
adj = torch.FloatTensor(adj)
features = torch.FloatTensor(features)
labels = torch.LongTensor(labels)

# 训练和测试掩码,前 30 个节点用于训练
train_mask = torch.BoolTensor([True if i < 30 else False for i in range(len(labels))])
test_mask = ~train_mask

# 初始化模型和优化器
model = GNN(input_dim=features.shape[1], hidden_dim=16, output_dim=2)
optimizer = optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

# 训练模型
train(model, optimizer, features, labels, adj, train_mask)

# 测试模型
train_acc = test(model, features, labels, adj, train_mask)
test_acc = test(model, features, labels, adj, test_mask)
print(f'Train Accuracy: {train_acc:.4f}, Test Accuracy: {test_acc:.4f}')

# 可视化结果
def plot_graph(G, labels, pred=None):
    pos = nx.spring_layout(G)
    plt.figure(figsize=(8, 8))
    nx.draw(G, pos, with_labels=True, node_color=labels, cmap=plt.cm.rainbow, node_size=500, font_color='white')
    if pred is not None:
        nx.draw_networkx_nodes(G, pos, node_color=pred, cmap=plt.cm.rainbow, node_size=200, alpha=0.5)
    plt.show()

plot_graph(G, labels.numpy(), pred=model(features, adj).max(1)[1].numpy())

四、总结

GNN能够直接处理图结构数据。通过端到端的方式进行训练,GNN能够直接从原始图数据中学习特征和表示,这使得它在处理社交网络、分子结构、知识图谱等任务中具有天然的优势。然而,GNN的计算复杂度较高,尤其是在处理大规模图数据时。每一轮迭代都需要进行消息传递和节点更新,这使得GNN的计算量较大,训练和推理速度较慢。在深层GNN中,节点的表示可能会变得过于相似,导致过平滑问题。此外,如果图数据存在噪声或不完整,GNN的性能也会受到影响。

相关推荐
100个铜锣烧5 小时前
高级提示技术:Chain-of-Thought与ReAct——让大模型学会“思考”和“行动”
人工智能·大模型·提示词工程
JackHCC5 小时前
快手OneRetrieval:可编辑生成式电商召回
人工智能·机器学习
hhzz5 小时前
基于监控视频的水位尺自动识别技术方案与实现
python·opencv·yolo·图像识别·cv
yongche_shi5 小时前
ragas官方文档中文版(五十)
开发语言·python·ai·ragas·如何评估和改进 rag 应用
前端之虎陈随易5 小时前
编程语言级别的Skill市场,AI Agent 的未来形态
前端·vue.js·人工智能·typescript·node.js
QiLinkOS5 小时前
第三视觉理解徐玉生与他的商业活动(30)
大数据·c++·人工智能·算法·开源协议
武汉唯众智创6 小时前
当汉字成为心理CT:AI汉字联想投射分析的技术实现与心理评估价值
人工智能·ai心理健康·ai心理评估·本土化心理测评·校园心理健康解决方案·ai心理监测·多模态情绪模型
疯狂打码的少年6 小时前
【操作系统】页面置换算法(OPT/FIFO/LRU)
算法
Longvox6 小时前
Agent为什么会死循环?
人工智能·ai编程
小O的算法实验室6 小时前
2026年CIE,优化客货协同运输:综合地铁系统的列车容量动态分配
算法