【PYG】Planetoid中边存储的格式,为什么打印前十条边用edge_index[:, :10]

edge_index 是 PyTorch Geometric 中常用的表示图边的张量。它通常是一个形状为 [2, num_edges] 的二维张量,其中 num_edges 表示图中边的数量。每一列表示一条边,包含两个节点的索引。

  • 实际上这是COO存储格式,官方文档里也有写,还有一种是邻接矩阵的存储格式,两种方式是可以互相转换的 https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html

edge_index[:, :10] 表示取出 edge_index 张量的前 10 列,即前 10 条边的节点索引。

在 Python 中,使用切片语法 [:,] 是一种方便的方式来选择多维数组或张量的特定部分(另外一部分Python语法知识)

程序输出结果

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
First 10 edges (edge_index[:, :10]):
tensor([[   0,    0,    0,    1,    1,    1,    2,    2,    2,    2],
        [ 633, 1862, 2582,    2,  652,  654,    1,  332, 1454, 1666]])
edge_index 0: tensor([  0, 633])
edge_index 1: tensor([   0, 1862])
edge_index 2: tensor([   0, 2582])
edge_index 3: tensor([1, 2])
edge_index 4: tensor([  1, 652])
edge_index 5: tensor([  1, 654])
edge_index 6: tensor([2, 1])
edge_index 7: tensor([  2, 332])
edge_index 8: tensor([   2, 1454])
edge_index 9: tensor([   2, 1666])

示例代码

假设你已经使用 PyTorch Geometric 加载了 Cora 数据集,并且 edge_index 已经被定义,以下代码展示如何查看前 10 条边的信息:

python 复制代码
import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

# 加载 Cora 数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())

# 获取数据集中的第一个图
data = dataset[0]

# 打印数据集的基本信息
print(data)

# 获取边索引
edge_index = data.edge_index

# 打印前 10 条边的节点索引
print("First 10 edges (edge_index[:, :10]):")
print(edge_index[:, :10])

for i in range(10):
    print(f"edge_index {i}: {edge_index[:,i]}")

示例输出

假设 edge_index 的前 10 列如下所示:

plaintext 复制代码
tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9],
        [  1,   2,   0,   4,   5,   3,   7,   6,   5,   8]])

这表示:

  1. 第一条边是从节点 0 到节点 1。
  2. 第二条边是从节点 1 到节点 2。
  3. 第三条边是从节点 2 到节点 0。
  4. 第四条边是从节点 3 到节点 4。
  5. 第五条边是从节点 4 到节点 5。
  6. 第六条边是从节点 5 到节点 3。
  7. 第七条边是从节点 6 到节点 7。
  8. 第八条边是从节点 7 到节点 6。
  9. 第九条边是从节点 8 到节点 5。
  10. 第十条边是从节点 9 到节点 8。

解释

  • edge_index 的形状为 [2, num_edges],其中 num_edges 表示边的数量。
  • edge_index[:, :10] 表示取出前 10 条边的节点索引。
  • 输出的张量第一行表示每条边的起始节点,第二行表示每条边的结束节点。

通过这种方式,你可以方便地查看和理解数据集中边的表示方式。

相关推荐
HERODING7724 天前
【论文精读】RELIEF: Reinforcement Learning Empowered Graph Feature Prompt Tuning
prompt·图论·gnn·图prompt
Cyril_KI1 个月前
PyTorch搭建GNN(GCN、GraphSAGE和GAT)实现多节点、单节点内多变量输入多变量输出时空预测
pytorch·时间序列预测·gnn·时空预测
不会&编程2 个月前
论文阅读:A Generalization of Transformer Networks to Graphs
论文阅读·深度学习·transformer·gnn
shuaixio2 个月前
【VectorNet】vectornet网络学习笔记
gnn·自注意力机制·mlp·vectornet·子图构建·全局图构建
只是有点小怂4 个月前
【TORCH】torch.normal()中的size参数
gnn
只是有点小怂4 个月前
【PYG】Cora数据集简介
gnn
只是有点小怂4 个月前
【PYG】简单分析Planetoid()中存储Cora数据集边的数量
gnn
盼小辉丶5 个月前
图神经网络实战(12)——图同构网络(Graph Isomorphism Network, GIN)
深度学习·图神经网络·gnn
人工智能培训咨询叶梓6 个月前
秒懂图神经网络(GNN)
人工智能·深度学习·神经网络·机器学习·语言模型·gnn·人工智能培训