DataModule, Module和Trainer测试代码

train文件中的代码往往分为dataset.py, module.py, trainer.py。为了测试这三款文件中的代码,我们准备了以下TinyModule。

在x.1中是不带core.py版本,在x.2中是带core.py版本。

x.1.1 dataset.py

dataset.py主要书写Dataset派生类,测试代码如下,

python 复制代码
if __name__=="__main__":
    # test Dataset
    ds = MicroDLDataset("/home/yingmuzhi/BioAI/data/data1_output/phase2dna_microdl_patches/tiles_256-256_step_128-128",
                        True,
                        None,)
    dl = DataLoader(ds, batch_size=4, num_workers=2)
    print(next(iter(dl))[0].shape)

x.1.2 module.py

dataset.py主要书写网络结构,我们需要创建简易Dataset和简易Trainer来进行测试,代码如下,

python 复制代码
if __name__=="__main__":
    Net = "your network"
    from torch.utils.data import DataLoader, Dataset
    import torch
    class TinyDataset(Dataset):
        def __init__(self, X, Y):
            # 定义好 image 的路径
            self.X, self.Y = X, Y

        def __getitem__(self, index):
            return self.X[index], self.Y[index]

        def __len__(self):
            return len(self.X)
    class TinyTrainer():
        def fit():
            X_tensor = torch.ones((4,1,32, 256, 256))
            Y_tensor = torch.zeros((4,1,32, 256, 256))
            mydataset = TinyDataset(X_tensor, Y_tensor)
            train_loader = DataLoader(mydataset, batch_size=2, shuffle=True)

            net=Net()
            print(net)
            import torch.nn as nn
            loss_fn = nn.MSELoss()
            optimizer = torch.optim.SGD(net.parameters(), lr=1e-3)

            # 3) Training loop
            for epoch in range(10):
                for i, (X, y) in enumerate(train_loader):
                    # predict = forward pass with our model
                    pred = net(X)
                    loss = loss_fn(pred, y)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    print('epoch={},i={}'.format(epoch,i))
    TinyTrainer().fit()

x.1.3 trainer.py

trainer.py主要进行网络训练,我们需要创建极简网络进行训练,代码如下,

python 复制代码
if __name__=="__main__":
    import torch.nn as nn
    class TinyNet(nn.Module):
        def __init__(self, input=28*28, output=28*28):
            super().__init__()
            # define any number of nn.Modules (or use your current ones)
            self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
            self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
        def  forward(self, x):
            y = self.encoder(x)
            z = self.decoder(y)
            return z
    Net = TinyNet

x.2.1 dataset.py

dataset.py主要书写Dataset派生类和DataModule派生类,测试代码如下,

python 复制代码
if __name__=="__main__":
    # test Dataset
    ds = MicroDLDataset("/home/yingmuzhi/BioAI/data/data1_output/phase2dna_microdl_patches/tiles_256-256_step_128-128",
                        True,
                        None,)
    dl = DataLoader(ds, batch_size=4, num_workers=2)
    print(next(iter(dl))[0].shape)
    
    # test DataModule
    root = "/home/yingmuzhi/BioAI/data/data1_output/phase2dna_microdl_patches/tiles_256-256_step_128-128"
    dm = MicroDLDM(root=root)
    print(next(iter(dm.train_dataloader()))[0].shape)
相关推荐
小兵张健3 分钟前
Java + Spring 到 Python + FastAPI (一)
java·python·spring
CoovallyAIHub32 分钟前
让Qwen-VL的检测能力像YOLO一样强,VLM-FO1如何打通大模型的视觉任督二脉
深度学习·算法·计算机视觉
盼小辉丶38 分钟前
TensorFlow深度学习实战(43)——TensorFlow.js
javascript·深度学习·tensorflow
2401_8414956443 分钟前
【自然语言处理】基于统计基的句子边界检测算法
人工智能·python·算法·机器学习·自然语言处理·统计学习·句子边界检测算法
程序员爱钓鱼44 分钟前
Python编程实战 - Python实用工具与库 - 操作Word:python-docx
后端·python
程序员爱钓鱼1 小时前
Python编程实战 - Python实用工具与库 - 操作PDF:pdfplumber、PyPDF2
后端·python
啾啾啾6661 小时前
连接一个新的服务器时,打开PyCharm时报错:报错内容是服务器磁盘或配额满了
python·pycharm
长不大的蜡笔小新1 小时前
掌握NumPy:ndarray核心特性与创建
开发语言·python·numpy
CoovallyAIHub1 小时前
突破跨模态识别瓶颈!火箭军工程大学提出MFENet:让AI在白天黑夜都能准确识人
深度学习·算法·计算机视觉
luoganttcc1 小时前
已知 空间 三个 A,B C 点 ,求 顺序 经过 A B C 三点 圆弧 轨迹 ,给出 python 代码 并且 画出图像
c语言·开发语言·python