简单易上手的生成对抗网络

模型原理

生成对抗网络 是指一类采用对抗训练方式 进行学习的深度生成模型,包含的判别网络生成网络都可以根据不同的生成任务使用不同的网络结构。

生成器: 通过机器生成数据,最终目的是骗过判别器。
判别器: 判断这张图像是真实的还是机器生成的,目的是找出生成器做的假数据。

构建GAN模型的基本逻辑: 现实问题需求→建立实现功能的GAN框架(编程)→训练GAN(生成网络、对抗网络)→成熟的GAN模型→应用。

GAN训练过程:

生成器生成假数据,然后将生成的假数据和真数据都输入判别器,判别器要判断出哪些是真的哪些是假的。判别器第一次判别出来的肯定有很大的误差,然后我们根据误差来优化判别器。现在判别器水平提高了,生成器生成的数据很难再骗过判别器了,所以我们得反过来优化生成器,之后生成器水平提高了,然后反过来继续训练判别器,判别器水平又提高了,再反过来训练生成器,就这样循环往复,直到达到纳什均衡。

GAN的发展历程

  1. GAN的基本思想起源于2014年,由伊恩·古德费洛等人首次提出。
  2. DCGAN,它在生成器和判别器中都使用了卷积层,取得了更好的图像生成效果。
  3. ConditionalGAN,通过引入条件信息指导生成器生成特定类型的数据。
  4. Wasserstein GAN使用Wasserstein距离作为损失函数,为GAN的训练提供了更稳定的优化方法,提高了生成样本的质量。

代码实现

DCGAN模型:

python 复制代码
generator = Sequential()
generator.add(Dense(7 * 7 * 128, input_shape=[100]))
generator.add(Reshape([7, 7, 128]))
generator.add(BatchNormalization())
generator.add(Conv2DTranspose(64, kernel_size=5, strides=2, padding="same",
                                 activation="relu"))
generator.add(BatchNormalization())
generator.add(Conv2DTranspose(1, kernel_size=5, strides=2, padding="same",
                                 activation="tanh"))
 
discriminator = Sequential()
discriminator.add(Conv2D(64, kernel_size=5, strides=2, padding="same",
                        activation=LeakyReLU(0.3),
                        input_shape=[28, 28, 1]))
discriminator.add(Dropout(0.5))
discriminator.add(Conv2D(128, kernel_size=5, strides=2, padding="same",
                        activation=LeakyReLU(0.3)))
discriminator.add(Dropout(0.5))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation="sigmoid"))

模型训练:

python 复制代码
GAN =Sequential([generator,discriminator])
discriminator.compile(optimizer='adam',loss='binary_crossentropy')
discriminator.trainable = False
 
GAN.compile(optimizer='adam',loss='binary_crossentropy')
 
epochs = 150 
batch_size = 100
noise_shape=100
 
with tf.device('/gpu:0'):
 for epoch in range(epochs):
    print(f"Currently on Epoch {epoch+1}")
    
    for i in range(X_train.shape[0]//batch_size):
        
        if (i+1)%50 == 0:
            print(f"\tCurrently on batch number {i+1} of {X_train.shape[0]//batch_size}")
            
        noise=np.random.normal(size=[batch_size,noise_shape])
       
        gen_image = generator.predict_on_batch(noise)
        
        train_dataset = X_train[i*batch_size:(i+1)*batch_size]
       
        train_label=np.ones(shape=(batch_size,1))
        discriminator.trainable = True
        d_loss_real=discriminator.train_on_batch(train_dataset,train_label)
        
        train_label=np.zeros(shape=(batch_size,1))
        d_loss_fake=discriminator.train_on_batch(gen_image,train_label)
        
        noise=np.random.normal(size=[batch_size,noise_shape])
        train_label=np.ones(shape=(batch_size,1))
        discriminator.trainable = False #while training the generator as combined model,discriminator training should be turned off
        
        d_g_loss_batch =GAN.train_on_batch(noise, train_label)
        
    if epoch % 10 == 0:
        samples = 10
        x_fake = generator.predict(np.random.normal(loc=0, scale=1, size=(samples, 100)))
 
        for k in range(samples):
            plt.subplot(2, 5, k+1)
            plt.imshow(x_fake[k].reshape(28, 28), cmap='gray')
            plt.xticks([])
            plt.yticks([])
 
        plt.tight_layout()
        plt.show()
        
print('Training is complete')

使用np.random.normal生成的噪声被作为输入给发生器:

python 复制代码
noise=np.random.normal(loc=0, scale=1, size=(100,noise_shape))
gen_image = generator.predict(noise)
plt.imshow(noise)
plt.title('DCGAN Noise')
相关推荐
大任视点16 分钟前
可梦AI获首批企业好评,蜜糖网络入驻共启AI短剧工业化
人工智能
高洁0123 分钟前
大模型-详解 Vision Transformer (ViT)
人工智能·python·深度学习·算法·transformer
科技峰行者25 分钟前
亚马逊云科技与OpenAI战略合作深度分析:算力联盟重塑AI产业格局
人工智能
说私域29 分钟前
O2O行业风口下的运营策略与定制开发AI智能名片S2B2C商城小程序的应用研究
人工智能·小程序
慕慕涵雪月光白29 分钟前
在Ubuntu系统上安装英伟达(NVIDIA)RTX 3070 Ti的驱动程序
linux·运维·人工智能·ubuntu
柳鲲鹏33 分钟前
OpenCV:BGR/RGB转I420(颜色失真),再转NV12
人工智能·opencv·计算机视觉
无风听海35 分钟前
神经网络之线性变换
人工智能·深度学习·神经网络
陈果然DeepVersion36 分钟前
Java大厂面试真题:Spring Boot+Kafka+AI智能客服场景全流程解析(九)
java·人工智能·spring boot·微服务·kafka·面试题·rag
aneasystone本尊39 分钟前
重温 Java 21 之外部函数和内存 API
人工智能
IT_陈寒1 小时前
7个Java Stream API的隐藏技巧,让你的代码效率提升50%
前端·人工智能·后端