用来生成二维矩阵的dcgan

有大量二维矩阵作为样本,为连续数据。数据具有空间连续性,因此用卷积网络,通过dcgan生成二维矩阵。因为是连续变量,因此损失采用nn.MSELoss()。

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from DemDataset import create_netCDF_Dem_trainLoader
import torchvision
from torch.utils.tensorboard import SummaryWriter

batch_size=16
#load data
dataloader = create_netCDF_Dem_trainLoader(batch_size)

# Generator with Conv2D structure
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(100, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img

# Discriminator with Conv2D structure
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=1),
        )

    def forward(self, img):
        validity = self.model(img)
        return validity

# Initialize GAN components
generator = Generator()
discriminator = Discriminator()


# Define loss function and optimizers
criterion = nn.MSELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)

writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

# Training loop
num_epochs = 200
for epoch in range(num_epochs):
    for batch_idx, real_data in enumerate(dataloader):
        real_data = real_data.to(device)

        # Train Discriminator
        optimizer_D.zero_grad()
        real_labels = torch.ones(real_data.size(0), 1).to(device)
        fake_labels = torch.zeros(real_data.size(0), 1).to(device)
        z = torch.randn(real_data.size(0), 100, 1, 1).to(device)
        fake_data = generator(z)
        real_pred = discriminator(real_data)
        fake_pred = discriminator(fake_data.detach())
        d_loss_real = criterion(real_pred, real_labels)
        d_loss_fake = criterion(fake_pred, fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        z = torch.randn(real_data.size(0), 100, 1, 1).to(device)
        fake_data = generator(z)
        fake_pred = discriminator(fake_data)
        g_loss = criterion(fake_pred, real_labels)
        g_loss.backward()
        optimizer_G.step()

        # Print progress
        if batch_idx % 100 == 0:
            print(f"[Epoch {epoch}/{num_epochs}] [Batch {batch_idx}/{len(dataloader)}] [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
            with torch.no_grad():
                    img_grid_real = torchvision.utils.make_grid(
                        fake_data#, normalize=True,

                    )

                    img_grid_fake = torchvision.utils.make_grid(
                        real_data#, normalize=True
                    )

                    writer_fake.add_image("fake_img", img_grid_fake, global_step=step)
                    writer_real.add_image("real_img", img_grid_real, global_step=step)

                    step += 1

# After training, you can generate a 2D array by sampling from the generator
z = torch.randn(1, 100, 1, 1).to(device)
generated_array = generator(z)
相关推荐
硅谷秋水1 小时前
在相机空间中落地动作:以观察为中心的视觉-语言-行动策略
机器学习·计算机视觉·语言模型·机器人
游戏AI研究所1 小时前
ComfyUI 里的 Prompt 插值器(prompt interpolation / text encoder 插值方式)的含义和作用!
人工智能·游戏·机器学习·stable diffusion·prompt·aigc
Chirp1 小时前
BS-RoFormer,目前音频分离SOTA
人工智能·机器学习
九章云极AladdinEdu1 小时前
Scikit-learn通关秘籍:从鸢尾花分类到房价预测
人工智能·python·机器学习·分类·scikit-learn·gpu算力
一个天蝎座 白勺 程序猿2 小时前
Apache IoTDB(4):深度解析时序数据库 IoTDB 在Kubernetes 集群中的部署与实践指南
数据库·深度学习·kubernetes·apache·时序数据库·iotdb
抠头专注python环境配置2 小时前
Pytorch GPU版本安装保姆级教程
pytorch·python·深度学习·conda
停停的茶3 小时前
决策树(2)
算法·决策树·机器学习
ccLianLian4 小时前
深度学习·GFSS
深度学习
Ronin-Lotus12 小时前
深度学习篇---卷积核的权重
人工智能·深度学习
.银河系.12 小时前
8.18 机器学习-决策树(1)
人工智能·决策树·机器学习