PyTorch训练深度卷积生成对抗网络DCGAN

文章目录

DCGAN介绍

将CNN和GAN结合起来,把监督学习和无监督学习结合起来。具体解释可以参见 深度卷积对抗生成网络(DCGAN)

DCGAN的生成器结构:

图片来源:https://arxiv.org/abs/1511.06434

代码

model.py

cpp 复制代码
import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # Input: N x channels_img x 64 x 64
            nn.Conv2d(
                channels_img, features_d, kernel_size=4, stride=2, padding=1
            ), # 32 x 32
            nn.LeakyReLU(0.2),
            self._block(features_d, features_d*2, 4, 2, 1), # 16 x 16
            self._block(features_d*2, features_d*4, 4, 2, 1), # 8 x 8
            self._block(features_d*4, features_d*8, 4, 2, 1), # 4 x 4
            nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # 1 x 1
            nn.Sigmoid(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)
    
class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # Input: N x z_dim x 1 x 1
            self._block(z_dim, features_g*16, 4, 1, 0), # N x f_g*16 x 4 x 4
            self._block(features_g*16, features_g*8, 4, 2, 1), # 8x8
            self._block(features_g*8, features_g*4, 4, 2, 1), # 16x16
            self._block(features_g*4, features_g*2, 4, 2, 1), # 32x32
            nn.ConvTranspose2d(
                features_g*2, channels_img, kernel_size=4, stride=2, padding=1,
            ),
            nn.Tanh(),

        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    
    def forward(self, x):
        return self.gen(x)


def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

def test():
    N, in_channels, H, W = 8, 3, 64, 64
    z_dim = 100
    x = torch.randn((N, in_channels, H, W))
    disc = Discriminator(in_channels, 8)
    initialize_weights(disc)
    assert disc(x).shape == (N, 1, 1, 1)

    gen = Generator(z_dim, in_channels, 8)
    initialize_weights(gen)
    z = torch.randn((N, z_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, H, W)
    print("success")
    
if __name__ == "__main__":
    test()

训练使用的数据集:CelebA dataset (Images Only) 总共1.3GB的图片,使用方法,将其解压到当前目录

图片如下图所示:

train.py

cpp 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator, initialize_weights

# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4  # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 3 # 1 if MNIST dataset; 3 if celeb dataset
NOISE_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64

transforms = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

# If you train on MNIST, remember to set channels_img to 1
# dataset = datasets.MNIST(
#     root="dataset/", train=True, transform=transforms, download=True
# )

# comment mnist above and uncomment below if train on CelebA

# If you train on celeb dataset, remember to set channels_img to 3
dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()

fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen.train()
disc.train()

for epoch in range(NUM_EPOCHS):
    # Target labels not needed! <3 unsupervised
    for batch_idx, (real, _) in enumerate(dataloader):
        real = real.to(device)
        noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)
        fake = gen(noise)

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1

结果

训练5个epoch,部分结果如下:

cpp 复制代码
Epoch [3/5] Batch 1500/1583                   Loss D: 0.4996, loss G: 1.1738
Epoch [4/5] Batch 0/1583                   Loss D: 0.4268, loss G: 1.6633
Epoch [4/5] Batch 100/1583                   Loss D: 0.4841, loss G: 1.7475
Epoch [4/5] Batch 200/1583                   Loss D: 0.5094, loss G: 1.2376
Epoch [4/5] Batch 300/1583                   Loss D: 0.4376, loss G: 2.1271
Epoch [4/5] Batch 400/1583                   Loss D: 0.4173, loss G: 1.4380
Epoch [4/5] Batch 500/1583                   Loss D: 0.5213, loss G: 2.1665
Epoch [4/5] Batch 600/1583                   Loss D: 0.5036, loss G: 2.1079
Epoch [4/5] Batch 700/1583                   Loss D: 0.5158, loss G: 1.0579
Epoch [4/5] Batch 800/1583                   Loss D: 0.5426, loss G: 1.9427
Epoch [4/5] Batch 900/1583                   Loss D: 0.4721, loss G: 1.2659
Epoch [4/5] Batch 1000/1583                   Loss D: 0.5662, loss G: 2.4537
Epoch [4/5] Batch 1100/1583                   Loss D: 0.5604, loss G: 0.8978
Epoch [4/5] Batch 1200/1583                   Loss D: 0.4085, loss G: 2.0747
Epoch [4/5] Batch 1300/1583                   Loss D: 1.1894, loss G: 0.1825
Epoch [4/5] Batch 1400/1583                   Loss D: 0.4518, loss G: 2.1509
Epoch [4/5] Batch 1500/1583                   Loss D: 0.3814, loss G: 1.9391

使用

cpp 复制代码
tensorboard --logdir=logs

打开tensorboard

参考

1 DCGAN implementation from scratch

2 https://arxiv.org/abs/1511.06434

相关推荐
程序猿追5 天前
那个右下角的小数字怎么“卡”住我打字——我用 HarmonyOS 自己写了一个字数限制输入框
pytorch·华为·harmonyos
xiao5kou4chang6kai45 天前
MATLAB机器学习、深度学习--从数据预处理到模型训练
深度学习·机器学习·matlab·数据预处理
renhongxia15 天前
世界模型作为AGI落地底层底座的作用
人工智能·深度学习·生成对抗网络·自然语言处理·知识图谱·agi
计算机科研狗@OUC5 天前
(cvpr26) AIMDepth: Asymmetric Image-Event Mamba for Monocular Depth Estimation
人工智能·深度学习·计算机视觉
闵孚龙5 天前
《PyTorch 深度修炼》Dataset 和 DataLoader:数据如何喂给模型
人工智能·pytorch·python
β添砖java6 天前
深度学习(22)网络中的网络NiN
人工智能·深度学习
Kobebryant-Manba6 天前
深度学习时候d2l报错和使用问题
人工智能·深度学习
zhangfeng11336 天前
deepspeed zero3 结合 llamafactory 微调 ,save_only_model: true 导致保存时候出错
开发语言·python·深度学习
大模型最新论文速读6 天前
06-16 · LLM 最新论文速览
论文阅读·人工智能·深度学习·机器学习·自然语言处理
宝贝儿好6 天前
【LLM】第二章:HuggingFace入门学习
人工智能·深度学习·神经网络·学习·算法·自然语言处理