DataModule, Module和Trainer测试代码

train文件中的代码往往分为dataset.py, module.py, trainer.py。为了测试这三款文件中的代码,我们准备了以下TinyModule。

在x.1中是不带core.py版本,在x.2中是带core.py版本。

x.1.1 dataset.py

dataset.py主要书写Dataset派生类,测试代码如下,

python 复制代码
if __name__=="__main__":
    # test Dataset
    ds = MicroDLDataset("/home/yingmuzhi/BioAI/data/data1_output/phase2dna_microdl_patches/tiles_256-256_step_128-128",
                        True,
                        None,)
    dl = DataLoader(ds, batch_size=4, num_workers=2)
    print(next(iter(dl))[0].shape)

x.1.2 module.py

dataset.py主要书写网络结构,我们需要创建简易Dataset和简易Trainer来进行测试,代码如下,

python 复制代码
if __name__=="__main__":
    Net = "your network"
    from torch.utils.data import DataLoader, Dataset
    import torch
    class TinyDataset(Dataset):
        def __init__(self, X, Y):
            # 定义好 image 的路径
            self.X, self.Y = X, Y

        def __getitem__(self, index):
            return self.X[index], self.Y[index]

        def __len__(self):
            return len(self.X)
    class TinyTrainer():
        def fit():
            X_tensor = torch.ones((4,1,32, 256, 256))
            Y_tensor = torch.zeros((4,1,32, 256, 256))
            mydataset = TinyDataset(X_tensor, Y_tensor)
            train_loader = DataLoader(mydataset, batch_size=2, shuffle=True)

            net=Net()
            print(net)
            import torch.nn as nn
            loss_fn = nn.MSELoss()
            optimizer = torch.optim.SGD(net.parameters(), lr=1e-3)

            # 3) Training loop
            for epoch in range(10):
                for i, (X, y) in enumerate(train_loader):
                    # predict = forward pass with our model
                    pred = net(X)
                    loss = loss_fn(pred, y)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    print('epoch={},i={}'.format(epoch,i))
    TinyTrainer().fit()

x.1.3 trainer.py

trainer.py主要进行网络训练,我们需要创建极简网络进行训练,代码如下,

python 复制代码
if __name__=="__main__":
    import torch.nn as nn
    class TinyNet(nn.Module):
        def __init__(self, input=28*28, output=28*28):
            super().__init__()
            # define any number of nn.Modules (or use your current ones)
            self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
            self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
        def  forward(self, x):
            y = self.encoder(x)
            z = self.decoder(y)
            return z
    Net = TinyNet

x.2.1 dataset.py

dataset.py主要书写Dataset派生类和DataModule派生类,测试代码如下,

python 复制代码
if __name__=="__main__":
    # test Dataset
    ds = MicroDLDataset("/home/yingmuzhi/BioAI/data/data1_output/phase2dna_microdl_patches/tiles_256-256_step_128-128",
                        True,
                        None,)
    dl = DataLoader(ds, batch_size=4, num_workers=2)
    print(next(iter(dl))[0].shape)
    
    # test DataModule
    root = "/home/yingmuzhi/BioAI/data/data1_output/phase2dna_microdl_patches/tiles_256-256_step_128-128"
    dm = MicroDLDM(root=root)
    print(next(iter(dm.train_dataloader()))[0].shape)
相关推荐
Yan-英杰12 分钟前
百度搜索和文心智能体接入DeepSeek满血版——AI搜索的新纪元
图像处理·人工智能·python·深度学习·deepseek
weixin_307779131 小时前
Azure上基于OpenAI GPT-4模型验证行政区域数据的设计方案
数据仓库·python·云计算·aws
玩电脑的辣条哥2 小时前
Python如何播放本地音乐并在web页面播放
开发语言·前端·python
taoqick2 小时前
对PosWiseFFN的改进: MoE、PKM、UltraMem
人工智能·pytorch·深度学习
多想和从前一样5 小时前
Django 创建表时 “__str__ ”方法的使用
后端·python·django
charles_vaez5 小时前
开源模型应用落地-LangGraph101-探索 LangGraph 短期记忆
深度学习·语言模型·自然语言处理
WHATEVER_LEO6 小时前
【每日论文】Latent Radiance Fields with 3D-aware 2D Representations
人工智能·深度学习·神经网络·机器学习·计算机视觉·自然语言处理
小喵要摸鱼6 小时前
【Pytorch 库】自定义数据集相关的类
pytorch·python
bdawn6 小时前
深度集成DeepSeek大模型:WebSocket流式聊天实现
python·websocket·openai·api·实时聊天·deepseek大模型·流式输出
Jackson@ML6 小时前
Python数据可视化简介
开发语言·python·数据可视化