Graph U-Net Code【图分类】

1. main.py

python 复制代码
# GNet是需要用到的model
net = GNet(G_data.feat_dim, G_data.num_class, args) # graph, 特征维度,类别数,参数
trainer = Trainer(args, net, G_data) #开始训练数据
# 正式开始训练数据
trainer.train()

2. network.py

python 复制代码
class GNet(nn.Module):
    def __init__(self, in_dim, n_classes, args):
        super(GNet, self).__init__()
        self.n_act = getattr(nn, args.act_n)()# getattr() 是 Python 内置的一个函数,可以用来获取一个对象的属性值或方法
        self.c_act = getattr(nn, args.act_c)()# print('GNet1: in_dim=', in_dim, 'n_class=',n_classes)  # GNet1: in_dim= 82 n_class= 2

        "用的是GCN的框架,输入分别是feat dim、layer dim、network act、drop net(net表示GCN网络本身的参数)"
        self.s_gcn = GCN(in_dim, args.l_dim, self.n_act, args.drop_n)
        self.g_unet = GraphUnet(args.ks, args.l_dim, args.l_dim, args.l_dim, self.n_act, args.drop_n)

        """nn.Linear定义一个神经网络的线性层,方法如下:
           torch.nn.Linear(in_features, # 输入的神经元个数
           out_features, # 输出神经元个数
           bias=True # 是否包含偏置)"""
        self.out_l_1 = nn.Linear(3*args.l_dim*(args.l_num+1), args.h_dim)
        self.out_l_2 = nn.Linear(args.h_dim, n_classes)

        "nn.Dropout(p = 0.3) # 表示每个神经元有0.3的可能性不被激活"
        self.out_drop = nn.Dropout(p=args.drop_c)
        Initializer.weights_init(self)

    def forward(self, gs, hs, labels):
        print('GNet2: gs=',type(gs), len(gs), 'hs=',type(hs), len(hs), 'labels:',type(labels),labels.shape)
        # GNet2: gs= <class 'list'> 32 hs= <class 'list'> 32 labels: <class 'torch.Tensor'> torch.Size([32])
        hs = self.embed(gs, hs)
        print('GNet2: hs=', type(hs), hs.shape)
        logits = self.classify(hs)
        return self.metric(logits, labels)

3. trainer.py

python 复制代码
class Trainer:
    "init初始化,输入分别是arg参数、gcn net、graph Data,将这些装进self里面"
    def __init__(self, args, net, G_data):
        self.args = args
        self.net = net
        self.feat_dim = G_data.feat_dim
        self.fold_idx = G_data.fold_idx
        self.init(args, G_data.train_gs, G_data.test_gs)
        # 若是有显卡,则用显卡跑
        if torch.cuda.is_available():
            self.net.cuda()

    "初始化------开始训练数据"
    def init(self, args, train_gs, test_gs):
        print('#train: %d, #test: %d' % (len(train_gs), len(test_gs)))
        # 分成训练集和测试集,记载数据
        train_data = GraphData(train_gs, self.feat_dim)
        test_data = GraphData(test_gs, self.feat_dim)

        # DataLoader 为pytorch 内部类,此时只需要指定trainset, batch_size, shuffle, num_workers, ...等
        self.train_d = train_data.loader(self.args.batch, True)
        self.test_d = test_data.loader(self.args.batch, False)
        self.optimizer = optim.Adam(
            self.net.parameters(), lr=self.args.lr, amsgrad=True,
            weight_decay=0.0008)
python 复制代码
    def train(self):
        max_acc = 0.0
        train_str = 'Train epoch %d: loss %.5f acc %.5f'
        test_str = 'Test epoch %d: loss %.5f acc %.5f max %.5f'
        line_str = '%d:\t%.5f\n'
        for e_id in range(self.args.num_epochs):
            self.net.train()

            # 从每个epoch开始训练
            loss, acc = self.run_epoch(e_id, self.train_d, self.net, self.optimizer)
            print(train_str % (e_id, loss, acc))

            with torch.no_grad():
                self.net.eval()
                loss, acc = self.run_epoch(e_id, self.test_d, self.net, None)
            max_acc = max(max_acc, acc)
            print(test_str % (e_id, loss, acc, max_acc))

        with open(self.args.acc_file, 'a+') as f:
            f.write(line_str % (self.fold_idx, max_acc))
python 复制代码
    def run_epoch(self, epoch, data, model, optimizer):

        #self.run_epoch(e_id, self.train_d, self.net, self.optimizer)
        losses, accs, n_samples = [], [], 0
        for batch in tqdm(data, desc=str(epoch), unit='b'):
            cur_len, gs, hs, ys = batch
            gs, hs, ys = map(self.to_cuda, [gs, hs, ys])
            loss, acc = model(gs, hs, ys)

            losses.append(loss*cur_len)
            accs.append(acc*cur_len)
            n_samples += cur_len
            if optimizer is not None:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        avg_loss, avg_acc = sum(losses) / n_samples, sum(accs) / n_samples
        return avg_loss.item(), avg_acc.item()

不懂

python 复制代码
class GraphConvolution(Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        """为啥要这么做???5555555555555555555555555555"""
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
 
        support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output
相关推荐
兵慌码乱21 小时前
面向桌面端的资产管理系统分层架构设计与核心模块实现
python·系统架构·sqlite·pyqt5·数据库设计·桌面应用开发·mvc架构
hboot1 天前
AI工程师第三课 - 机器学习基础
python·scikit-learn·kaggle
顾林海1 天前
Agent入门阶段-编程基础-Python:流程控制
python·agent·ai编程
呱呱复呱呱1 天前
Django CBV 源码解读:一个请求是怎么找到你的 get() 方法的
python·django
曲幽1 天前
刚部署的 LibreTranslate 频频翻车?我掏出了 20 年前的 StarDict 词典,用 FastAPI 搭了个本地词典翻译 API
python·fastapi·web·translate·goldendict·libretranslate·stardict·pystardict
荣码1 天前
用Streamlit给AI应用套个界面,10行代码出Web页面
java·python
武子康1 天前
调查研究-189 Kronos 调研:金融 K 线基础模型,是真突破,还是量化圈的新玩具?
人工智能·深度学习·openai
兵慌码乱2 天前
基于Python+PyQt5+SQLite的药房管理系统实现:事务一致性与界面解耦全流程解析
python·sqlite·信号与槽·pyqt5·数据库设计·桌面应用开发·事务处理
金銀銅鐵2 天前
[Python] 体验用欧几里得算法计算最大公约数的过程
python·数学
FreakStudio2 天前
W55MH32L-EVB 上手测评:硬件 TCP/IP 加持的以太网单片机,MicroPython 零门槛开发
python·单片机·嵌入式·大学生·面向对象·并行计算·电子diy·电子计算机