DenseDataLoader
是专门用于处理稠密图数据的,而 DataLoader
通常用于处理稀疏图数据。两者的主要区别在于它们的输入数据格式和处理方式。DenseDataLoader
适合处理固定大小的邻接矩阵和节点特征矩阵的数据,而 DataLoader
更加灵活,可以处理稀疏表示的图数据。
主要区别
-
DataLoader
:- 适合处理稀疏图数据。
- 通常与
torch_geometric.data.Data
一起使用,其中边索引是稀疏表示的。 - 更加灵活,适合处理各种不同形状和大小的图。
-
DenseDataLoader
:- 适合处理稠密图数据。
- 通常与固定大小的邻接矩阵和节点特征矩阵一起使用。
- 更高效地处理固定大小的图数据。
使用示例
使用 DenseDataLoader
如果你有固定大小的邻接矩阵和节点特征矩阵,可以直接使用 DenseDataLoader
加载数据:
1. 导入必要的库
python
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DenseDataLoader
2. 定义数据集类
python
class MyDenseDataset(torch.utils.data.Dataset):
def __init__(self, num_samples, num_nodes, num_node_features):
self.num_samples = num_samples
self.num_nodes = num_nodes
self.num_node_features = num_node_features
self.adj_matrix = self.create_adj_matrix(num_nodes)
def create_adj_matrix(self, num_nodes):
# 创建环形图的邻接矩阵
adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.float)
for i in range(num_nodes):
adj_matrix[i, (i + 1) % num_nodes] = 1
adj_matrix[(i + 1) % num_nodes, i] = 1
return adj_matrix
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# 创建随机特征和标签
x = torch.randn((self.num_nodes, self.num_node_features))
y = torch.randn((self.num_nodes, 1)) # 每个节点一个标签
return Data(x=x, adj=self.adj_matrix, y=y)
3. 创建数据集和封装数据
python
# 参数设置
num_samples = 100 # 样本数
num_nodes = 10 # 每个图中的节点数
num_node_features = 8 # 每个节点的特征数
# 创建数据集
dataset = MyDenseDataset(num_samples, num_nodes, num_node_features)
4. 使用 DenseDataLoader
python
# 使用 DenseDataLoader 加载数据
loader = DenseDataLoader(dataset, batch_size=32, shuffle=True)
# 从 DenseDataLoader 中获取一个批次的数据并查看其形状
for data in loader:
print("Batch node features shape:", data.x.shape) # 期望输出形状为 (32, 10, 8)
print("Batch adjacency matrix shape:", data.adj.shape) # 期望输出形状为 (32, 10, 10)
print("Batch labels shape:", data.y.shape) # 期望输出形状为 (32, 10, 1)
break # 仅查看第一个批次的形状
解释
-
导入库:
- 导入
torch
、torch_geometric.data
中的Data
和torch_geometric.loader
中的DenseDataLoader
。
- 导入
-
定义
MyDenseDataset
类:__init__
方法初始化数据集参数,并创建邻接矩阵。create_adj_matrix
方法创建环形图的邻接矩阵。__len__
方法返回数据集的样本数量。__getitem__
方法生成每个样本的随机节点特征和标签,并返回节点特征矩阵、邻接矩阵和标签。
-
创建数据集:
- 使用
MyDenseDataset
类创建一个包含 100 个样本的数据集,每个样本包含 10 个节点,每个节点有 8 个特征。
- 使用
-
使用
DenseDataLoader
:- 使用
DenseDataLoader
加载dataset
,设置批次大小为 32,并进行随机打乱。 - 在获取一个批次的数据时,检查
x
、adj
和y
的形状,以确保其符合期望的三维形状。
- 使用
通过这个完整的示例代码,你可以生成、封装和加载稠密图数据,并确保每个批次的数据形状保持正确。这种方法适合处理节点数和边数固定的图数据,提高数据加载和处理的效率。
定义数据集类并使用 DenseDataLoader
python
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DenseDataLoader # 更新导入路径
class MyDenseDataset(torch.utils.data.Dataset):
def __init__(self, num_samples, num_nodes, num_node_features):
self.num_samples = num_samples
self.num_nodes = num_nodes
self.num_node_features = num_node_features
self.adj_matrix = self.create_adj_matrix(num_nodes)
def create_adj_matrix(self, num_nodes):
# 创建环形图的邻接矩阵
adj_matrix = torch.zeros((num_nodes, num_nodes), dtype=torch.float)
for i in range(num_nodes):
adj_matrix[i, (i + 1) % num_nodes] = 1
adj_matrix[(i + 1) % num_nodes, i] = 1
print(adj_matrix)
return adj_matrix
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
# 创建随机特征和标签
x = torch.randn((self.num_nodes, self.num_node_features))
y = torch.randn((self.num_nodes, 1)) # 每个节点一个标签
return Data(x, self.adj_matrix, y=y)
# 创建数据集
num_samples = 100 # 样本数
num_nodes = 10 # 每个图中的节点数
num_node_features = 8 # 每个节点的特征数
dataset = MyDenseDataset(num_samples, num_nodes, num_node_features)
# 使用 DenseDataLoader 加载数据
loader = DenseDataLoader(dataset, batch_size=32, shuffle=True)
# 从 DenseDataLoader 中获取一个批次的数据并查看其形状
for data in loader:
print("Batch node features shape:", data.x.shape) # 期望输出形状为 (32, 10, 8)
# print("Batch adjacency matrix shape:", data.adj.shape) # 期望输出形状为 (32, 10, 10)
print("Batch labels shape:", data.y.shape) # 期望输出形状为 (32, 10, 1)
break # 仅查看第一个批次的形状
使用 DataLoader
如果你使用的是 DataLoader
,则数据应当是 torch_geometric.data.Data
对象,并将数据封装在列表中:
python
import torch
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader # 更新导入路径
class MyDataset(torch.utils.data.Dataset):
def __init__(self, num_samples, num_nodes, num_node_features):
self.num_samples = num_samples
self.num_nodes = num_nodes
self.num_node_features = num_node_features
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
x = torch.randn(self.num_nodes, self.num_node_features)
edge_index = torch.tensor([[i, (i + 1) % self.num_nodes] for i in range(self.num_nodes)], dtype=torch.long).t().contiguous()
y = torch.randn(self.num_nodes, 1)
return Data(x=x, edge_index=edge_index, y=y)
# 创建数据集
num_samples = 100 # 样本数
num_nodes = 10 # 每个图中的节点数
num_node_features = 8 # 每个节点的特征数
dataset = MyDataset(num_samples, num_nodes, num_node_features)
# 使用 DataLoader 加载数据
loader = DataLoader(dataset, batch_size=32, shuffle=True)
# 迭代加载数据
for batch in loader:
print("Batch node features shape:", batch.x.shape) # 期望输出形状为 (320, 8)
print("Batch edge index shape:", batch.edge_index.shape)
总结
DenseDataLoader
:处理固定大小的邻接矩阵和节点特征矩阵的数据,__getitem__
返回Data(x, adj, y)。DataLoader
:处理torch_geometric.data.Data
对象,__getitem__
返回一个Data
对象。
确保数据格式与使用的加载器相匹配,以避免属性错误和其他兼容性问题。