【PYG】Planetoid中边存储的格式,为什么打印前十条边用edge_index[:, :10]

edge_index 是 PyTorch Geometric 中常用的表示图边的张量。它通常是一个形状为 [2, num_edges] 的二维张量,其中 num_edges 表示图中边的数量。每一列表示一条边,包含两个节点的索引。

  • 实际上这是COO存储格式,官方文档里也有写,还有一种是邻接矩阵的存储格式,两种方式是可以互相转换的 https://pytorch-geometric.readthedocs.io/en/latest/get_started/introduction.html

edge_index[:, :10] 表示取出 edge_index 张量的前 10 列,即前 10 条边的节点索引。

在 Python 中,使用切片语法 [:,] 是一种方便的方式来选择多维数组或张量的特定部分(另外一部分Python语法知识)

程序输出结果

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])
First 10 edges (edge_index[:, :10]):
tensor([[   0,    0,    0,    1,    1,    1,    2,    2,    2,    2],
        [ 633, 1862, 2582,    2,  652,  654,    1,  332, 1454, 1666]])
edge_index 0: tensor([  0, 633])
edge_index 1: tensor([   0, 1862])
edge_index 2: tensor([   0, 2582])
edge_index 3: tensor([1, 2])
edge_index 4: tensor([  1, 652])
edge_index 5: tensor([  1, 654])
edge_index 6: tensor([2, 1])
edge_index 7: tensor([  2, 332])
edge_index 8: tensor([   2, 1454])
edge_index 9: tensor([   2, 1666])

示例代码

假设你已经使用 PyTorch Geometric 加载了 Cora 数据集,并且 edge_index 已经被定义,以下代码展示如何查看前 10 条边的信息:

python 复制代码
import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures

# 加载 Cora 数据集
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())

# 获取数据集中的第一个图
data = dataset[0]

# 打印数据集的基本信息
print(data)

# 获取边索引
edge_index = data.edge_index

# 打印前 10 条边的节点索引
print("First 10 edges (edge_index[:, :10]):")
print(edge_index[:, :10])

for i in range(10):
    print(f"edge_index {i}: {edge_index[:,i]}")

示例输出

假设 edge_index 的前 10 列如下所示:

plaintext 复制代码
tensor([[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9],
        [  1,   2,   0,   4,   5,   3,   7,   6,   5,   8]])

这表示:

  1. 第一条边是从节点 0 到节点 1。
  2. 第二条边是从节点 1 到节点 2。
  3. 第三条边是从节点 2 到节点 0。
  4. 第四条边是从节点 3 到节点 4。
  5. 第五条边是从节点 4 到节点 5。
  6. 第六条边是从节点 5 到节点 3。
  7. 第七条边是从节点 6 到节点 7。
  8. 第八条边是从节点 7 到节点 6。
  9. 第九条边是从节点 8 到节点 5。
  10. 第十条边是从节点 9 到节点 8。

解释

  • edge_index 的形状为 [2, num_edges],其中 num_edges 表示边的数量。
  • edge_index[:, :10] 表示取出前 10 条边的节点索引。
  • 输出的张量第一行表示每条边的起始节点,第二行表示每条边的结束节点。

通过这种方式,你可以方便地查看和理解数据集中边的表示方式。

相关推荐
Che_Che_10 小时前
Cross-Inlining Binary Function Similarity Detection
人工智能·网络安全·gnn·二进制相似度检测
litble12 天前
图神经网络(GNN)入门笔记(1)——图信号处理与图傅里叶变换
笔记·神经网络·信号处理·图神经网络·gnn·gcn·傅里叶变换
医学小达人13 天前
Python 分子图分类,GNN Model for HIV Molecules Classification,HIV 分子图分类模型;整图分类问题,代码实战
nlp·图神经网络·gnn·图计算·分子图分类·整图分类模型·hiv分子图分类
HERODING771 个月前
【论文精读】RELIEF: Reinforcement Learning Empowered Graph Feature Prompt Tuning
prompt·图论·gnn·图prompt
Cyril_KI1 个月前
PyTorch搭建GNN(GCN、GraphSAGE和GAT)实现多节点、单节点内多变量输入多变量输出时空预测
pytorch·时间序列预测·gnn·时空预测
不会&编程2 个月前
论文阅读:A Generalization of Transformer Networks to Graphs
论文阅读·深度学习·transformer·gnn
shuaixio3 个月前
【VectorNet】vectornet网络学习笔记
gnn·自注意力机制·mlp·vectornet·子图构建·全局图构建
只是有点小怂5 个月前
【TORCH】torch.normal()中的size参数
gnn
只是有点小怂5 个月前
【PYG】Cora数据集简介
gnn
只是有点小怂5 个月前
【PYG】简单分析Planetoid()中存储Cora数据集边的数量
gnn