TensorFlow2 Python深度学习 - 生成对抗网络(GAN)实例

锋哥原创的TensorFlow2 Python深度学习视频教程:

https://www.bilibili.com/video/BV1X5xVz6E4w/

课程介绍

本课程主要讲解基于TensorFlow2的Python深度学习知识,包括深度学习概述,TensorFlow2框架入门知识,以及卷积神经网络(CNN),循环神经网络(RNN),生成对抗网络(GAN),模型保存与加载等。

TensorFlow2 Python深度学习 - 生成对抗网络(GAN)实例

我们以生成手写数字数据集为示例:

复制代码
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import keras
from keras import layers, Input
import matplotlib.pyplot as plt
import time
​
# 使用手写字体或单品样本做训练  这里注意的是 我们只需要训练数据,不需要答案和测试数据集。
(train_images, _), (_, _) = keras.datasets.mnist.load_data()
​
# 因为卷积层的需求,增加色深维度
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
# 规范化为-1 - +1
train_images = (train_images - 127.5) / 127.5
​
BUFFER_SIZE = 60000  # 以供60000个样本
BATCH_SIZE = 256  # 256张为一组
# 创建数据集
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
​
​
# 生成器网络
def make_generator_model():  # 根据长度为100的随机数组,生成一张28,28,1的矩阵
    model = tf.keras.Sequential()
    model.add(Input(shape=(100,)))
    # 全联接层,输入纬度为[[100],[n]],  输出为7*7*256 = 12544的节点  use_bias=False不使用偏差
    model.add(layers.Dense(7 * 7 * 256, use_bias=False))
    # BatchNormalization层:该层在每个batch上将前一层的激活值重新规范化,即使得其输出数据的均值接近0,其标准差接近1
    # 该层作用:(1)加速收敛(2)控制过拟合,可以少用或不用Dropout和正则(3)降低网络对初始化权重不敏感(4)允许使用较大的学习率
    model.add(layers.BatchNormalization())
    # ReLU是将所有的负值都设为零,相反,Leaky ReLU是给所有负值赋予一个非零斜率(负数)
    model.add(layers.LeakyReLU())
    # 将平铺的节点转为7*7*256的shape
    model.add(layers.Reshape((7, 7, 256)))
    # 通俗的讲这个解卷积,也就做反卷积,也叫做转置卷积(最贴切),我们就叫做反卷积吧,它的目的就是卷积的反向操作
    # 个人理解,正常的卷积是提取卷积核特征,反卷积就是用卷积核反向修改图像,风格迁移应该也是这么回事,那么问题来了在这个gan中,卷积特征从哪来?
    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    # 64, (5, 5), strides=(2, 2), 希望得到64个特征核,步长2,2
    # model.output_shape == (None, 14, 14, 64) 输出的节点数64就是上面的特征核,由于padding='same',所以卷积后无变化,
    # 14,14 是因为步长 2,2  所以7*2
    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    return model
​
​
# 判别器网络
def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(Input(shape=(28, 28, 1)))
    # 将 28.28.1的图像卷积 输出64个节点
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))
    # 接着卷积出128个节点
    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    # 激活函数 为非0的斜率
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))
    # 平铺 并输出一个数字
    model.add(layers.Flatten())
    model.add(layers.Dense(1))
    return model
​
​
generator = make_generator_model()
discriminator = make_discriminator_model()
​
# 交叉熵损失函数
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
​
​
# 辨别模型损失函数
def discriminator_loss(real_output, fake_output):
    # 样本图希望结果趋近1
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    # 自己生成的图希望结果趋近0
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    # 总损失
    total_loss = real_loss + fake_loss
    return total_loss
​
​
# 生成模型的损失函数
def generator_loss(fake_output):
    # 生成模型期望最终的结果越来越接近1,也就是真实样本
    return cross_entropy(tf.ones_like(fake_output), fake_output)
​
​
# 优化器
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
​
EPOCHS = 100  # 训练轮数
noise_dim = 100  # 噪声向量的维度
num_examples_to_generate = 16  # 生成图片数量
​
# 初始化16个种子向量,用于生成4x4的图片  seed shape: 16, 100
seed = tf.random.normal([num_examples_to_generate, noise_dim])
​
​
def train_step(images):  # 更新 模型权重数据的核心方法
    # 随机生成一个批次的种子向量 BATCH_SIZE = 256   noise_dim = 100  ,256个长度为100的噪音响亮
    noise = tf.random.normal([BATCH_SIZE, noise_dim])  # noise shape:[256],[100]
​
    # 查看每一次epoch参数更新  这个GradientTape 是每次梯度更新都会调用的,这个取代了model.fit的训练计算
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
        # 生成一个批次的图片
        generated_images = generator(noise, training=True)
​
        # 辨别一个批次的真实样本
        real_output = discriminator(images, training=True)
        # 辨别一个批次的生成图片
        fake_output = discriminator(generated_images, training=True)
​
        # 计算两个损失值
        gen_loss = generator_loss(fake_output)
        disc_loss = discriminator_loss(real_output, fake_output)
​
    # 根据损失值调整模型的权重参量
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
​
    # 计算出的参量应用到模型   梯度修剪,用于改变值, 梯度修剪主要避免训练梯度爆炸和消失问题
    # zIP是个格式转换函数 例如:a = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]; zip(*a) = [(1, 4, 7), (2, 5, 8), (3, 6, 9)]
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
​
​
# 训练
def train(dataset, epochs):
    for epoch in range(epochs + 1):
        start = time.time()
​
        # 训练
        for image_batch in dataset:
            train_step(image_batch)
​
        # 保存图片
        # 每个训练批次生成一张图片作为阶段成功
        print("=======================================")
        generate_and_save_images(generator, epoch + 1, seed)
​
        print('Time for epoch {} is {} sec'.format(epoch + 1, time.time() - start))
​
​
# 生成图片
def generate_and_save_images(model, epoch, test_input):
    # 设置为非训练状态,生成一组图片
    predictions = model(test_input, training=False)
​
    # 4格x4格拼接
    for i in range(predictions.shape[0]):
        plt.subplot(4, 4, i + 1)
        plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
        plt.axis('off')
​
    # 保存为png
    plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    plt.close()
​
​
# 以训练模式运行,进入训练状态
train(train_dataset, EPOCHS)

运行会生成100个训练图片,每个图片有16个数字小图。

越后面的图片,数字辨识度越高。

第1张,基本无法识别。

第16张,稍微有点辨识度:

第70张,基本有辨识度了:

相关推荐
忘忧记4 小时前
excel拆分和合并代码的思路整合和工具打包
python
天才测试猿4 小时前
黑盒测试用例的四种设计方法
自动化测试·软件测试·python·功能测试·测试工具·职场和发展·测试用例
B站_计算机毕业设计之家4 小时前
机器学习:基于大数据的基金数据分析可视化系统 股票数据 金融数据 股价 Django框架 大数据技术(源码) ✅
大数据·python·金融·数据分析·股票·etf·基金
CoovallyAIHub5 小时前
一夜之间,大模型处理长文本的难题被DeepSeek新模型彻底颠覆!
深度学习·算法·计算机视觉
*才华有限公司*5 小时前
《爬虫进阶之路:从模拟浏览器到破解动态加载的实战指南》
开发语言·python
深蓝电商API5 小时前
爬虫+Redis:如何实现分布式去重与任务队列?
redis·分布式·爬虫·python
我是华为OD~HR~栗栗呀5 小时前
华为OD-23届考研-测试面经
java·c++·python·华为od·华为·面试·单元测试
gc_22995 小时前
学习Python中Selenium模块的基本用法(20:安装Selenium IDE)
python·selenium
AI technophile5 小时前
OpenCV计算机视觉实战(27)——深度学习与卷积神经网络
深度学习·opencv·计算机视觉