简单易上手的生成对抗网络

模型原理

生成对抗网络 是指一类采用对抗训练方式 进行学习的深度生成模型,包含的判别网络生成网络都可以根据不同的生成任务使用不同的网络结构。

生成器: 通过机器生成数据,最终目的是骗过判别器。
判别器: 判断这张图像是真实的还是机器生成的,目的是找出生成器做的假数据。

构建GAN模型的基本逻辑: 现实问题需求→建立实现功能的GAN框架(编程)→训练GAN(生成网络、对抗网络)→成熟的GAN模型→应用。

GAN训练过程:

生成器生成假数据,然后将生成的假数据和真数据都输入判别器,判别器要判断出哪些是真的哪些是假的。判别器第一次判别出来的肯定有很大的误差,然后我们根据误差来优化判别器。现在判别器水平提高了,生成器生成的数据很难再骗过判别器了,所以我们得反过来优化生成器,之后生成器水平提高了,然后反过来继续训练判别器,判别器水平又提高了,再反过来训练生成器,就这样循环往复,直到达到纳什均衡。

GAN的发展历程

  1. GAN的基本思想起源于2014年,由伊恩·古德费洛等人首次提出。
  2. DCGAN,它在生成器和判别器中都使用了卷积层,取得了更好的图像生成效果。
  3. ConditionalGAN,通过引入条件信息指导生成器生成特定类型的数据。
  4. Wasserstein GAN使用Wasserstein距离作为损失函数,为GAN的训练提供了更稳定的优化方法,提高了生成样本的质量。

代码实现

DCGAN模型:

python 复制代码
generator = Sequential()
generator.add(Dense(7 * 7 * 128, input_shape=[100]))
generator.add(Reshape([7, 7, 128]))
generator.add(BatchNormalization())
generator.add(Conv2DTranspose(64, kernel_size=5, strides=2, padding="same",
                                 activation="relu"))
generator.add(BatchNormalization())
generator.add(Conv2DTranspose(1, kernel_size=5, strides=2, padding="same",
                                 activation="tanh"))
 
discriminator = Sequential()
discriminator.add(Conv2D(64, kernel_size=5, strides=2, padding="same",
                        activation=LeakyReLU(0.3),
                        input_shape=[28, 28, 1]))
discriminator.add(Dropout(0.5))
discriminator.add(Conv2D(128, kernel_size=5, strides=2, padding="same",
                        activation=LeakyReLU(0.3)))
discriminator.add(Dropout(0.5))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation="sigmoid"))

模型训练:

python 复制代码
GAN =Sequential([generator,discriminator])
discriminator.compile(optimizer='adam',loss='binary_crossentropy')
discriminator.trainable = False
 
GAN.compile(optimizer='adam',loss='binary_crossentropy')
 
epochs = 150 
batch_size = 100
noise_shape=100
 
with tf.device('/gpu:0'):
 for epoch in range(epochs):
    print(f"Currently on Epoch {epoch+1}")
    
    for i in range(X_train.shape[0]//batch_size):
        
        if (i+1)%50 == 0:
            print(f"\tCurrently on batch number {i+1} of {X_train.shape[0]//batch_size}")
            
        noise=np.random.normal(size=[batch_size,noise_shape])
       
        gen_image = generator.predict_on_batch(noise)
        
        train_dataset = X_train[i*batch_size:(i+1)*batch_size]
       
        train_label=np.ones(shape=(batch_size,1))
        discriminator.trainable = True
        d_loss_real=discriminator.train_on_batch(train_dataset,train_label)
        
        train_label=np.zeros(shape=(batch_size,1))
        d_loss_fake=discriminator.train_on_batch(gen_image,train_label)
        
        noise=np.random.normal(size=[batch_size,noise_shape])
        train_label=np.ones(shape=(batch_size,1))
        discriminator.trainable = False #while training the generator as combined model,discriminator training should be turned off
        
        d_g_loss_batch =GAN.train_on_batch(noise, train_label)
        
    if epoch % 10 == 0:
        samples = 10
        x_fake = generator.predict(np.random.normal(loc=0, scale=1, size=(samples, 100)))
 
        for k in range(samples):
            plt.subplot(2, 5, k+1)
            plt.imshow(x_fake[k].reshape(28, 28), cmap='gray')
            plt.xticks([])
            plt.yticks([])
 
        plt.tight_layout()
        plt.show()
        
print('Training is complete')

使用np.random.normal生成的噪声被作为输入给发生器:

python 复制代码
noise=np.random.normal(loc=0, scale=1, size=(100,noise_shape))
gen_image = generator.predict(noise)
plt.imshow(noise)
plt.title('DCGAN Noise')
相关推荐
zhurui_xiaozhuzaizai2 分钟前
模型训练-关于token【低概率token, 高熵token】
人工智能·算法·自然语言处理
清醒的兰35 分钟前
OpenCV 图像像素值统计
人工智能·opencv·计算机视觉
彭祥.44 分钟前
YOLO电力物目标检测训练
人工智能·yolo·目标检测
尘浮7281 小时前
60天python训练计划----day50
人工智能·python·深度学习
Listennnn1 小时前
OCR & MLLM & Evaluation
人工智能·ocr
云布道师1 小时前
云服务运行安全创新标杆:阿里云飞天洛神云网络子系统“齐天”再次斩获奖项
网络·人工智能·安全·阿里云·云计算·云布道师
lcw_lance1 小时前
智慧园区综合运营管理平台(SmartPark)和安全EHS平台的分工与协作
大数据·人工智能
ywyy67982 小时前
「全栈技术解析」推客小程序系统开发:从架构设计到裂变增长的完整解决方案
人工智能·小程序·短剧·推客系统·推客小程序·推客系统开发·推客小程序开发
加百力2 小时前
自动驾驶+人形机器人?亚马逊即将测试人形机器人送货
人工智能·机器人·自动驾驶
落沐萧萧2 小时前
本地多语言 AI 字幕组:Whisper 实战教程
人工智能·whisper