【操作记录】pytorch_geometric安装方法

pytorch_geometric安装方法

github地址

主要不要直接pip install安装,会由于依赖无法安装而失败

点击here手动安装依赖

选择对应的pytorch版本,我的是Win10 Python3.8.3+Pytorch1.8.1+CUDA10.2

手动下载四个依赖包本地安装:

主要不要直接:pip install torch_geometric

这样会安装最新的torch_geometric,后面在使用时候会出现Pytorch AttributeError: module 'torch' has no attribute 'sparse_scs'

这里手动指定低版本安装:

pip install torch_geometric==2.0.4

接下来就可以跑一个图卷积神经网络试试:

Python 复制代码
import torch
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.datasets import KarateClub
from torch_geometric.utils import to_networkx


dataset = KarateClub()

def visualize_graph(G, color):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                     node_color=color, cmap="Set2")
    plt.show()


def visualize_embedding(h, color, epoch=None, loss=None):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    h = h.detach().cpu().numpy()
    plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")
    if epoch is not None and loss is not None:
        plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
    plt.show()
    
G = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)    
python 复制代码
import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        torch.manual_seed(1234)
        self.conv1 = GCNConv(dataset.num_features, 4) # 只需定义好输入特征和输出特征即可
        self.conv2 = GCNConv(4, 4)
        self.conv3 = GCNConv(4, 2)
        self.classifier = Linear(2, dataset.num_classes)

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index) # 输入特征与邻接矩阵(注意格式,上面那种)
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()  
        
        # 分类层
        out = self.classifier(h)

        return out, h

model = GCN()
_, h = model(data.x, data.edge_index)
print(f'Embedding shape: {list(h.shape)}')

visualize_embedding(h, color=data.y)
复制代码
Embedding shape: [34, 2]

训练模型

Python 复制代码
import time

model = GCN()
criterion = torch.nn.CrossEntropyLoss()  # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Define optimizer.

def train(data):
    optimizer.zero_grad()  
    out, h = model(data.x, data.edge_index) #h是两维向量,主要是为了咱们画个图 
    loss = criterion(out[data.train_mask], data.y[data.train_mask])  # semi-supervised
    loss.backward()  
    optimizer.step()  
    return loss, h

for epoch in range(401):
    loss, h = train(data)
    if epoch % 10 == 0:
        visualize_embedding(h, color=data.y, epoch=epoch, loss=loss)
        time.sleep(0.3)
相关推荐
weixin_4597539418 分钟前
golang如何实现Trace上下文传播_golang Trace上下文传播实现思路
jvm·数据库·python
weixin_4440129327 分钟前
PHP 中逻辑或(--)运算符的正确使用与条件逻辑重构指南
jvm·数据库·python
iAm_Ike6 小时前
Go 中自定义类型与基础类型间的显式类型转换详解
jvm·数据库·python
iuvtsrt6 小时前
Golang怎么实现方法集与接口的匹配_Golang如何理解值类型和指针类型实现接口的区别【详解】
jvm·数据库·python
牧子川7 小时前
009-Transformer-Architecture
人工智能·深度学习·transformer
covco7 小时前
矩阵管理系统指南:拆解星链引擎的架构设计与全链路落地实践
大数据·人工智能·矩阵
沪漂阿龙7 小时前
AI大模型面试题:支持向量机是什么?间隔最大化、软间隔、核函数、LinearSVC 全面拆解
人工智能·算法·支持向量机
lifewange7 小时前
AI编写测试用例工具介绍
人工智能·测试用例
陕西字符7 小时前
2026 西安 豆包获客优化技术深度解析:企来客科技 AI 全域获客系统测评
大数据·人工智能
掘金安东尼7 小时前
GGUF、GPTQ、AWQ、EXL2、MLX、VMLX...运行大模型,为什么会有这么多格式?
人工智能