生成对抗网络入门案例

前言

生成对抗网络(Generative Adversarial Networks,简称GANs)是一种用于生成新样本的机器学习模型。它由两个主要组件组成:生成器(Generator)和判别器(Discriminator)。生成器尝试生成与训练数据相似的新样本,而判别器则试图区分生成器生成的样本和真实训练数据。

下面是一个简单的对抗生成网络的入门例子,用于生成手写数字图像:

实现过程

1、导入必要的库和模块

python 复制代码
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
from tensorflow.keras.optimizers import Adam

2、加载MNIST数据集

python 复制代码
(x_train, _), (_, _) = mnist.load_data()
x_train = x_train / 255.0
x_train = np.expand_dims(x_train, axis=3)

3、定义生成器模型

python 复制代码
generator = Sequential()
generator.add(Dense(7*7*128, input_shape=(100,), activation='relu'))
generator.add(Reshape((7, 7, 128)))
generator.add(Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', activation='relu'))
generator.add(Conv2DTranspose(1, (3, 3), strides=(2, 2), padding='same', activation='sigmoid'))

4、定义判别器模型

python 复制代码
discriminator = Sequential()
discriminator.add(Conv2D(64, (3, 3), strides=(2, 2), padding='same', input_shape=(28, 28, 1), activation='relu'))
discriminator.add(Conv2D(128, (3, 3), strides=(2, 2), padding='same', activation='relu'))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))

5、编译判别器模型

python 复制代码
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5), metrics=['accuracy'])

6、冻结判别器模型的权重

python 复制代码
discriminator.trainable = False

7、定义GAN模型

python 复制代码
gan = Sequential()
gan.add(generator)
gan.add(discriminator)

8、编译GAN模型

python 复制代码
gan.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))

9、定义训练函数

python 复制代码
def train_gan(epochs, batch_size, sample_interval):
    for epoch in range(epochs):
        # 生成随机噪声作为输入
        noise = np.random.normal(0, 1, (batch_size, 100))
        
        # 生成假样本
        generated_images = generator.predict(noise)
        
        # 从真实样本中随机选择一批样本
        real_images = x_train[np.random.randint(0, x_train.shape[0], batch_size)]
        
        # 训练判别器
        discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
        discriminator_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((batch_size, 1)))
        discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)
        
        # 训练生成器
        noise = np.random.normal(0, 1, (batch_size, 100))
        generator_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))
        
        # 打印损失
        if epoch % sample_interval == 0:
            print(f"Epoch {epoch}/{epochs}, Discriminator Loss: {discriminator_loss[0]}, Generator Loss: {generator_loss}")
            
            # 保存生成的图像
            save_images(epoch)

10、保存生成的图像

python 复制代码
def save_images(epoch):
    rows, cols = 5, 5
    noise = np.random.normal(0, 1, (rows * cols, 100))
    generated_images = generator.predict(noise)
    generated_images = 0.5 * generated_images + 0.5
    fig, axs = plt.subplots(rows, cols)
    idx = 0
    for i in range(rows):
        for j in range(cols):
            axs[i, j].imshow(generated_images[idx, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            idx += 1
    fig.savefig(f"gan_images/mnist_{epoch}.png")
    plt.close()

11、训练GAN模型

python 复制代码
epochs = 10000
batch_size = 128
sample_interval = 1000

完整代码

python 复制代码
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
from tensorflow.keras.optimizers import Adam

# 加载MNIST数据集
(x_train, _), (_, _) = mnist.load_data()
x_train = x_train / 255.0
x_train = np.expand_dims(x_train, axis=3)

# 定义生成器模型
generator = Sequential()
generator.add(Dense(7*7*128, input_shape=(100,), activation='relu'))
generator.add(Reshape((7, 7, 128)))
generator.add(Conv2DTranspose(64, (3, 3), strides=(2, 2), padding='same', activation='relu'))
generator.add(Conv2DTranspose(1, (3, 3), strides=(2, 2), padding='same', activation='sigmoid'))

# 定义判别器模型
discriminator = Sequential()
discriminator.add(Conv2D(64, (3, 3), strides=(2, 2), padding='same', input_shape=(28, 28, 1), activation='relu'))
discriminator.add(Conv2D(128, (3, 3), strides=(2, 2), padding='same', activation='relu'))
discriminator.add(Flatten())
discriminator.add(Dense(1, activation='sigmoid'))

# 编译判别器模型
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5), metrics=['accuracy'])

# 冻结判别器模型的权重
discriminator.trainable = False

# 定义GAN模型
gan = Sequential()
gan.add(generator)
gan.add(discriminator)

# 编译GAN模型
gan.compile(loss='binary_crossentropy', optimizer=Adam(learning_rate=0.0002, beta_1=0.5))

# 定义训练函数
def train_gan(epochs, batch_size, sample_interval):
    for epoch in range(epochs):
        # 生成随机噪声作为输入
        noise = np.random.normal(0, 1, (batch_size, 100))
        
        # 生成假样本
        generated_images = generator.predict(noise)
        
        # 从真实样本中随机选择一批样本
        real_images = x_train[np.random.randint(0, x_train.shape[0], batch_size)]
        
        # 训练判别器
        discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
        discriminator_loss_fake = discriminator.train_on_batch(generated_images, np.zeros((batch_size, 1)))
        discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)
        
        # 训练生成器
        noise = np.random.normal(0, 1, (batch_size, 100))
        generator_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))
        
        # 打印损失
        if epoch % sample_interval == 0:
            print(f"Epoch {epoch}/{epochs}, Discriminator Loss: {discriminator_loss[0]}, Generator Loss: {generator_loss}")
            
            # 保存生成的图像
            save_images(epoch)
            
# 保存生成的图像
def save_images(epoch):
    rows, cols = 5, 5
    noise = np.random.normal(0, 1, (rows * cols, 100))
    generated_images = generator.predict(noise)
    generated_images = 0.5 * generated_images + 0.5
    fig, axs = plt.subplots(rows, cols)
    idx = 0
    for i in range(rows):
        for j in range(cols):
            axs[i, j].imshow(generated_images[idx, :, :, 0], cmap='gray')
            axs[i, j].axis('off')
            idx += 1
    fig.savefig(f"gan_images/mnist_{epoch}.png")
    plt.close()
    
# 训练GAN模型
epochs = 10000
batch_size = 128
sample_interval = 1000

train_gan(epochs, batch_size, sample_interval)

这个例子使用了MNIST数据集,生成手写数字图像。生成器和判别器模型使用了卷积神经网络的结构。在训练过程中,生成器试图生成逼真的手写数字图像,而判别器则试图区分真实图像和生成图像。通过反复迭代训练生成器和判别器,GAN模型能够逐渐生成更逼真的手写数字图像。生成的图像会保存在gan_images文件夹中。

相关推荐
ACP广源盛139246256731 小时前
(ACP广源盛)GSV1175---- MIPI/LVDS 转 Type-C/DisplayPort 1.2 转换器产品说明及功能分享
人工智能·音视频
胡耀超1 小时前
隐私计算技术全景:从联邦学习到可信执行环境的实战指南—数据安全——隐私计算 联邦学习 多方安全计算 可信执行环境 差分隐私
人工智能·安全·数据安全·tee·联邦学习·差分隐私·隐私计算
停停的茶3 小时前
深度学习(目标检测)
人工智能·深度学习·目标检测
Y200309163 小时前
基于 CIFAR10 数据集的卷积神经网络(CNN)模型训练与集成学习
人工智能·cnn·集成学习
老兵发新帖3 小时前
主流神经网络快速应用指南
人工智能·深度学习·神经网络
AI量化投资实验室3 小时前
15年122倍,年化43.58%,回撤才20%,Optuna机器学习多目标调参backtrader,附python代码
人工智能·python·机器学习
java_logo4 小时前
vllm-openai Docker 部署手册
运维·人工智能·docker·ai·容器
倔强青铜三4 小时前
苦练Python第67天:光速读取任意行,linecache模块解锁文件处理新姿势
人工智能·python·面试
算家计算4 小时前
重磅突破!全球首个真实物理环境机器人基准测试正式发布,具身智能迎来 “ImageNet 时刻”
人工智能·资讯
新智元4 小时前
苹果 M5「夜袭」高通英特尔!AI 算力狂飙 400%,Pro 三剑客火速上新
人工智能·openai