用来生成二维矩阵的dcgan

有大量二维矩阵作为样本,为连续数据。数据具有空间连续性,因此用卷积网络,通过dcgan生成二维矩阵。因为是连续变量,因此损失采用nn.MSELoss()。

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from DemDataset import create_netCDF_Dem_trainLoader
import torchvision
from torch.utils.tensorboard import SummaryWriter

batch_size=16
#load data
dataloader = create_netCDF_Dem_trainLoader(batch_size)

# Generator with Conv2D structure
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(100, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img

# Discriminator with Conv2D structure
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=1),
        )

    def forward(self, img):
        validity = self.model(img)
        return validity

# Initialize GAN components
generator = Generator()
discriminator = Discriminator()


# Define loss function and optimizers
criterion = nn.MSELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)

writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

# Training loop
num_epochs = 200
for epoch in range(num_epochs):
    for batch_idx, real_data in enumerate(dataloader):
        real_data = real_data.to(device)

        # Train Discriminator
        optimizer_D.zero_grad()
        real_labels = torch.ones(real_data.size(0), 1).to(device)
        fake_labels = torch.zeros(real_data.size(0), 1).to(device)
        z = torch.randn(real_data.size(0), 100, 1, 1).to(device)
        fake_data = generator(z)
        real_pred = discriminator(real_data)
        fake_pred = discriminator(fake_data.detach())
        d_loss_real = criterion(real_pred, real_labels)
        d_loss_fake = criterion(fake_pred, fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        z = torch.randn(real_data.size(0), 100, 1, 1).to(device)
        fake_data = generator(z)
        fake_pred = discriminator(fake_data)
        g_loss = criterion(fake_pred, real_labels)
        g_loss.backward()
        optimizer_G.step()

        # Print progress
        if batch_idx % 100 == 0:
            print(f"[Epoch {epoch}/{num_epochs}] [Batch {batch_idx}/{len(dataloader)}] [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
            with torch.no_grad():
                    img_grid_real = torchvision.utils.make_grid(
                        fake_data#, normalize=True,

                    )

                    img_grid_fake = torchvision.utils.make_grid(
                        real_data#, normalize=True
                    )

                    writer_fake.add_image("fake_img", img_grid_fake, global_step=step)
                    writer_real.add_image("real_img", img_grid_real, global_step=step)

                    step += 1

# After training, you can generate a 2D array by sampling from the generator
z = torch.randn(1, 100, 1, 1).to(device)
generated_array = generator(z)
相关推荐
硅谷秋水4 分钟前
GAIA-2:用于自动驾驶的可控多视图生成世界模型
人工智能·机器学习·自动驾驶
多巴胺与内啡肽.7 分钟前
深度学习--自然语言处理统计语言与神经语言模型
深度学习·语言模型·自然语言处理
深度之眼31 分钟前
2025时间序列都有哪些创新点可做——总结篇
人工智能·深度学习·机器学习·时间序列
不吃香菜?3 小时前
PyTorch 实现食物图像分类实战:从数据处理到模型训练
人工智能·深度学习
Olafur_zbj3 小时前
【EDA】EDA中聚类(Clustering)和划分(Partitioning)
机器学习·数据挖掘·聚类
Light603 小时前
智启未来:深度解析Python Transformers库及其应用场景
开发语言·python·深度学习·自然语言处理·预训练模型·transformers库 |·|应用场景
数据智能老司机3 小时前
构建具备自主性的人工智能系统——在生成式人工智能系统中构建信任
深度学习·llm·aigc
sduwcgg4 小时前
kaggle配置
人工智能·python·机器学习
谦行5 小时前
工欲善其事,必先利其器—— PyTorch 深度学习基础操作
pytorch·深度学习·ai编程
xwz小王子5 小时前
Nature Communications 面向形状可编程磁性软材料的数据驱动设计方法—基于随机设计探索与神经网络的协同优化框架
深度学习