用来生成二维矩阵的dcgan

有大量二维矩阵作为样本,为连续数据。数据具有空间连续性,因此用卷积网络,通过dcgan生成二维矩阵。因为是连续变量,因此损失采用nn.MSELoss()。

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from DemDataset import create_netCDF_Dem_trainLoader
import torchvision
from torch.utils.tensorboard import SummaryWriter

batch_size=16
#load data
dataloader = create_netCDF_Dem_trainLoader(batch_size)

# Generator with Conv2D structure
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(100, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 512, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        return img

# Discriminator with Conv2D structure
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=1),
        )

    def forward(self, img):
        validity = self.model(img)
        return validity

# Initialize GAN components
generator = Generator()
discriminator = Discriminator()


# Define loss function and optimizers
criterion = nn.MSELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator.to(device)
discriminator.to(device)

writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

# Training loop
num_epochs = 200
for epoch in range(num_epochs):
    for batch_idx, real_data in enumerate(dataloader):
        real_data = real_data.to(device)

        # Train Discriminator
        optimizer_D.zero_grad()
        real_labels = torch.ones(real_data.size(0), 1).to(device)
        fake_labels = torch.zeros(real_data.size(0), 1).to(device)
        z = torch.randn(real_data.size(0), 100, 1, 1).to(device)
        fake_data = generator(z)
        real_pred = discriminator(real_data)
        fake_pred = discriminator(fake_data.detach())
        d_loss_real = criterion(real_pred, real_labels)
        d_loss_fake = criterion(fake_pred, fake_labels)
        d_loss = d_loss_real + d_loss_fake
        d_loss.backward()
        optimizer_D.step()

        # Train Generator
        optimizer_G.zero_grad()
        z = torch.randn(real_data.size(0), 100, 1, 1).to(device)
        fake_data = generator(z)
        fake_pred = discriminator(fake_data)
        g_loss = criterion(fake_pred, real_labels)
        g_loss.backward()
        optimizer_G.step()

        # Print progress
        if batch_idx % 100 == 0:
            print(f"[Epoch {epoch}/{num_epochs}] [Batch {batch_idx}/{len(dataloader)}] [D loss: {d_loss.item():.4f}] [G loss: {g_loss.item():.4f}]")
            with torch.no_grad():
                    img_grid_real = torchvision.utils.make_grid(
                        fake_data#, normalize=True,

                    )

                    img_grid_fake = torchvision.utils.make_grid(
                        real_data#, normalize=True
                    )

                    writer_fake.add_image("fake_img", img_grid_fake, global_step=step)
                    writer_real.add_image("real_img", img_grid_real, global_step=step)

                    step += 1

# After training, you can generate a 2D array by sampling from the generator
z = torch.randn(1, 100, 1, 1).to(device)
generated_array = generator(z)
相关推荐
美狐美颜sdk41 分钟前
直播美颜工具架构设计与性能优化实战:美颜SDK集成与实时处理
深度学习·美颜sdk·第三方美颜sdk·视频美颜sdk·美颜api
curemoon1 小时前
理解都远正态分布中指数项的精度矩阵(协方差逆矩阵)
人工智能·算法·矩阵
Fansv5872 小时前
深度学习-6.用于计算机视觉的深度学习
人工智能·深度学习·计算机视觉
deephub3 小时前
LLM高效推理:KV缓存与分页注意力机制深度解析
人工智能·深度学习·语言模型
奋斗的袍子0073 小时前
Spring AI + Ollama 实现调用DeepSeek-R1模型API
人工智能·spring boot·深度学习·spring·springai·deepseek
青衫弦语3 小时前
【论文精读】VLM-AD:通过视觉-语言模型监督实现端到端自动驾驶
人工智能·深度学习·语言模型·自然语言处理·自动驾驶
美狐美颜sdk3 小时前
直播美颜SDK的底层技术解析:图像处理与深度学习的结合
图像处理·人工智能·深度学习·直播美颜sdk·视频美颜sdk·美颜api·滤镜sdk
WHATEVER_LEO3 小时前
【每日论文】Text-guided Sparse Voxel Pruning for Efficient 3D Visual Grounding
人工智能·深度学习·神经网络·算法·机器学习·自然语言处理
Binary Oracle3 小时前
RNN中远距离时间步梯度消失问题及解决办法
人工智能·rnn·深度学习
阿_旭4 小时前
基于YOLO11深度学习的糖尿病视网膜病变检测与诊断系统【python源码+Pyqt5界面+数据集+训练代码】
人工智能·python·深度学习·视网膜病变检测