DataModule, Module和Trainer测试代码

train文件中的代码往往分为dataset.py, module.py, trainer.py。为了测试这三款文件中的代码,我们准备了以下TinyModule。

在x.1中是不带core.py版本,在x.2中是带core.py版本。

x.1.1 dataset.py

dataset.py主要书写Dataset派生类,测试代码如下,

python 复制代码
if __name__=="__main__":
    # test Dataset
    ds = MicroDLDataset("/home/yingmuzhi/BioAI/data/data1_output/phase2dna_microdl_patches/tiles_256-256_step_128-128",
                        True,
                        None,)
    dl = DataLoader(ds, batch_size=4, num_workers=2)
    print(next(iter(dl))[0].shape)

x.1.2 module.py

dataset.py主要书写网络结构,我们需要创建简易Dataset和简易Trainer来进行测试,代码如下,

python 复制代码
if __name__=="__main__":
    Net = "your network"
    from torch.utils.data import DataLoader, Dataset
    import torch
    class TinyDataset(Dataset):
        def __init__(self, X, Y):
            # 定义好 image 的路径
            self.X, self.Y = X, Y

        def __getitem__(self, index):
            return self.X[index], self.Y[index]

        def __len__(self):
            return len(self.X)
    class TinyTrainer():
        def fit():
            X_tensor = torch.ones((4,1,32, 256, 256))
            Y_tensor = torch.zeros((4,1,32, 256, 256))
            mydataset = TinyDataset(X_tensor, Y_tensor)
            train_loader = DataLoader(mydataset, batch_size=2, shuffle=True)

            net=Net()
            print(net)
            import torch.nn as nn
            loss_fn = nn.MSELoss()
            optimizer = torch.optim.SGD(net.parameters(), lr=1e-3)

            # 3) Training loop
            for epoch in range(10):
                for i, (X, y) in enumerate(train_loader):
                    # predict = forward pass with our model
                    pred = net(X)
                    loss = loss_fn(pred, y)

                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    print('epoch={},i={}'.format(epoch,i))
    TinyTrainer().fit()

x.1.3 trainer.py

trainer.py主要进行网络训练,我们需要创建极简网络进行训练,代码如下,

python 复制代码
if __name__=="__main__":
    import torch.nn as nn
    class TinyNet(nn.Module):
        def __init__(self, input=28*28, output=28*28):
            super().__init__()
            # define any number of nn.Modules (or use your current ones)
            self.encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
            self.decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))
        def  forward(self, x):
            y = self.encoder(x)
            z = self.decoder(y)
            return z
    Net = TinyNet

x.2.1 dataset.py

dataset.py主要书写Dataset派生类和DataModule派生类,测试代码如下,

python 复制代码
if __name__=="__main__":
    # test Dataset
    ds = MicroDLDataset("/home/yingmuzhi/BioAI/data/data1_output/phase2dna_microdl_patches/tiles_256-256_step_128-128",
                        True,
                        None,)
    dl = DataLoader(ds, batch_size=4, num_workers=2)
    print(next(iter(dl))[0].shape)
    
    # test DataModule
    root = "/home/yingmuzhi/BioAI/data/data1_output/phase2dna_microdl_patches/tiles_256-256_step_128-128"
    dm = MicroDLDM(root=root)
    print(next(iter(dm.train_dataloader()))[0].shape)
相关推荐
九章云极AladdinEdu2 小时前
超参数自动化调优指南:Optuna vs. Ray Tune 对比评测
运维·人工智能·深度学习·ai·自动化·gpu算力
酷飞飞3 小时前
Python网络与多任务编程:TCP/UDP实战指南
网络·python·tcp/ip
研梦非凡4 小时前
ICCV 2025|从粗到细:用于高效3D高斯溅射的可学习离散小波变换
人工智能·深度学习·学习·3d
数字化顾问4 小时前
Python:OpenCV 教程——从传统视觉到深度学习:YOLOv8 与 OpenCV DNN 模块协同实现工业缺陷检测
python
学生信的大叔5 小时前
【Python自动化】Ubuntu24.04配置Selenium并测试
python·selenium·自动化
诗句藏于尽头6 小时前
Django模型与数据库表映射的两种方式
数据库·python·django
通街市密人有6 小时前
IDF: Iterative Dynamic Filtering Networks for Generalizable Image Denoising
人工智能·深度学习·计算机视觉
智数研析社7 小时前
9120 部 TMDb 高分电影数据集 | 7 列全维度指标 (评分 / 热度 / 剧情)+API 权威源 | 电影趋势分析 / 推荐系统 / NLP 建模用
大数据·人工智能·python·深度学习·数据分析·数据集·数据清洗
扯淡的闲人7 小时前
多语言编码Agent解决方案(5)-IntelliJ插件实现
开发语言·python
moxiaoran57537 小时前
Flask学习笔记(一)
后端·python·flask