【PYG】Planetoid中边存储的格式,为什么打印前十条边用edge_index[:, :10]

edge_index 是 PyTorch Geometric 中常用的表示图边的张量。它通常是一个形状为 [2, num_edges] 的二维张量,其中 num_edges 表示图中边的数量。每一列表示一条边,包含两个节点的索引。

  • 实际上这是COO存储格式,官方文档里也有写,还有一种是邻接矩阵的存储格式,两种方式是可以互相转换的 https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html

edge_index[:, :10] 表示取出 edge_index 张量的前 10 列,即前 10 条边的节点索引。

在 Python 中,使用切片语法 [:,] 是一种方便的方式来选择多维数组或张量的特定部分(另外一部分Python语法知识)

程序输出结果

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
First 10 edges (edge_index[:, :10]):
tensor([[   0,    0,    0,    1,    1,    1,    2,    2,    2,    2],
        [ 633, 1862, 2582,    2,  652,  654,    1,  332, 1454, 1666]])
edge_index 0: tensor([  0, 633])
edge_index 1: tensor([   0, 1862])
edge_index 2: tensor([   0, 2582])
edge_index 3: tensor([1, 2])
edge_index 4: tensor([  1, 652])
edge_index 5: tensor([  1, 654])
edge_index 6: tensor([2, 1])
edge_index 7: tensor([  2, 332])
edge_index 8: tensor([   2, 1454])
edge_index 9: tensor([   2, 1666])

示例代码

假设你已经使用 PyTorch Geometric 加载了 Cora 数据集,并且 edge_index 已经被定义,以下代码展示如何查看前 10 条边的信息:

python 复制代码
import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

# 加载 Cora 数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())

# 获取数据集中的第一个图
data = dataset[0]

# 打印数据集的基本信息
print(data)

# 获取边索引
edge_index = data.edge_index

# 打印前 10 条边的节点索引
print("First 10 edges (edge_index[:, :10]):")
print(edge_index[:, :10])

for i in range(10):
    print(f"edge_index {i}: {edge_index[:,i]}")

示例输出

假设 edge_index 的前 10 列如下所示:

plaintext 复制代码
tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9],
        [  1,   2,   0,   4,   5,   3,   7,   6,   5,   8]])

这表示:

  1. 第一条边是从节点 0 到节点 1。
  2. 第二条边是从节点 1 到节点 2。
  3. 第三条边是从节点 2 到节点 0。
  4. 第四条边是从节点 3 到节点 4。
  5. 第五条边是从节点 4 到节点 5。
  6. 第六条边是从节点 5 到节点 3。
  7. 第七条边是从节点 6 到节点 7。
  8. 第八条边是从节点 7 到节点 6。
  9. 第九条边是从节点 8 到节点 5。
  10. 第十条边是从节点 9 到节点 8。

解释

  • edge_index 的形状为 [2, num_edges],其中 num_edges 表示边的数量。
  • edge_index[:, :10] 表示取出前 10 条边的节点索引。
  • 输出的张量第一行表示每条边的起始节点,第二行表示每条边的结束节点。

通过这种方式,你可以方便地查看和理解数据集中边的表示方式。

相关推荐
只是有点小怂2 天前
【TORCH】torch.normal()中的size参数
gnn
只是有点小怂4 天前
【PYG】Cora数据集简介
gnn
只是有点小怂6 天前
【PYG】简单分析Planetoid()中存储Cora数据集边的数量
gnn
盼小辉丶1 个月前
图神经网络实战(12)——图同构网络(Graph Isomorphism Network, GIN)
深度学习·图神经网络·gnn
人工智能培训咨询叶梓2 个月前
秒懂图神经网络(GNN)
人工智能·深度学习·神经网络·机器学习·语言模型·gnn·人工智能培训
Karen_Yu_2 个月前
【intro】Graph Isomorphism Network(GIN)
人工智能·神经网络·gin·gnn
盼小辉丶2 个月前
图神经网络实战(9)——GraphSAGE详解与实现
pytorch·图神经网络·gnn
盼小辉丶3 个月前
图神经网络实战(6)——使用PyTorch构建图神经网络
pytorch·深度学习·神经网络·图神经网络·gnn
鲸可落5 个月前
【论文阅读】LLM4GCL: CAN LARGE LANGUAGE MODEL EM-POWER GRAPH CONTRASTIVE LEARNING?
论文阅读·人工智能·语言模型·chatgpt·llm·gnn·图对比学习