GNN入门案例——KarateClub结点分类

文章目录

一、任务描述

Karate Club 图任务是一个经典的图结构学习问题,通常用于社交网络分析和社区检测。该数据集是由 Wayne W. Zachary 在1977年收集的,描述了一个美国的空手道俱乐部成员间的社交互动。

任务描述如下:

  • 图结构:该数据集包含34个节点(即空手道俱乐部的成员)和78条边(即成员之间的友谊关系)。每条边表示两个成员之间的连接。
  • 节点特征:每个节点可以具有一些基本特征,如成员的身份信息(例如,成员的性别、年龄等)。在某些情况下,节点也可能包含其他的上下文信息。
  • 社区划分:一个重要的任务是检测社交网络中的社区结构。在Zachary的研究中,发现该俱乐部在1970年代末分裂成两个派系。这两个派系通常被称为"Club A"和"Club B"。
  • 任务目标:
    节点分类:根据现有的边连接关系和节点特征来预测某个节点属于哪个社区。

二、环境配置

完成该项目需要安装一个关键的第三方依赖torch_geometric,官方文档如下:torch_geometric,可以通过如下指令一键安装:

python 复制代码
conda install pyg -c pyg

除此之外还需要安装pytorch,matplotlib,networkx这三个库,前面的pytorch用于搭建GNN网络结构,后面两个库用来可视化数据。

三、加载数据

通过torch_geometric.datasets直接加载KarateClub数据,查看这个图的大小规模。

python 复制代码
from torch_geometric.datasets import KarateClub

dataset = KarateClub()
print(f'Dataset: {dataset}:')   
print("========================================")
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

输出

python 复制代码
Dataset: KarateClub():
========================================
Number of graphs: 1
Number of features: 34
Number of classes: 4

查看数据的具体内容:

python 复制代码
data = dataset[0]  # Get the first graph object.
print(data)
print(type(data))
# Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34]) # 
# <class 'torch_geometric.data.data.Data'>

说明这个图里面有34个节点,每个节点34个特征,156条边,34个标签,34个训练掩码,训练掩码的作用是在计算损失时只计算有标签的节点损失。edge_index就是图的邻接矩阵,但因为如果用全矩阵来表示邻接矩阵过于稀疏,这里用一个2*156的矩阵表示,代表有156条边,两个tensor对应的位置即为连接的边。比如0-1,0-2...

python 复制代码
edge_index = data.edge_index
print(edge_index)
python 复制代码
tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  3,
          3,  3,  3,  3,  3,  4,  4,  4,  5,  5,  5,  5,  6,  6,  6,  6,  7,  7,
          7,  7,  8,  8,  8,  8,  8,  9,  9, 10, 10, 10, 11, 12, 12, 13, 13, 13,
         13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 19, 20, 20, 21,
         21, 22, 22, 23, 23, 23, 23, 23, 24, 24, 24, 25, 25, 25, 26, 26, 27, 27,
         27, 27, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 31,
         31, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33,
         33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33],
        [ 1,  2,  3,  4,  5,  6,  7,  8, 10, 11, 12, 13, 17, 19, 21, 31,  0,  2,
          3,  7, 13, 17, 19, 21, 30,  0,  1,  3,  7,  8,  9, 13, 27, 28, 32,  0,
          1,  2,  7, 12, 13,  0,  6, 10,  0,  6, 10, 16,  0,  4,  5, 16,  0,  1,
          2,  3,  0,  2, 30, 32, 33,  2, 33,  0,  4,  5,  0,  0,  3,  0,  1,  2,
          3, 33, 32, 33, 32, 33,  5,  6,  0,  1, 32, 33,  0,  1, 33, 32, 33,  0,
          1, 32, 33, 25, 27, 29, 32, 33, 25, 27, 31, 23, 24, 31, 29, 33,  2, 23,
         24, 33,  2, 31, 33, 23, 26, 32, 33,  1,  8, 32, 33,  0, 24, 25, 28, 32,
         33,  2,  8, 14, 15, 18, 20, 22, 23, 29, 30, 31, 33,  8,  9, 13, 14, 15,
         18, 19, 20, 22, 23, 26, 27, 28, 29, 30, 31, 32]])

用network可视化一下这个数据集:

python 复制代码
%matplotlib inline
import torch 
import networkx as nx
import matplotlib.pyplot as plt
def visualize_graph(G,color):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                     node_color=color, cmap="Set2")
    plt.show()
def visualize_embedding(h,color,epoch=None,loss=None):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    h = h.detach().cpu().numpy()
    plt.scatter(h[:,0],h[:,1],s=140,c=color,cmap="Set2")
    if epoch is not None and loss is not None:
        plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
    plt.show()
from torch_geometric.utils import to_networkx
G = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)

四、定义网络结构

定义一个简单的GNN模型,包含三个GCNconv和一个分类层,激活函数选择tanh,将GCNConv结果和最后分类的结果都做返回。

python 复制代码
import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        torch.manual_seed(1234)
        self.conv1 = GCNConv(dataset.num_features, 4)
        self.conv2 = GCNConv(4, 4)
        self.conv3 = GCNConv(4, 2)
        self.classifier = Linear(2, dataset.num_classes)
    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()
        out = self.classifier(h)
        return out, h
model = GCN()
print(model)

先查看一下未训练模型对于数据处理情况:

python 复制代码
model = GCN()
_,h = model(data.x, data.edge_index)
visualize_embedding(h, color=data.y)

五、训练模型

训练过程大同小异了,定义模型,定义优化器,定义损失函数,计算损失,反向传播更新参数,训练400轮,直到训练完成。可视化GCNconv处理之后的特征。

python 复制代码
import time
model = GCN()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train(data):
    optimizer.zero_grad()
    out, h = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss, h
for epoch in range(401):
    loss, h = train(data)
    if epoch % 10 == 0:
        visualize_embedding(h, color=data.y, epoch=epoch, loss=loss)
        time.sleep(0.3)
相关推荐
思通数科多模态大模型19 分钟前
10大核心应用场景,解锁AI检测系统的智能安全之道
人工智能·深度学习·安全·目标检测·计算机视觉·自然语言处理·数据挖掘
数据岛23 分钟前
数据集论文:面向深度学习的土地利用场景分类与变化检测
人工智能·深度学习
龙的爹23331 小时前
论文翻译 | RECITATION-AUGMENTED LANGUAGE MODELS
人工智能·语言模型·自然语言处理·prompt·gpu算力
白光白光1 小时前
凸函数与深度学习调参
人工智能·深度学习
sp_fyf_20241 小时前
【大语言模型】ACL2024论文-18 MINPROMPT:基于图的最小提示数据增强用于少样本问答
人工智能·深度学习·神经网络·目标检测·机器学习·语言模型·自然语言处理
weixin_543662861 小时前
BERT的中文问答系统33
人工智能·深度学习·bert
爱喝白开水a1 小时前
Sentence-BERT实现文本匹配【分类目标函数】
人工智能·深度学习·机器学习·自然语言处理·分类·bert·大模型微调
Jack黄从零学c++1 小时前
opencv(c++)---自带的卷积运算filter2D以及应用
c++·人工智能·opencv
封步宇AIGC2 小时前
量化交易系统开发-实时行情自动化交易-4.2.3.指数移动平均线实现
人工智能·python·机器学习·数据挖掘
Mr.谢尔比2 小时前
李宏毅机器学习课程知识点摘要(1-5集)
人工智能·pytorch·深度学习·神经网络·算法·机器学习·计算机视觉