import torch.nn as nn
from torch_geometric.nn import GATConv
import torch.nn.functional as F
class gat_cls(nn.Module):
def __init__(self,in_dim,hid_dim,out_dim,dropout_size=0.5):
super(gat_cls,self).__init__()
self.conv1 = GATConv(in_dim,hid_dim)
self.conv2 = GATConv(hid_dim,hid_dim)
self.fc = nn.Linear(hid_dim,out_dim)
self.relu = nn.ReLU()
self.dropout_size = dropout_size
def forward(self,x,edge_index):
x = self.conv1(x,edge_index)
x = F.dropout(x,p=self.dropout_size,training=self.training)
x = self.relu(x)
x = self.conv2(x,edge_index)
x = self.relu(x)
x = self.fc(x)
return x