DataModule, Module和Trainer测试代码

train文件中的代码往往分为dataset.py, module.py, trainer.py。为了测试这三款文件中的代码,我们准备了以下TinyModule。

在x.1中是不带core.py版本,在x.2中是带core.py版本。

x.1.1 dataset.py

dataset.py主要书写Dataset派生类,测试代码如下,

python 复制代码
if __name__=="__main__":
    # test Dataset
    ds = MicroDLDataset("/home/yingmuzhi/BioAI/data/data1_output/phase2dna_microdl_patches/tiles_256-256_step_128-128",
                        True,
                        None,)
    dl = DataLoader(ds, batch_size=4, num_workers=2)
    print(next(iter(dl))[0].shape)

x.1.2 module.py

dataset.py主要书写网络结构,我们需要创建简易Dataset和简易Trainer来进行测试,代码如下,

python 复制代码
if __name__=="__main__":
    Net = "your network"
    from torch.utils.data import DataLoader, Dataset
    import torch
    class TinyDataset(Dataset):
        def __init__(self, X, Y):
            # 定义好 image 的路径
            self.X, self.Y = X, Y

        def __getitem__(self, index):
            return self.X[index], self.Y[index]

        def __len__(self):
            return len(self.X)
    class TinyTrainer():
        def fit():
            X_tensor = torch.ones((4,1,32, 256, 256))
            Y_tensor = torch.zeros((4,1,32, 256, 256))
            mydataset = TinyDataset(X_tensor, Y_tensor)
            train_loader = DataLoader(mydataset, batch_size=2, shuffle=True)

            net=Net()
            print(net)
            import torch.nn as nn
            loss_fn = nn.MSELoss()
            optimizer = torch.optim.SGD(net.parameters(), lr=1e-3)

            # 3) Training loop
            for epoch in range(10):
                for i, (X, y) in enumerate(train_loader):
                    # predict = forward pass with our model
                    pred = net(X)
                    loss = loss_fn(pred, y)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    print('epoch={},i={}'.format(epoch,i))
    TinyTrainer().fit()

x.1.3 trainer.py

trainer.py主要进行网络训练,我们需要创建极简网络进行训练,代码如下,

python 复制代码
if __name__=="__main__":
    import torch.nn as nn
    class TinyNet(nn.Module):
        def __init__(self, input=28*28, output=28*28):
            super().__init__()
            # define any number of nn.Modules (or use your current ones)
            self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
            self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
        def  forward(self, x):
            y = self.encoder(x)
            z = self.decoder(y)
            return z
    Net = TinyNet

x.2.1 dataset.py

dataset.py主要书写Dataset派生类和DataModule派生类,测试代码如下,

python 复制代码
if __name__=="__main__":
    # test Dataset
    ds = MicroDLDataset("/home/yingmuzhi/BioAI/data/data1_output/phase2dna_microdl_patches/tiles_256-256_step_128-128",
                        True,
                        None,)
    dl = DataLoader(ds, batch_size=4, num_workers=2)
    print(next(iter(dl))[0].shape)
    
    # test DataModule
    root = "/home/yingmuzhi/BioAI/data/data1_output/phase2dna_microdl_patches/tiles_256-256_step_128-128"
    dm = MicroDLDM(root=root)
    print(next(iter(dm.train_dataloader()))[0].shape)
相关推荐
蔗理苦16 分钟前
2025-08-22 Python进阶10——魔术方法
开发语言·python
、水水水水水18 分钟前
RAG学习(五)——查询构建、Text2SQL、查询重构与分发
人工智能·python·深度学习·nlp
jieyu111942 分钟前
Python 实战:内网渗透中的信息收集自动化脚本(2)
python·网络安全·脚本开发
瑶光守护者42 分钟前
【卫星通信】超低码率语音编码ULBC:EnCodec神经音频编解码器架构深度解析
深度学习·音视频·卫星通信·语音编解码·ulbc
码界筑梦坊4 小时前
171-基于Flask的笔记本电脑数据可视化分析系统
python·信息可视化·flask·毕业设计·echarts
Uzuki7 小时前
LLM 指标 | PPL vs. BLEU vs. ROUGE-L vs. METEOR vs. CIDEr
深度学习·机器学习·llm·vlm
hui函数8 小时前
Flask电影投票系统全解析
后端·python·flask
闲人编程9 小时前
Python第三方库IPFS-API使用详解:构建去中心化应用的完整指南
开发语言·python·去中心化·内存·寻址·存储·ipfs
计算机编程小咖10 小时前
《基于大数据的农产品交易数据分析与可视化系统》选题不当,毕业答辩可能直接挂科
java·大数据·hadoop·python·数据挖掘·数据分析·spark
zhangfeng113311 小时前
以下是基于图论的归一化切割(Normalized Cut)图像分割工具的完整实现,结合Tkinter界面设计及Python代码示
开发语言·python·图论