GNN入门案例——KarateClub结点分类

文章目录

一、任务描述

Karate Club 图任务是一个经典的图结构学习问题,通常用于社交网络分析和社区检测。该数据集是由 Wayne W. Zachary 在1977年收集的,描述了一个美国的空手道俱乐部成员间的社交互动。

任务描述如下:

  • 图结构:该数据集包含34个节点(即空手道俱乐部的成员)和78条边(即成员之间的友谊关系)。每条边表示两个成员之间的连接。
  • 节点特征:每个节点可以具有一些基本特征,如成员的身份信息(例如,成员的性别、年龄等)。在某些情况下,节点也可能包含其他的上下文信息。
  • 社区划分:一个重要的任务是检测社交网络中的社区结构。在Zachary的研究中,发现该俱乐部在1970年代末分裂成两个派系。这两个派系通常被称为"Club A"和"Club B"。
  • 任务目标:
    节点分类:根据现有的边连接关系和节点特征来预测某个节点属于哪个社区。

二、环境配置

完成该项目需要安装一个关键的第三方依赖torch_geometric,官方文档如下:torch_geometric,可以通过如下指令一键安装:

python 复制代码
conda install pyg -c pyg

除此之外还需要安装pytorch,matplotlib,networkx这三个库,前面的pytorch用于搭建GNN网络结构,后面两个库用来可视化数据。

三、加载数据

通过torch_geometric.datasets直接加载KarateClub数据,查看这个图的大小规模。

python 复制代码
from torch_geometric.datasets import KarateClub

dataset = KarateClub()
print(f'Dataset: {dataset}:')   
print("========================================")
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

输出

python 复制代码
Dataset: KarateClub():
========================================
Number of graphs: 1
Number of features: 34
Number of classes: 4

查看数据的具体内容:

python 复制代码
data = dataset[0]  # Get the first graph object.
print(data)
print(type(data))
# Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34]) # 
# <class 'torch_geometric.data.data.Data'>

说明这个图里面有34个节点,每个节点34个特征,156条边,34个标签,34个训练掩码,训练掩码的作用是在计算损失时只计算有标签的节点损失。edge_index就是图的邻接矩阵,但因为如果用全矩阵来表示邻接矩阵过于稀疏,这里用一个2*156的矩阵表示,代表有156条边,两个tensor对应的位置即为连接的边。比如0-1,0-2...

python 复制代码
edge_index = data.edge_index
print(edge_index)
python 复制代码
tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  3,
          3,  3,  3,  3,  3,  4,  4,  4,  5,  5,  5,  5,  6,  6,  6,  6,  7,  7,
          7,  7,  8,  8,  8,  8,  8,  9,  9, 10, 10, 10, 11, 12, 12, 13, 13, 13,
         13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 19, 20, 20, 21,
         21, 22, 22, 23, 23, 23, 23, 23, 24, 24, 24, 25, 25, 25, 26, 26, 27, 27,
         27, 27, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 31,
         31, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33,
         33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33],
        [ 1,  2,  3,  4,  5,  6,  7,  8, 10, 11, 12, 13, 17, 19, 21, 31,  0,  2,
          3,  7, 13, 17, 19, 21, 30,  0,  1,  3,  7,  8,  9, 13, 27, 28, 32,  0,
          1,  2,  7, 12, 13,  0,  6, 10,  0,  6, 10, 16,  0,  4,  5, 16,  0,  1,
          2,  3,  0,  2, 30, 32, 33,  2, 33,  0,  4,  5,  0,  0,  3,  0,  1,  2,
          3, 33, 32, 33, 32, 33,  5,  6,  0,  1, 32, 33,  0,  1, 33, 32, 33,  0,
          1, 32, 33, 25, 27, 29, 32, 33, 25, 27, 31, 23, 24, 31, 29, 33,  2, 23,
         24, 33,  2, 31, 33, 23, 26, 32, 33,  1,  8, 32, 33,  0, 24, 25, 28, 32,
         33,  2,  8, 14, 15, 18, 20, 22, 23, 29, 30, 31, 33,  8,  9, 13, 14, 15,
         18, 19, 20, 22, 23, 26, 27, 28, 29, 30, 31, 32]])

用network可视化一下这个数据集:

python 复制代码
%matplotlib inline
import torch 
import networkx as nx
import matplotlib.pyplot as plt
def visualize_graph(G,color):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                     node_color=color, cmap="Set2")
    plt.show()
def visualize_embedding(h,color,epoch=None,loss=None):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    h = h.detach().cpu().numpy()
    plt.scatter(h[:,0],h[:,1],s=140,c=color,cmap="Set2")
    if epoch is not None and loss is not None:
        plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
    plt.show()
from torch_geometric.utils import to_networkx
G = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)

四、定义网络结构

定义一个简单的GNN模型,包含三个GCNconv和一个分类层,激活函数选择tanh,将GCNConv结果和最后分类的结果都做返回。

python 复制代码
import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        torch.manual_seed(1234)
        self.conv1 = GCNConv(dataset.num_features, 4)
        self.conv2 = GCNConv(4, 4)
        self.conv3 = GCNConv(4, 2)
        self.classifier = Linear(2, dataset.num_classes)
    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()
        out = self.classifier(h)
        return out, h
model = GCN()
print(model)

先查看一下未训练模型对于数据处理情况:

python 复制代码
model = GCN()
_,h = model(data.x, data.edge_index)
visualize_embedding(h, color=data.y)

五、训练模型

训练过程大同小异了,定义模型,定义优化器,定义损失函数,计算损失,反向传播更新参数,训练400轮,直到训练完成。可视化GCNconv处理之后的特征。

python 复制代码
import time
model = GCN()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train(data):
    optimizer.zero_grad()
    out, h = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss, h
for epoch in range(401):
    loss, h = train(data)
    if epoch % 10 == 0:
        visualize_embedding(h, color=data.y, epoch=epoch, loss=loss)
        time.sleep(0.3)
相关推荐
埃菲尔铁塔_CV算法2 小时前
深度学习神经网络创新点方向
人工智能·深度学习·神经网络
艾思科蓝-何老师【H8053】2 小时前
【ACM出版】第四届信号处理与通信技术国际学术会议(SPCT 2024)
人工智能·信号处理·论文发表·香港中文大学
weixin_452600693 小时前
《青牛科技 GC6125:驱动芯片中的璀璨之星,点亮 IPcamera 和云台控制(替代 BU24025/ROHM)》
人工智能·科技·单片机·嵌入式硬件·新能源充电桩·智能充电枪
学术搬运工3 小时前
【珠海科技学院主办,暨南大学协办 | IEEE出版 | EI检索稳定 】2024年健康大数据与智能医疗国际会议(ICHIH 2024)
大数据·图像处理·人工智能·科技·机器学习·自然语言处理
右恩3 小时前
AI大模型重塑软件开发:流程革新与未来展望
人工智能
图片转成excel表格3 小时前
WPS Office Excel 转 PDF 后图片丢失的解决方法
人工智能·科技·深度学习
ApiHug4 小时前
ApiSmart x Qwen2.5-Coder 开源旗舰编程模型媲美 GPT-4o, ApiSmart 实测!
人工智能·spring boot·spring·ai编程·apihug
哇咔咔哇咔4 小时前
【科普】简述CNN的各种模型
人工智能·神经网络·cnn
李歘歘4 小时前
万字长文解读深度学习——多模态模型CLIP、BLIP、ViLT
人工智能·深度学习