【chatgpt】两层gcn提取最后一层节点输出特征,如何自定义简单数据集

文章目录

两层gcn,提取最后一层节点输出特征,10个节点,每个节点8个特征,连接关系随机生成(无全连接层)

如果没有全连接层(fc层),那么输出将是每个节点的特征。这样的话,标签需要对应每个节点的输出特征。在这种情况下,标签的维度应该是节点数乘以输出特征数。

如果我们将输出特征数设为1,并且没有全连接层,那么每个节点的输出将是一个标量。标签的维度将与节点数相同。例如,假设有10个节点,每个节点的输出是一个标量,那么标签的维度应该是 (10,)

以下是一个没有全连接层的双层GCN实现,其中标签数和节点数相同:

python 复制代码
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        print(x.shape)
        x = self.conv1(x, edge_index)
        print(x.shape)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        print(x.shape)
        return x

# 定义图数据
num_nodes = 10
num_node_features = 8
edge_index = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
                           [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]], dtype=torch.long)
x = torch.randn((num_nodes, num_node_features))

data = Data(x=x, edge_index=edge_index)

# 定义并训练模型
model = GCN(in_channels=num_node_features, hidden_channels=16, out_channels=1)
print(model)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.MSELoss()

# 假设我们有一些训练标签,每个节点一个标签
labels = torch.randn((num_nodes, 1))

# 简单的训练循环
model.train()
for epoch in range(1):
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = criterion(out, labels)
    print(f"loss {loss}")
    loss.backward()
    optimizer.step()

# 提取最后一层节点输出特征
model.eval()
with torch.no_grad():
    out = model(data.x, data.edge_index)
    print(out)

在这个示例中:

  1. 模型输出的形状是 (num_nodes, 1),即每个节点有一个输出特征。
  2. 标签的形状也是 (num_nodes, 1),即每个节点有一个对应的标签。
  3. 损失函数使用均方误差损失(MSELoss),并计算每个节点输出与对应标签之间的误差。

这样,我们确保输出与标签的维度一致,从而正确计算损失并进行训练。

如何计算MSE

在没有全连接层的情况下,每个节点都有一个输出特征,标签也是每个节点一个标签。损失函数(如MSELoss)会计算每个节点输出特征与对应标签之间的误差,然后将这些误差进行平均或求和,得到总体损失。

以下是一个更详细的说明:

  • 模型输出 :形状为(num_nodes, 1),表示每个节点的输出特征。
  • 标签 :形状为(num_nodes, 1),表示每个节点的对应标签。

在训练过程中,损失函数会计算每个节点的输出与对应标签之间的误差,并将所有节点的误差进行汇总(通常是求平均),得到一个标量损失值。优化器会根据这个损失值更新模型的参数。

我们使用均方误差损失(MSELoss),其计算公式为:
MSE = 1 N ∑ i = 1 N ( y i − y ^ i ) 2 \text{MSE} = \frac{1}{N} \sum_{i=1}^N (y_i - \hat{y}_i)^2 MSE=N1∑i=1N(yi−y^i)2

其中 N N N是节点数, y i y_i yi 是第 i i i 个节点的标签, y ^ i \hat{y}_i y^i是第 i i i 个节点的预测值。

在这个示例中:

  1. 模型输出 :每个节点的输出特征(形状为 (num_nodes, 1))。
  2. 标签 :每个节点的对应标签(形状为 (num_nodes, 1))。
  3. 损失函数:均方误差损失(MSELoss),计算每个节点输出与对应标签之间的误差,并将这些误差进行平均,得到总体损失。

这个过程确保了每个节点的输出特征与其对应的标签进行比较,从而计算损失并优化模型。

100个样本,并且使用批量大小为32进行训练

第一个版本定义数据集出错,添加super()并修改为__len__和__getitem__

python 复制代码
# 定义数据集类
class MyDataset(Dataset):
    def __init__(self, num_samples, num_nodes, num_node_features):
        self.num_samples = num_samples
        self.num_nodes = num_nodes
        self.num_node_features = num_node_features

    def len(self):
        return self.num_samples

    def get(self, idx):
        edge_index = torch.tensor([[i for i in range(self.num_nodes)],
                                   [(i + 1) % self.num_nodes for i in range(self.num_nodes)]], dtype=torch.long)
        x = torch.randn((self.num_nodes, self.num_node_features))
        y = torch.randn((self.num_nodes, 1))  # 每个节点一个标签
        return Data(x=x, edge_index=edge_index, y=y)

当你遇到 AttributeError: can't set attribute 错误时,通常是因为你试图在一个类的实例上设置一个属性,但该类不允许直接设置属性。在Python中,某些类,特别是那些继承自某些基类(比如 torch.utils.data.Dataset)的类,限制了直接属性赋值的操作,这是为了安全或一致性考虑。

要解决这个问题,你应该在你的 MyDataset 类的 __init__ 方法中正确地初始化属性。以下是如何修复这个问题的方法:

python 复制代码
from torch.utils.data import Dataset
from torch_geometric.data import Data

# 定义数据集类
class MyDataset(Dataset):
    def __init__(self, num_samples, num_nodes, num_node_features):
        super(MyDataset, self).__init__()  # Initialize base class if inheriting
        self.num_samples = num_samples
        self.num_nodes = num_nodes
        self.num_node_features = num_node_features
        # 创建固定的边索引,这里简单使用环形图
        self.edge_index = torch.tensor([[i for i in range(self.num_nodes)],
                                        [(i + 1) % self.num_nodes for i in range(self.num_nodes)]], dtype=torch.long)


    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # 创建随机特征和标签,这里仅作示例
        x = torch.randn((self.num_nodes, self.num_node_features))
        y = torch.randn((self.num_nodes, 1))  # 每个节点一个标签
        
        # 返回一个包含图数据的 Data 对象,保持相同的边索引
        return Data(x=x, edge_index=self.edge_index, y=y)

# 创建数据集和数据加载器
num_samples = 100
num_nodes = 10
num_node_features = 8
batch_size = 32

dataset = MyDataset(num_samples, num_nodes, num_node_features)
print(dataset[0].edge_index)

idx的作用

在PyTorch中,特别是在使用 torch.utils.data.Datasettorch.utils.data.DataLoader 构建数据加载和处理管道时,idx(或者通常命名为 index)代表着数据集中样本的索引。具体来说:

  1. __getitem__ 方法中的 idx

    • 在自定义的数据集类中,通常会实现 __getitem__ 方法。这个方法接收一个参数 idx,它表示你要获取的样本在数据集中的索引。
    • 例如,在一个图像分类任务中,idx 就是每张图像在数据集中的位置。通过这个索引,你可以从数据集中加载并返回对应位置的样本数据。
  2. 作用

    • idx 的作用是定位和访问数据集中特定样本的数据。在训练过程中,DataLoader 会使用 __getitem__ 方法迭代数据集,根据给定的 idx 获取每个样本,然后将它们组织成批量供模型训练。
    • 在使用 DataLoader 加载数据时,idx 通常会被 DataLoader 内部迭代器管理,你无需手动传递它,只需实现好 __getitem__ 方法即可。
  3. 示例

    • 假设你有一个自定义的数据集类 MyDataset,实现了 __getitem__ 方法来根据 idx 加载图像数据。当你使用 DataLoader 加载这个数据集时,DataLoader 会自动处理索引的管理和批量数据的组织,你只需要关注数据集类的实现和模型的训练过程。

总结来说,idx 是用来在数据集中定位和访问特定样本的索引参数,它在自定义数据集类中的作用是非常重要的,能够帮助你有效地管理和处理数据集中的样本数据。

使用super()方法

如果在使用 super() 调用时出现错误,通常是因为类的初始化方法(__init__)中没有正确地调用父类的初始化方法。这可能会导致 Python 报告类的属性无法设置的错误。让我们来看看如何正确使用 super() 并初始化属性。

在你的 MyDataset 类中,确保按照以下方式使用 super() 和正确初始化属性:

  • super() 函数 :在 Python 中,super() 函数用于调用父类的方法。在 MyDataset 类的 __init__ 方法中,super(MyDataset, self).__init__() 调用了 Dataset 类的初始化方法,确保正确初始化了 Dataset 类中的属性和方法。
  • 属性初始化 :在 MyDataset__init__ 方法中,通过 self.num_samplesself.num_nodesself.num_node_features 初始化了数据集的属性。这些属性用于定义数据集的特征和样本数量。
  • 数据加载__getitem__ 方法用于按照给定的 idx 加载数据集中的样本,并返回一个包含图数据的 Data 对象。

通过这样的方式,你可以确保 MyDataset 类正确地继承了 Dataset 类,并正确初始化了属性,避免了 AttributeError 错误的发生。

再次解释一遍定义数据集的代码

在PyTorch中,torch.utils.data.Dataset 是一个抽象基类,要求自定义的数据集类必须实现 __len____getitem__ 方法。这些方法分别用于确定数据集的长度和获取数据集中的一个样本。

  • __init__ 方法 :在 __init__ 方法中,创建了一个固定的边索引 self.edge_index,这里使用简单的环形图示例。这个边索引在数据集初始化时被创建,并在每次调用 __getitem__ 方法时被重复使用,从而确保每个样本的图数据保持相同的连接关系。
  • __len__ 方法 :这个方法返回数据集的长度,即数据集中样本的数量。在这里,它返回了 num_samples,表示数据集中有多少个样本。
  • __getitem__ 方法 :这个方法根据给定的索引 idx 返回数据集中的一个样本。在这里,它返回一个包含随机节点特征、固定边索引和随机节点标签的 Data 对象,确保了图连接关系的不变性。

你可以正确地实现并使用 MyDataset 类来创建多个数据集样本,并确保每个样本的图连接关系保持不变。

另外一个值得注意的错误:定义数据集部分修改之后还是报obj = super().new(cls)TypeError: Can't instantiate abstract class MyDataset with abstract methods get, len错误

from torch.utils.data import Dataset 与 from torch_geometric.data import Data 和 Dataset是不一样的

torch.utils.data.Datasettorch_geometric.data.Dataset 是两个不同的类,分别来自于不同的模块,功能和用途也略有不同。

  1. torch.utils.data.Dataset

    • 这是 PyTorch 提供的一个抽象基类,用于创建自定义数据集。它要求用户继承并实现 __len____getitem__ 方法,以便能够使用 torch.utils.data.DataLoader 进行数据加载和批处理。
    • 主要用途是在通用的机器学习任务中加载和处理数据集,例如图像分类、文本处理等。
  2. torch_geometric.data.Dataset

    • 这是 PyTorch Geometric 提供的一个特定数据集类,用于处理图数据。它继承自 torch.utils.data.Dataset,并额外提供了一些方法和功能,使得可以更方便地处理图数据集。
    • 主要用途是在图神经网络中加载和处理图数据,包括节点特征、边索引等。
  • 功能特点

    • torch.utils.data.Dataset 适用于通用的数据加载和处理,可以处理各种类型的数据集。
    • torch_geometric.data.Dataset 专门用于处理图数据,提供了额外的功能来处理节点和边的特征。
  • 使用场景

    • 如果你处理的是普通的数据集(如图像、文本等),可以使用 torch.utils.data.Dataset 来创建自定义的数据加载器。
    • 如果你处理的是图数据(如节点和边具有特定的连接关系和属性),建议使用 torch_geometric.data.Dataset 来利用其专门针对图数据设计的功能。

如果你想要处理图数据,可以使用 torch_geometric.data.Dataset 的子类,例如 torch_geometric.datasets.Planetoid,用来加载图数据集,例如 Planetoid 数据集:

python 复制代码
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T

dataset = Planetoid(root='/your/data/path', name='Cora', transform=T.NormalizeFeatures())

这里使用了 Planetoid 数据集类,它继承自 torch_geometric.data.Dataset,专门用于加载和处理图数据集,例如 Cora 数据集。

from torch_geometric.data import Data的作用

在 PyTorch Geometric 中,torch_geometric.data.Data 是一个用于表示图数据的核心数据结构之一。它主要用来存储图中的节点特征、边索引以及可选的图级别特征,具有以下作用:

  1. 存储节点特征和边索引

    • Data 对象可以存储节点特征矩阵(通常是一个二维张量)和边索引(通常是一个二维长整型张量)。节点特征矩阵的每一行表示一个节点的特征向量,边索引描述了节点之间的连接关系。
  2. 支持图级别的特征

    • 除了节点特征和边索引外,Data 对象还可以存储图级别的特征,例如全局图特征(如图的标签或属性)。
  3. 作为输入输出的载体

    • 在图神经网络中,Data 对象通常作为输入数据的载体。例如,在进行图分类、节点分类或图生成任务时,模型的输入通常是 Data 对象。
  4. 与其他 PyTorch Geometric 函数和类的兼容性

    • Data 对象与 PyTorch Geometric 中的其他函数和类高度兼容,例如数据转换、数据集加载等。它们共同支持创建、处理和转换图数据。
  5. 用于数据集的表示

    • 在自定义的数据集中,你可以使用 Data 对象来表示每个样本的图数据。通过组织和存储节点特征、边索引和图级别特征,可以更方便地加载和处理复杂的图结构数据集。

下面是一个简单的示例,展示如何使用 Data 对象创建和操作图数据:

python 复制代码
from torch_geometric.data import Data
import torch

# 创建节点特征和边索引
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float)  # 3个节点,每个节点2个特征
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)  # 边索引表示节点之间的连接关系

# 创建一个 Data 对象
data = Data(x=x, edge_index=edge_index)

# 访问和操作 Data 对象中的属性
print(data)
print("Number of nodes:", data.num_nodes)
print("Number of edges:", data.num_edges)
print("Node features shape:", data.x.shape)
print("Edge index shape:", data.edge_index.shape)

在这个示例中,我们首先创建了节点特征矩阵 x 和边索引 edge_index,然后使用它们来实例化一个 Data 对象 data。通过访问 data 对象的属性,可以获取节点数、边数以及节点特征和边索引的形状信息。

总之,torch_geometric.data.Data 在 PyTorch Geometric 中扮演着关键的角色,用于表示和处理图数据,是构建图神经网络模型的重要基础之一。

验证 MyDataset 类生成的样本和批次数据的形状

为了实现一个自定义数据集 MyDataset,可以创建一个包含 100 个样本的数据集,每个样本包含一个形状为 (32, 8) 的节点特征矩阵。需要注意的是,MyDataset 类中的 __getitem__ 方法应该返回每个样本的数据,包括节点特征、边索引等。

我们可以通过打印 MyDataset 中每个样本的数据形状来验证数据的形状。以下是实现和验证的示例代码:

python 复制代码
dataset = MyDataset(num_samples, num_nodes, num_node_features)
print(len(dataset))

# 查看数据集中的前几个样本的形状
for i in range(3):
    data = dataset[i]
    print(f"Sample {i} - Node features shape: {data.x.shape}, Edge index shape: {data.edge_index.shape}, Labels shape: {data.y.shape}")

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 从 DataLoader 中获取一个批次的数据
for batch in dataloader:
    print("Batch node features shape:", batch.x.shape)
    print("Batch edge index shape:", batch.edge_index.shape)
    print("Batch labels shape:", batch.y.shape)
    break  # 仅查看第一个批次的形状

数据格式是展平的 (batch_size * num_nodes, num_features)

使用DenseDataLoader数据格式为(batch_size , num_nodes, num_features)

DenseDataLoader 和 DataLoader 在处理数据的方式上有所不同。DenseDataLoader 是专门用于处理稠密图数据的,而 DataLoader 通常用于处理稀疏图数据。在你的案例中,如果所有图的节点数和边数是固定的,可以使用 DenseDataLoader 进行更高效的批处理。

完整代码

python 复制代码
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

from torch.utils.data import Dataset
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader, DenseDataLoader

# 定义数据集类
class MyDataset(Dataset):
    def __init__(self, num_samples, num_nodes, num_node_features):
        super(MyDataset, self).__init__()  # Initialize base class if inheriting
        self.num_samples = num_samples
        self.num_nodes = num_nodes
        self.num_node_features = num_node_features
        # 创建固定的边索引,这里简单使用环形图
        self.edge_index = torch.tensor([[i for i in range(self.num_nodes)],
                                        [(i + 1) % self.num_nodes for i in range(self.num_nodes)]], dtype=torch.long)


    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        # 创建随机特征和标签,这里仅作示例
        x = torch.randn((self.num_nodes, self.num_node_features))
        y = torch.randn((self.num_nodes, 1))  # 每个节点一个标签
        
        # 返回一个包含图数据的 Data 对象,保持相同的边索引
        return Data(x=x, edge_index=self.edge_index, y=y)

# 创建数据集和数据加载器
num_samples = 100
num_nodes = 10
num_node_features = 8
batch_size = 32

dataset = MyDataset(num_samples, num_nodes, num_node_features)
# data_list = [dataset[i] for i in range(num_samples)]
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 从 DataLoader 中获取一个批次的数据
for batch in dataloader:
    print("Batch node features shape:", batch.x.shape)
    print("Batch edge index shape:", batch.edge_index.shape)
    print("Batch labels shape:", batch.y.shape)
    break  # 仅查看第一个批次的形状

# 定义GCN模型
class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = GCNConv(in_channels, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, out_channels)

    def forward(self, x, edge_index):
        print(f"first {x.shape}")
        x = self.conv1(x, edge_index)
        print(f"conv1 {x.shape}")
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        print(f"conv2 {x.shape}")
        return x

model = GCN(in_channels=num_node_features, hidden_channels=16, out_channels=1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.MSELoss()

# 训练模型
model.train()
for epoch in range(2):
    for data in dataloader:
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        loss = criterion(out, data.y)
        loss.backward()
        optimizer.step()

# 评估模型
model.eval()
with torch.no_grad():
    for batch in dataloader:
        out = model(batch.x, batch.edge_index)
        print(out.shape) # 待实现把batch中所有features都拼接起来

写得比较乱,最后densedataloader和dataloader都可以专门来写一篇了

相关推荐
ToToBe9 小时前
L1G3000 提示工程(Prompt Engineering)
chatgpt·prompt
龙的爹23339 小时前
论文 | Legal Prompt Engineering for Multilingual Legal Judgement Prediction
人工智能·语言模型·自然语言处理·chatgpt·prompt
bytebeats11 小时前
我用 Spring AI 集成 OpenAI ChatGPT API 创建了一个 Spring Boot 小程序
spring boot·chatgpt·openai
&永恒的星河&15 小时前
Hunyuan-Large:推动AI技术进步的下一代语言模型
人工智能·语言模型·自然语言处理·chatgpt·moe·llms
我爱学Python!21 小时前
AI Prompt如何帮你提升论文中的逻辑推理部分?
人工智能·程序人生·自然语言处理·chatgpt·llm·prompt·提示词
Jet45051 天前
第100+31步 ChatGPT学习:概率校准 Quantile Calibration
学习·chatgpt·概率校准
开发者每周简报2 天前
ChatGPT o1与GPT-4o、Claude 3.5 Sonnet和Gemini 1.5 Pro的比较
人工智能·gpt·chatgpt
Topstip2 天前
在 Google Chrome 上查找并安装 SearchGPT 扩展
前端·人工智能·chrome·gpt·ai·chatgpt
科研小达人2 天前
Langchain调用模型使用FAISS
python·chatgpt·langchain·faiss
全域观察2 天前
两台手机如何提词呢,一台手机后置高清摄像一台手机前置提词+实时监测状态的解决方案来喽
大数据·人工智能·chatgpt·新媒体运营·程序员创富