PyTorch 实现 Conditional DCGAN(条件深度卷积生成对抗网络)进行图像到图像转换的示例代码

以下是一个使用 PyTorch 实现 Conditional DCGAN(条件深度卷积生成对抗网络)进行图像到图像转换的示例代码。该代码包含训练和可视化部分,假设输入为图片和 4 个工艺参数,根据这些输入生成相应的图片。

1. 导入必要的库

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
import matplotlib.pyplot as plt

# 检查是否有可用的 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

2. 定义数据集类

python 复制代码
class ImagePairDataset(Dataset):
    def __init__(self, image_pairs, params):
        self.image_pairs = image_pairs
        self.params = params

    def __len__(self):
        return len(self.image_pairs)

    def __getitem__(self, idx):
        input_image, target_image = self.image_pairs[idx]
        param = self.params[idx]
        return input_image, target_image, param

3. 定义生成器和判别器

python 复制代码
# 生成器
class Generator(nn.Module):
    def __init__(self, z_dim=4, img_channels=3):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # 输入: [batch_size, z_dim + 4, 1, 1]
            self._block(z_dim + 4, 1024, 4, 1, 0),  # [batch_size, 1024, 4, 4]
            self._block(1024, 512, 4, 2, 1),  # [batch_size, 512, 8, 8]
            self._block(512, 256, 4, 2, 1),  # [batch_size, 256, 16, 16]
            self._block(256, 128, 4, 2, 1),  # [batch_size, 128, 32, 32]
            nn.ConvTranspose2d(128, img_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(True)
        )

    def forward(self, z, params):
        params = params.view(params.size(0), 4, 1, 1)
        x = torch.cat([z, params], dim=1)
        return self.gen(x)

# 判别器
class Discriminator(nn.Module):
    def __init__(self, img_channels=3):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # 输入: [batch_size, img_channels + 4, 64, 64]
            nn.Conv2d(img_channels + 4, 64, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            self._block(64, 128, 4, 2, 1),  # [batch_size, 128, 16, 16]
            self._block(128, 256, 4, 2, 1),  # [batch_size, 256, 8, 8]
            self._block(256, 512, 4, 2, 1),  # [batch_size, 512, 4, 4]
            nn.Conv2d(512, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid()
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels, out_channels, kernel_size, stride, padding, bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, img, params):
        params = params.view(params.size(0), 4, 1, 1).repeat(1, 1, img.size(2), img.size(3))
        x = torch.cat([img, params], dim=1)
        return self.disc(x)

4. 训练代码

python 复制代码
def train_conditional_dcgan(image_pairs, params, batch_size=32, epochs=10, lr=0.0002, z_dim=4):
    dataset = ImagePairDataset(image_pairs, params)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    gen = Generator(z_dim).to(device)
    disc = Discriminator().to(device)

    criterion = nn.BCELoss()
    opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_disc = optim.Adam(disc.parameters(), lr=lr, betas=(0.5, 0.999))

    for epoch in range(epochs):
        for i, (input_images, target_images, param) in enumerate(dataloader):
            input_images = input_images.to(device)
            target_images = target_images.to(device)
            param = param.to(device)

            # 训练判别器
            opt_disc.zero_grad()
            real_labels = torch.ones((target_images.size(0), 1, 1, 1)).to(device)
            fake_labels = torch.zeros((target_images.size(0), 1, 1, 1)).to(device)

            # 计算判别器对真实图像的损失
            real_output = disc(target_images, param)
            d_real_loss = criterion(real_output, real_labels)

            # 生成假图像
            z = torch.randn(target_images.size(0), z_dim, 1, 1).to(device)
            fake_images = gen(z, param)

            # 计算判别器对假图像的损失
            fake_output = disc(fake_images.detach(), param)
            d_fake_loss = criterion(fake_output, fake_labels)

            # 总判别器损失
            d_loss = d_real_loss + d_fake_loss
            d_loss.backward()
            opt_disc.step()

            # 训练生成器
            opt_gen.zero_grad()
            output = disc(fake_images, param)
            g_loss = criterion(output, real_labels)
            g_loss.backward()
            opt_gen.step()

        print(f'Epoch [{epoch+1}/{epochs}] D_loss: {d_loss.item():.4f} G_loss: {g_loss.item():.4f}')

    return gen

5. 可视化代码

python 复制代码
def visualize_generated_images(gen, input_images, params, z_dim=4):
    input_images = input_images.to(device)
    params = params.to(device)
    z = torch.randn(input_images.size(0), z_dim, 1, 1).to(device)
    fake_images = gen(z, params).cpu().detach()

    fig, axes = plt.subplots(1, input_images.size(0), figsize=(15, 3))
    for i in range(input_images.size(0)):
        img = fake_images[i].permute(1, 2, 0).numpy()
        img = (img + 1) / 2  # 从 [-1, 1] 转换到 [0, 1]
        axes[i].imshow(img)
        axes[i].axis('off')
    plt.show()

6. 示例使用

python 复制代码
# 假设 image_pairs 是一个包含图像对的列表,params 是一个包含 4 个工艺参数的列表
image_pairs = []  # 这里需要替换为实际的图像对数据
params = []  # 这里需要替换为实际的工艺参数数据

# 训练模型
gen = train_conditional_dcgan(image_pairs, params)

# 可视化生成的图像
test_input_images, test_target_images, test_params = image_pairs[:5], image_pairs[:5], params[:5]
test_input_images = torch.stack([torch.tensor(img) for img in test_input_images]).float()
test_params = torch.tensor(test_params).float()
visualize_generated_images(gen, test_input_images, test_params)

代码说明

  1. 数据集类ImagePairDataset 用于加载图像对和工艺参数。
  2. 生成器和判别器GeneratorDiscriminator 分别定义了生成器和判别器的网络结构。
  3. 训练代码train_conditional_dcgan 函数用于训练 Conditional DCGAN 模型。
  4. 可视化代码visualize_generated_images 函数用于可视化生成的图像。
  5. 示例使用:最后部分展示了如何使用上述函数进行训练和可视化。

请注意,你需要将 image_pairsparams 替换为实际的数据集。此外,代码中的超参数(如 batch_sizeepochslr 等)可以根据实际情况进行调整。

相关推荐
边缘计算社区34 分钟前
FPGA与边缘AI:计算革命的前沿力量
人工智能·fpga开发
飞哥数智坊42 分钟前
打工人周末充电:15条AI资讯助你领先一小步
人工智能
Tech Synapse1 小时前
基于CARLA与PyTorch的自动驾驶仿真系统全栈开发指南
人工智能·opencv·sqlite
layneyao1 小时前
深度强化学习(DRL)实战:从AlphaGo到自动驾驶
人工智能·机器学习·自动驾驶
海特伟业2 小时前
隧道调频广播覆盖的实现路径:隧道无线广播技术赋能行车安全升级,隧道汽车广播收音系统助力隧道安全管理升级
人工智能
CareyWYR2 小时前
每周AI论文速递(250421-250425)
人工智能
追逐☞2 小时前
机器学习(10)——神经网络
人工智能·神经网络·机器学习
winner88812 小时前
对抗学习:机器学习里的 “零和博弈”,如何实现 “双赢”?
人工智能·机器学习·gan·对抗学习
Elastic 中国社区官方博客2 小时前
使用 LangGraph 和 Elasticsearch 构建强大的 RAG 工作流
大数据·数据库·人工智能·elasticsearch·搜索引擎·ai·全文检索
娃娃略2 小时前
【AI模型学习】双流网络——更强大的网络设计
网络·人工智能·pytorch·python·神经网络·学习