用GAN生成奖杯

数据集链接:https://pan.baidu.com/s/19Uxc2ELiMG3acUtLeSTDTA?pwd=wsyw

提取码:wsyw

我设置的图片大小为128*128,如果内存爆炸可以将batch_size调小,epoch我设置的2000,我感觉其实1000也够了。代码如下:

python 复制代码
import argparse
from torchvision import datasets, transforms
import torch
import torch.nn as nn
import os
import numpy as np
from torchvision.utils import save_image


def args_parse():
    parser = argparse.ArgumentParser()
    parser.add_argument("--n_epoches", type=int, default=2000, help="number of epochs of training")
    parser.add_argument("--batch_size", type=int, default=256, help="size of the batches")
    parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate")
    parser.add_argument("--n_cpu", type=int, default=1, help="number of cpu threads to use during batch generation")
    parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space")
    parser.add_argument("--img_size", type=int, default=128, help="size of each image dimension")
    parser.add_argument("--channels", type=int, default=3, help="number of image channels")
    parser.add_argument("--sample_interval", type=int, default=50, help="interval between image sampling")
    parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient")
    parser.add_argument("--type", type=str, default='DCGAN', help="The type of DCGAN")
    return parser.parse_args()

class Generator(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()
        self.img_shape = img_shape
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )
    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), self.img_shape[0], self.img_shape[1], self.img_shape[2])
        return img


class Discriminator(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )
    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity


class Generator_CNN(nn.Module):
    def __init__(self, latent_dim, img_shape):
        super(Generator_CNN, self).__init__()
        self.init_size = img_shape[1] // 4
        self.l1 = nn.Sequential(nn.Linear(latent_dim, 128 * self.init_size ** 2))  # 100 ------> 128 * 8 * 8 = 8192
        self.conv_blocks = nn.Sequential(
            nn.BatchNorm2d(128),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 128, 3, stride=1, padding=1),
            nn.BatchNorm2d(128, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, 3, stride=1, padding=1),
            nn.BatchNorm2d(64, 0.8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(64, img_shape[0], 3, stride=1, padding=1),
            nn.Tanh()
        )
    def forward(self, z):
        out = self.l1(z)
        out = out.view(out.shape[0], 128, self.init_size, self.init_size)
        img = self.conv_blocks(out)
        return img


class Discriminator_CNN(nn.Module):
    def __init__(self, img_shape):
        super(Discriminator_CNN, self).__init__()
        def discriminator_block(in_filters, out_filters, bn=True):
            block = [nn.Conv2d(in_filters, out_filters, 3, 2, 1),
                     nn.LeakyReLU(0.2, inplace=True),
                     nn.Dropout2d(0.25)]
            if bn:
                block.append(nn.BatchNorm2d(out_filters, 0.8))
            return block
        self.model = nn.Sequential(
            *discriminator_block(img_shape[0], 16, bn=False),
            *discriminator_block(16, 32),
            *discriminator_block(32, 64),
            *discriminator_block(64, 128),
        )
        ds_size = img_shape[1] // 2 ** 4
        self.adv_layer = nn.Sequential(nn.Linear(128 * ds_size ** 2, 1), nn.Sigmoid())  # 128 * 2 * 2 ------> 1

    def forward(self, img):
        out = self.model(img)
        out = out.view(out.shape[0], -1)
        # print('out:', out.shape)
        validity = self.adv_layer(out)
        return validity


def train():
    opt = args_parse()
    transform = transforms.Compose(
        [
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])
    data = datasets.ImageFolder('./dataset', transform=transform)
    train_loader = torch.utils.data.DataLoader(
        data,
        batch_size=opt.batch_size,
        shuffle=True)
    img_shape = (opt.channels, opt.img_size, opt.img_size)
    # Construct generator and discriminator
    if opt.type == 'DCGAN':
        generator = Generator_CNN(opt.latent_dim, img_shape)
        discriminator = Discriminator_CNN(img_shape)
    else:
        generator = Generator(opt.latent_dim, img_shape)
        discriminator = Discriminator(img_shape)
    adversarial_loss = torch.nn.BCELoss()
    cuda = True if torch.cuda.is_available() else False
    if cuda:
        generator.cuda()
        discriminator.cuda()
        adversarial_loss.cuda()
    # Optimizers
    optimizer_G = torch.optim.Adam(generator.parameters(), lr=(opt.lr * 8 / 9), betas=(opt.b1, opt.b2))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
    for epoch in range(opt.n_epoches):
        for i, (imgs, _) in enumerate(train_loader):
            # adversarial ground truths
            valid = torch.ones(imgs.shape[0], 1).type(Tensor)
            fake = torch.zeros(imgs.shape[0], 1).type(Tensor)
            real_imgs = imgs.type(Tensor)
            #############    Train Generator    ################
            optimizer_G.zero_grad()
            # sample noise as generator input
            z = torch.tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))).type(Tensor)
            # Generate a batch of images
            gen_imgs = generator(z)
            # G-Loss
            g_loss = adversarial_loss(discriminator(gen_imgs), valid)
            g_loss.backward()
            optimizer_G.step()
            #############  Train Discriminator ################
            optimizer_D.zero_grad()
            # D-Loss
            real_loss = adversarial_loss(discriminator(real_imgs), valid)
            fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
            d_loss = (real_loss + fake_loss) / 2
            d_loss.backward()
            optimizer_D.step()
            print(
                "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G Loss: %f]"
                % (epoch, opt.n_epoches, i, len(train_loader), d_loss.item(), g_loss.item())
            )
            batches_done = epoch * len(train_loader) + i
            os.makedirs("images_3_2", exist_ok=True)
            if batches_done % opt.sample_interval == 0:
                save_image(gen_imgs.data[:25], "images_3_2/%d.png" % (batches_done), nrow=5, normalize=True)


if __name__ == '__main__':
    train()

实验效果如下:

相关推荐
北京搜维尔科技有限公司23 分钟前
搜维尔科技:【应用】Xsens在荷兰车辆管理局人体工程学评估中的应用
人工智能·安全
说私域26 分钟前
基于开源 AI 智能名片 S2B2C 商城小程序的视频号交易小程序优化研究
人工智能·小程序·零售
YRr YRr27 分钟前
深度学习:Transformer Decoder详解
人工智能·深度学习·transformer
知来者逆32 分钟前
研究大语言模型在心理保健智能顾问的有效性和挑战
人工智能·神经网络·机器学习·语言模型·自然语言处理
云起无垠41 分钟前
技术分享 | 大语言模型赋能软件测试:开启智能软件安全新时代
人工智能·安全·语言模型
老艾的AI世界1 小时前
新一代AI换脸更自然,DeepLiveCam下载介绍(可直播)
图像处理·人工智能·深度学习·神经网络·目标检测·机器学习·ai换脸·视频换脸·直播换脸·图片换脸
翔云API1 小时前
PHP静默活体识别API接口应用场景与集成方案
人工智能
浊酒南街1 小时前
吴恩达深度学习笔记:卷积神经网络(Foundations of Convolutional Neural Networks)4.9-4.10
人工智能·深度学习·神经网络·cnn
Tony聊跨境2 小时前
独立站SEO类型及优化:来检查这些方面你有没有落下
网络·人工智能·tcp/ip·ip
懒惰才能让科技进步2 小时前
从零学习大模型(十二)-----基于梯度的重要性剪枝(Gradient-based Pruning)
人工智能·深度学习·学习·算法·chatgpt·transformer·剪枝