DataModule, Module和Trainer测试代码

train文件中的代码往往分为dataset.py, module.py, trainer.py。为了测试这三款文件中的代码,我们准备了以下TinyModule。

在x.1中是不带core.py版本,在x.2中是带core.py版本。

x.1.1 dataset.py

dataset.py主要书写Dataset派生类,测试代码如下,

python 复制代码
if __name__=="__main__":
    # test Dataset
    ds = MicroDLDataset("/home/yingmuzhi/BioAI/data/data1_output/phase2dna_microdl_patches/tiles_256-256_step_128-128",
                        True,
                        None,)
    dl = DataLoader(ds, batch_size=4, num_workers=2)
    print(next(iter(dl))[0].shape)

x.1.2 module.py

dataset.py主要书写网络结构,我们需要创建简易Dataset和简易Trainer来进行测试,代码如下,

python 复制代码
if __name__=="__main__":
    Net = "your network"
    from torch.utils.data import DataLoader, Dataset
    import torch
    class TinyDataset(Dataset):
        def __init__(self, X, Y):
            # 定义好 image 的路径
            self.X, self.Y = X, Y

        def __getitem__(self, index):
            return self.X[index], self.Y[index]

        def __len__(self):
            return len(self.X)
    class TinyTrainer():
        def fit():
            X_tensor = torch.ones((4,1,32, 256, 256))
            Y_tensor = torch.zeros((4,1,32, 256, 256))
            mydataset = TinyDataset(X_tensor, Y_tensor)
            train_loader = DataLoader(mydataset, batch_size=2, shuffle=True)

            net=Net()
            print(net)
            import torch.nn as nn
            loss_fn = nn.MSELoss()
            optimizer = torch.optim.SGD(net.parameters(), lr=1e-3)

            # 3) Training loop
            for epoch in range(10):
                for i, (X, y) in enumerate(train_loader):
                    # predict = forward pass with our model
                    pred = net(X)
                    loss = loss_fn(pred, y)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    print('epoch={},i={}'.format(epoch,i))
    TinyTrainer().fit()

x.1.3 trainer.py

trainer.py主要进行网络训练,我们需要创建极简网络进行训练,代码如下,

python 复制代码
if __name__=="__main__":
    import torch.nn as nn
    class TinyNet(nn.Module):
        def __init__(self, input=28*28, output=28*28):
            super().__init__()
            # define any number of nn.Modules (or use your current ones)
            self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
            self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
        def  forward(self, x):
            y = self.encoder(x)
            z = self.decoder(y)
            return z
    Net = TinyNet

x.2.1 dataset.py

dataset.py主要书写Dataset派生类和DataModule派生类,测试代码如下,

python 复制代码
if __name__=="__main__":
    # test Dataset
    ds = MicroDLDataset("/home/yingmuzhi/BioAI/data/data1_output/phase2dna_microdl_patches/tiles_256-256_step_128-128",
                        True,
                        None,)
    dl = DataLoader(ds, batch_size=4, num_workers=2)
    print(next(iter(dl))[0].shape)
    
    # test DataModule
    root = "/home/yingmuzhi/BioAI/data/data1_output/phase2dna_microdl_patches/tiles_256-256_step_128-128"
    dm = MicroDLDM(root=root)
    print(next(iter(dm.train_dataloader()))[0].shape)
相关推荐
这里是小悦同学呀!几秒前
python学习day2
java·python·学习
FL16238631292 小时前
荔枝成熟度分割数据集labelme格式2263张3类别
人工智能·深度学习
我是你们的星光3 小时前
基于深度学习的高效图像失真校正框架总结
人工智能·深度学习·计算机视觉·3d
未来可期叶4 小时前
如何用Python批量解压ZIP文件?快速解决方案
python
张槊哲4 小时前
ROS2架构介绍
python·架构
风逸hhh5 小时前
python打卡day29@浙大疏锦行
开发语言·前端·python
浩皓素5 小时前
深入理解For循环及相关关键字原理:以Python和C语言为例
c语言·python
英英_5 小时前
详细介绍一下Python连接MySQL数据库的完整步骤
数据库·python·mysql
水花花花花花5 小时前
GloVe 模型讲解与实战
python·深度学习·conda·pip
C_VuI5 小时前
如何安装cuda版本的pytorch
人工智能·pytorch·python