用来生成二维矩阵的dcgan

有大量二维矩阵作为样本,为连续数据。数据具有空间连续性,因此用卷积网络,通过dcgan生成二维矩阵。因为是连续变量,因此损失采用nn.MSELoss()。

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from DemDataset import create_netCDF_Dem_trainLoader
import torchvision
from torch.utils.tensorboard import SummaryWriter

batch_size=16
#load data
dataloader = create_netCDF_Dem_trainLoader(batch_size)

# Generator with Conv2D structure
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(100, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img

# Discriminator with Conv2D structure
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=1),
        )

    def forward(self, img):
        validity = self.model(img)
        return validity

# Initialize GAN components
generator = Generator()
discriminator = Discriminator()


# Define loss function and optimizers
criterion = nn.MSELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)

writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

# Training loop
num_epochs = 200
for epoch in range(num_epochs):
    for batch_idx, real_data in enumerate(dataloader):
        real_data = real_data.to(device)

        # Train Discriminator
        optimizer_D.zero_grad()
        real_labels = torch.ones(real_data.size(0), 1).to(device)
        fake_labels = torch.zeros(real_data.size(0), 1).to(device)
        z = torch.randn(real_data.size(0), 100, 1, 1).to(device)
        fake_data = generator(z)
        real_pred = discriminator(real_data)
        fake_pred = discriminator(fake_data.detach())
        d_loss_real = criterion(real_pred, real_labels)
        d_loss_fake = criterion(fake_pred, fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        z = torch.randn(real_data.size(0), 100, 1, 1).to(device)
        fake_data = generator(z)
        fake_pred = discriminator(fake_data)
        g_loss = criterion(fake_pred, real_labels)
        g_loss.backward()
        optimizer_G.step()

        # Print progress
        if batch_idx % 100 == 0:
            print(f"[Epoch {epoch}/{num_epochs}] [Batch {batch_idx}/{len(dataloader)}] [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
            with torch.no_grad():
                    img_grid_real = torchvision.utils.make_grid(
                        fake_data#, normalize=True,

                    )

                    img_grid_fake = torchvision.utils.make_grid(
                        real_data#, normalize=True
                    )

                    writer_fake.add_image("fake_img", img_grid_fake, global_step=step)
                    writer_real.add_image("real_img", img_grid_real, global_step=step)

                    step += 1

# After training, you can generate a 2D array by sampling from the generator
z = torch.randn(1, 100, 1, 1).to(device)
generated_array = generator(z)
相关推荐
ARM+FPGA+AI工业主板定制专家1 小时前
基于GPS/PTP/gPTP的自动驾驶数据同步授时方案
人工智能·机器学习·自动驾驶
lisw056 小时前
SolidWorks:现代工程设计与数字制造的核心平台
人工智能·机器学习·青少年编程·软件工程·制造
学Linux的语莫6 小时前
机器学习数据处理
java·算法·机器学习
递归不收敛8 小时前
吴恩达机器学习课程(PyTorch适配)学习笔记:1.3 特征工程与模型优化
pytorch·学习·机器学习
小关会打代码8 小时前
深度学习之YOLO系列YOLOv1
人工智能·深度学习·yolo
一车小面包8 小时前
Transformer Decoder 中序列掩码(Sequence Mask / Look-ahead Mask)
人工智能·深度学习·transformer
B站_计算机毕业设计之家9 小时前
机器学习实战项目:Python+Flask 汽车销量分析可视化系统(requests爬车主之家+可视化 源码+文档)✅
人工智能·python·机器学习·数据分析·flask·汽车·可视化
渡我白衣10 小时前
深度学习入门(一)——从神经元到损失函数,一步步理解前向传播(下)
人工智能·深度学习·神经网络
小虎鲸0012 小时前
PyTorch的安装与使用
人工智能·pytorch·python·深度学习
lucky_syq13 小时前
解锁特征工程:机器学习的秘密武器
人工智能·机器学习