简单易上手的生成对抗网络

模型原理

生成对抗网络 是指一类采用对抗训练方式 进行学习的深度生成模型,包含的判别网络生成网络都可以根据不同的生成任务使用不同的网络结构。

生成器: 通过机器生成数据,最终目的是骗过判别器。
判别器: 判断这张图像是真实的还是机器生成的,目的是找出生成器做的假数据。

构建GAN模型的基本逻辑: 现实问题需求→建立实现功能的GAN框架(编程)→训练GAN(生成网络、对抗网络)→成熟的GAN模型→应用。

GAN训练过程:

生成器生成假数据,然后将生成的假数据和真数据都输入判别器,判别器要判断出哪些是真的哪些是假的。判别器第一次判别出来的肯定有很大的误差,然后我们根据误差来优化判别器。现在判别器水平提高了,生成器生成的数据很难再骗过判别器了,所以我们得反过来优化生成器,之后生成器水平提高了,然后反过来继续训练判别器,判别器水平又提高了,再反过来训练生成器,就这样循环往复,直到达到纳什均衡。

GAN的发展历程

  1. GAN的基本思想起源于2014年,由伊恩·古德费洛等人首次提出。
  2. DCGAN,它在生成器和判别器中都使用了卷积层,取得了更好的图像生成效果。
  3. ConditionalGAN,通过引入条件信息指导生成器生成特定类型的数据。
  4. Wasserstein GAN使用Wasserstein距离作为损失函数,为GAN的训练提供了更稳定的优化方法,提高了生成样本的质量。

代码实现

DCGAN模型:

python 复制代码
generator = Sequential()
generator.add(Dense(7 * 7 * 128, input_shape=[100]))
generator.add(Reshape([7, 7, 128]))
generator.add(BatchNormalization())
generator.add(Conv2DTranspose(64, kernel_size=5, strides=2, padding="same",
                                 activation="relu"))
generator.add(BatchNormalization())
generator.add(Conv2DTranspose(1, kernel_size=5, strides=2, padding="same",
                                 activation="tanh"))
 
discriminator = Sequential()
discriminator.add(Conv2D(64, kernel_size=5, strides=2, padding="same",
                        activation=LeakyReLU(0.3),
                        input_shape=[28, 28, 1]))
discriminator.add(Dropout(0.5))
discriminator.add(Conv2D(128, kernel_size=5, strides=2, padding="same",
                        activation=LeakyReLU(0.3)))
discriminator.add(Dropout(0.5))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation="sigmoid"))

模型训练:

python 复制代码
GAN =Sequential([generator,discriminator])
discriminator.compile(optimizer='adam',loss='binary_crossentropy')
discriminator.trainable = False
 
GAN.compile(optimizer='adam',loss='binary_crossentropy')
 
epochs = 150 
batch_size = 100
noise_shape=100
 
with tf.device('/gpu:0'):
 for epoch in range(epochs):
    print(f"Currently on Epoch {epoch+1}")
    
    for i in range(X_train.shape[0]//batch_size):
        
        if (i+1)%50 == 0:
            print(f"\tCurrently on batch number {i+1} of {X_train.shape[0]//batch_size}")
            
        noise=np.random.normal(size=[batch_size,noise_shape])
       
        gen_image = generator.predict_on_batch(noise)
        
        train_dataset = X_train[i*batch_size:(i+1)*batch_size]
       
        train_label=np.ones(shape=(batch_size,1))
        discriminator.trainable = True
        d_loss_real=discriminator.train_on_batch(train_dataset,train_label)
        
        train_label=np.zeros(shape=(batch_size,1))
        d_loss_fake=discriminator.train_on_batch(gen_image,train_label)
        
        noise=np.random.normal(size=[batch_size,noise_shape])
        train_label=np.ones(shape=(batch_size,1))
        discriminator.trainable = False #while training the generator as combined model,discriminator training should be turned off
        
        d_g_loss_batch =GAN.train_on_batch(noise, train_label)
        
    if epoch % 10 == 0:
        samples = 10
        x_fake = generator.predict(np.random.normal(loc=0, scale=1, size=(samples, 100)))
 
        for k in range(samples):
            plt.subplot(2, 5, k+1)
            plt.imshow(x_fake[k].reshape(28, 28), cmap='gray')
            plt.xticks([])
            plt.yticks([])
 
        plt.tight_layout()
        plt.show()
        
print('Training is complete')

使用np.random.normal生成的噪声被作为输入给发生器:

python 复制代码
noise=np.random.normal(loc=0, scale=1, size=(100,noise_shape))
gen_image = generator.predict(noise)
plt.imshow(noise)
plt.title('DCGAN Noise')
相关推荐
JJTX0011 分钟前
入门基础人工智能理论
人工智能·搜索引擎
神经星星34 分钟前
3秒检测准确率超90%,Ainnova Tech研发视网膜病变早筛平台,临床试验方案获FDA指导
数据库·人工智能·llm
东方佑1 小时前
UniVoc:基于二维矩阵映射的多语言词汇表系统
人工智能·算法·矩阵
小关会打代码1 小时前
计算机视觉第一课opencv(二)保姆级教
人工智能·opencv·计算机视觉
dundunmm1 小时前
【每天一个知识点】生物的数字孪生
人工智能·数字孪生·生物信息·单细胞
码码哈哈爱分享2 小时前
Cursor替代品:亚马逊出品,Kiro免费使用Claude Sonnet4.0一款更注重流程感的 AI IDE
人工智能·ai编程
roman_日积跬步-终至千里2 小时前
【深度学习】深度学习的四个核心步骤:从房价预测看机器学习本质
人工智能·深度学习·机器学习
wwww.bo2 小时前
机器学习(1)
人工智能·机器学习
CV实验室2 小时前
CVPR 2025 | 北大团队SLAM3R:单目RGB长视频实时重建,精度效率双杀!
人工智能·计算机视觉·论文·音视频
MARS_AI_2 小时前
云蝠智能 VoiceAgent:重构物流售后场景的智能化引擎
人工智能·自然语言处理·重构·交互·信息与通信