PyTorch训练深度卷积生成对抗网络DCGAN

文章目录

DCGAN介绍

将CNN和GAN结合起来,把监督学习和无监督学习结合起来。具体解释可以参见 深度卷积对抗生成网络(DCGAN)

DCGAN的生成器结构:

图片来源:https://arxiv.org/abs/1511.06434

代码

model.py

cpp 复制代码
import torch
import torch.nn as nn

class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            # Input: N x channels_img x 64 x 64
            nn.Conv2d(
                channels_img, features_d, kernel_size=4, stride=2, padding=1
            ), # 32 x 32
            nn.LeakyReLU(0.2),
            self._block(features_d, features_d*2, 4, 2, 1), # 16 x 16
            self._block(features_d*2, features_d*4, 4, 2, 1), # 8 x 8
            self._block(features_d*4, features_d*8, 4, 2, 1), # 4 x 4
            nn.Conv2d(features_d*8, 1, kernel_size=4, stride=2, padding=0), # 1 x 1
            nn.Sigmoid(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        return self.disc(x)
    
class Generator(nn.Module):
    def __init__(self, z_dim, channels_img, features_g):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            # Input: N x z_dim x 1 x 1
            self._block(z_dim, features_g*16, 4, 1, 0), # N x f_g*16 x 4 x 4
            self._block(features_g*16, features_g*8, 4, 2, 1), # 8x8
            self._block(features_g*8, features_g*4, 4, 2, 1), # 16x16
            self._block(features_g*4, features_g*2, 4, 2, 1), # 32x32
            nn.ConvTranspose2d(
                features_g*2, channels_img, kernel_size=4, stride=2, padding=1,
            ),
            nn.Tanh(),

        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )
    
    def forward(self, x):
        return self.gen(x)


def initialize_weights(model):
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

def test():
    N, in_channels, H, W = 8, 3, 64, 64
    z_dim = 100
    x = torch.randn((N, in_channels, H, W))
    disc = Discriminator(in_channels, 8)
    initialize_weights(disc)
    assert disc(x).shape == (N, 1, 1, 1)

    gen = Generator(z_dim, in_channels, 8)
    initialize_weights(gen)
    z = torch.randn((N, z_dim, 1, 1))
    assert gen(z).shape == (N, in_channels, H, W)
    print("success")
    
if __name__ == "__main__":
    test()

训练使用的数据集:CelebA dataset (Images Only) 总共1.3GB的图片,使用方法,将其解压到当前目录

图片如下图所示:

train.py

cpp 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from model import Discriminator, Generator, initialize_weights

# Hyperparameters etc.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE = 2e-4  # could also use two lrs, one for gen and one for disc
BATCH_SIZE = 128
IMAGE_SIZE = 64
CHANNELS_IMG = 3 # 1 if MNIST dataset; 3 if celeb dataset
NOISE_DIM = 100
NUM_EPOCHS = 5
FEATURES_DISC = 64
FEATURES_GEN = 64

transforms = transforms.Compose(
    [
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize(
            [0.5 for _ in range(CHANNELS_IMG)], [0.5 for _ in range(CHANNELS_IMG)]
        ),
    ]
)

# If you train on MNIST, remember to set channels_img to 1
# dataset = datasets.MNIST(
#     root="dataset/", train=True, transform=transforms, download=True
# )

# comment mnist above and uncomment below if train on CelebA

# If you train on celeb dataset, remember to set channels_img to 3
dataset = datasets.ImageFolder(root="celeb_dataset", transform=transforms)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
gen = Generator(NOISE_DIM, CHANNELS_IMG, FEATURES_GEN).to(device)
disc = Discriminator(CHANNELS_IMG, FEATURES_DISC).to(device)
initialize_weights(gen)
initialize_weights(disc)

opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999))
criterion = nn.BCELoss()

fixed_noise = torch.randn(32, NOISE_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0

gen.train()
disc.train()

for epoch in range(NUM_EPOCHS):
    # Target labels not needed! <3 unsupervised
    for batch_idx, (real, _) in enumerate(dataloader):
        real = real.to(device)
        noise = torch.randn(BATCH_SIZE, NOISE_DIM, 1, 1).to(device)
        fake = gen(noise)

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(dataloader)} \
                  Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1

结果

训练5个epoch,部分结果如下:

cpp 复制代码
Epoch [3/5] Batch 1500/1583                   Loss D: 0.4996, loss G: 1.1738
Epoch [4/5] Batch 0/1583                   Loss D: 0.4268, loss G: 1.6633
Epoch [4/5] Batch 100/1583                   Loss D: 0.4841, loss G: 1.7475
Epoch [4/5] Batch 200/1583                   Loss D: 0.5094, loss G: 1.2376
Epoch [4/5] Batch 300/1583                   Loss D: 0.4376, loss G: 2.1271
Epoch [4/5] Batch 400/1583                   Loss D: 0.4173, loss G: 1.4380
Epoch [4/5] Batch 500/1583                   Loss D: 0.5213, loss G: 2.1665
Epoch [4/5] Batch 600/1583                   Loss D: 0.5036, loss G: 2.1079
Epoch [4/5] Batch 700/1583                   Loss D: 0.5158, loss G: 1.0579
Epoch [4/5] Batch 800/1583                   Loss D: 0.5426, loss G: 1.9427
Epoch [4/5] Batch 900/1583                   Loss D: 0.4721, loss G: 1.2659
Epoch [4/5] Batch 1000/1583                   Loss D: 0.5662, loss G: 2.4537
Epoch [4/5] Batch 1100/1583                   Loss D: 0.5604, loss G: 0.8978
Epoch [4/5] Batch 1200/1583                   Loss D: 0.4085, loss G: 2.0747
Epoch [4/5] Batch 1300/1583                   Loss D: 1.1894, loss G: 0.1825
Epoch [4/5] Batch 1400/1583                   Loss D: 0.4518, loss G: 2.1509
Epoch [4/5] Batch 1500/1583                   Loss D: 0.3814, loss G: 1.9391

使用

cpp 复制代码
tensorboard --logdir=logs

打开tensorboard

参考

[1] DCGAN implementation from scratch

[2] https://arxiv.org/abs/1511.06434

相关推荐
埃菲尔铁塔_CV算法19 分钟前
深度学习神经网络创新点方向
人工智能·深度学习·神经网络
图片转成excel表格2 小时前
WPS Office Excel 转 PDF 后图片丢失的解决方法
人工智能·科技·深度学习
李歘歘3 小时前
万字长文解读深度学习——多模态模型CLIP、BLIP、ViLT
人工智能·深度学习
B站计算机毕业设计超人3 小时前
计算机毕业设计Python+大模型农产品价格预测 ARIMA自回归模型 农产品可视化 农产品爬虫 机器学习 深度学习 大数据毕业设计 Django Flask
大数据·爬虫·python·深度学习·机器学习·课程设计·数据可视化
新手小白勇闯新世界3 小时前
深度学习知识点5-马尔可夫链
人工智能·深度学习·计算机视觉
布鲁格若门4 小时前
AMD CPU下pytorch 多GPU运行卡死和死锁解决
人工智能·pytorch·python·nvidia
热爱生活的五柒5 小时前
深度学习:利用随机数据更快地测试一个新的模型在自己数据格式很复杂的时候
人工智能·深度学习
科研实践课堂(公众号名称)5 小时前
基于OpenFOAM和深度学习驱动的流体力学计算与应用
人工智能·深度学习
小锋学长生活大爆炸5 小时前
【教程】Cupy、Numpy、Torch互相转换
pytorch·numpy·cupy