生成对抗网络(GAN)手写数字生成

文章目录

一、前言

我的环境:

  • 语言环境:Python3.6.5
  • 编译器:jupyter notebook
  • 深度学习环境:TensorFlow2.4.1

往期精彩内容:

来自专栏: 机器学习与深度学习算法推荐

二、前期工作

1. 设置GPU(如果使用的是CPU可以忽略这步)

python 复制代码
import tensorflow as tf

gpus = tf.config.list_physical_devices("GPU")

if gpus:
    tf.config.experimental.set_memory_growth(gpus[0], True)  #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpus[0]],"GPU")
    
# 打印显卡信息,确认GPU可用
print(gpus)
python 复制代码
from tensorflow.keras import layers, datasets, Sequential, Model, optimizers
from tensorflow.keras.layers import LeakyReLU, UpSampling2D, Conv2D

import matplotlib.pyplot as plt
import numpy             as np
import sys,os,pathlib
py 复制代码
img_shape  = (28, 28, 1)
latent_dim = 200

二、什么是生成对抗网络

1. 简单介绍

生成对抗网络(GAN) 包含生成器和判别器,两个模型通过对抗训练不断学习、进化。

  • 生成器(Generator):生成数据(大部分情况下是图像),目的是"骗过"判别器。
  • 鉴别器(Discriminator):判断这张图像是真实的还是机器生成的,目的是找出生成器生成的"假数据"。

2. 应用领域

GAN 的应用十分广泛,它的应用包括图像合成、风格迁移、照片修复以及照片编辑,数据增强等等。

1)风格迁移

图像风格迁移是将图像A的风格转换到图像B中去,得到新的图像。

2)图像生成

GAN 不但能生成人脸,还能生成其他类型的图片,比如漫画人物。

三、网络结构

简单来讲,就是用生成器生成手写数字图像,用鉴别器鉴别图像的真假。二者相互对抗学习(卷),在对抗学习(卷)的过程中不断完善自己,直至生成器可以生成以假乱真的图片(鉴别器无法判断其真假)。结构图如下:

GAN步骤:

  • 1.生成器(Generator)接收随机数并返回生成图像。
  • 2.将生成的数字图像与实际数据集中的数字图像一起送到鉴别器(Discriminator)。
  • 3.鉴别器(Discriminator)接收真实和假图像并返回概率,0到1之间的数字,1表示真,0表示假。

四、构建生成器

py 复制代码
def build_generator():
    # ======================================= #
    #     生成器,输入一串随机数字生成图片
    # ======================================= #
    model = Sequential([
        layers.Dense(256, input_dim=latent_dim),
        layers.LeakyReLU(alpha=0.2),               # 高级一点的激活函数
        layers.BatchNormalization(momentum=0.8),   # BN 归一化
        
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),
        
        layers.Dense(1024),
        layers.LeakyReLU(alpha=0.2),
        layers.BatchNormalization(momentum=0.8),
        
        layers.Dense(np.prod(img_shape), activation='tanh'),
        layers.Reshape(img_shape)
    ])

    noise = layers.Input(shape=(latent_dim,))
    img = model(noise)

    return Model(noise, img)

五、构建鉴别器

py 复制代码
def build_discriminator():
    # ===================================== #
    #   鉴别器,对输入的图片进行判别真假
    # ===================================== #
    model = Sequential([
        layers.Flatten(input_shape=img_shape),
        layers.Dense(512),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(256),
        layers.LeakyReLU(alpha=0.2),
        layers.Dense(1, activation='sigmoid')
    ])

    img = layers.Input(shape=img_shape)
    validity = model(img)

    return Model(img, validity)
py 复制代码
# 创建判别器
discriminator = build_discriminator()
# 定义优化器
optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator.compile(loss='binary_crossentropy',
                      optimizer=optimizer,
                      metrics=['accuracy'])

# 创建生成器 
generator = build_generator()
gan_input = layers.Input(shape=(latent_dim,))
img = generator(gan_input)

# 对生成的假图片进行预测
validity = discriminator(img)
combined = Model(gan_input, validity)
combined.compile(loss='binary_crossentropy', optimizer=optimizer)

六、训练模型

1. 保存样例图片

py 复制代码
def sample_images(epoch):
    """
    保存样例图片
    """
    row, col = 4, 4
    noise = np.random.normal(0, 1, (row*col, latent_dim))
    gen_imgs = generator.predict(noise)

    fig, axs = plt.subplots(row, col)
    cnt = 0
    for i in range(row):
        for j in range(col):
            axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
            axs[i,j].axis('off')
            cnt += 1
    fig.savefig("images/%05d.png" % epoch)
    plt.close()

2. 训练模型

train_on_batch:函数接受单批数据,执行反向传播,然后更新模型参数,该批数据的大小可以是任意的,即,它不需要提供明确的批量大小,属于精细化控制训练模型。

py 复制代码
def train(epochs, batch_size=128, sample_interval=50):
    # 加载数据
    (train_images,_), (_,_) = tf.keras.datasets.mnist.load_data()

    # 将图片标准化到 [-1, 1] 区间内   
    train_images = (train_images - 127.5) / 127.5
    # 数据
    train_images = np.expand_dims(train_images, axis=3)

    # 创建标签
    true = np.ones((batch_size, 1))
    fake = np.zeros((batch_size, 1))
    
    # 进行循环训练
    for epoch in range(epochs): 

        # 随机选择 batch_size 张图片
        idx = np.random.randint(0, train_images.shape[0], batch_size)
        imgs = train_images[idx]      
        
        # 生成噪音
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        # 生成器通过噪音生成图片,gen_imgs的shape为:(128, 28, 28, 1)
        gen_imgs = generator.predict(noise)
        
        # 训练鉴别器 
        d_loss_true = discriminator.train_on_batch(imgs, true)
        d_loss_fake = discriminator.train_on_batch(gen_imgs, fake)
        # 返回loss值
        d_loss = 0.5 * np.add(d_loss_true, d_loss_fake)

        # 训练生成器
        noise = np.random.normal(0, 1, (batch_size, latent_dim))
        g_loss = combined.train_on_batch(noise, true)
        
        print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))

        # 保存样例图片
        if epoch % sample_interval == 0:
            sample_images(epoch)
py 复制代码
train(epochs=30000, batch_size=256, sample_interval=200)

七、生成动图

如果报错:ModuleNotFoundError: No module named 'imageio' 可以使用:pip install imageio 安装 imageio 库。

py 复制代码
import imageio

def compose_gif():
    # 图片地址
    data_dir = "images_old"
    data_dir = pathlib.Path(data_dir)
    paths    = list(data_dir.glob('*'))
    
    gif_images = []
    for path in paths:
        print(path)
        gif_images.append(imageio.imread(path))
    imageio.mimsave("test.gif",gif_images,fps=2)
    
compose_gif()
相关推荐
极小狐11 分钟前
极狐GitLab 如何 cherry-pick 变更?
人工智能·git·机器学习·gitlab
小宋加油啊13 分钟前
深度学习小记(包括pytorch 还有一些神经网络架构)
pytorch·深度学习·神经网络
沛沛老爹16 分钟前
从线性到非线性:简单聊聊神经网络的常见三大激活函数
人工智能·深度学习·神经网络·激活函数·relu·sigmoid·tanh
0x21124 分钟前
[论文阅读]ReAct: Synergizing Reasoning and Acting in Language Models
人工智能·语言模型·自然语言处理
何大春34 分钟前
【视频时刻检索】Text-Video Retrieval via Multi-Modal Hypergraph Networks 论文阅读
论文阅读·深度学习·神经网络·计算机视觉·视觉检测·论文笔记
mucheni36 分钟前
迅为iTOP-RK3576开发板/核心板6TOPS超强算力NPU适用于ARM PC、边缘计算、个人移动互联网设备及其他多媒体产品
arm开发·人工智能·边缘计算
Jamence37 分钟前
多模态大语言模型arxiv论文略读(三十六)
人工智能·语言模型·自然语言处理
猿饵块1 小时前
opencv--图像变换
人工智能·opencv·计算机视觉
LucianaiB1 小时前
【金仓数据库征文】_AI 赋能数据库运维:金仓KES的智能化未来
运维·数据库·人工智能·金仓数据库 2025 征文·数据库平替用金仓
jndingxin1 小时前
OpenCV 图形API(63)图像结构分析和形状描述符------计算图像中非零像素的边界框函数boundingRect()
人工智能·opencv·计算机视觉