【学习笔记】GAN生成对抗神经网络原理与实践

最早在2014年Ian J. Goodfellow等人提出的GAN。

文献为:Generative Adversarial Nets

GAN面临的主要挑战有模型训练困难,容易出现生成模型坍塌等问题。因为GAN是采用生成对抗策略来训练的,优化生成模型必然导致判别模型的损失增大。

定义

生成对抗神经网络(Generative Adversarial Nets,GAN)是一种深度学习的框架,它通过相互对抗的过程来完成模型训练的。

环境配置

使用Docker或者通过conda安装能够实现环境隔离的安装,这两种方式最大的优点是可以让多个版本的TensorFlow共存,不会发生冲突,并且这两种安装方式在不同操作系统中的操作方法基本是一致的。

通过conda安装

  1. 安装Anaconda软件

  2. 创建虚拟环境

     conda create -n gan python=3.7
    
  3. 激活虚拟环境

     activate gan
    
  4. 安装TensorFlow

     pip install tensorflow==2.0
    

深层神经网络简介

深层神经网络是指包含多个模型层的神经网络,其特点是每一层神经网络中的神经元都与前一个网络层中的全部神经元相连,因此其也成为全连接层神经网络。同时,由于这种连接方式导致连接非常密集(Dense),因此也成为密集连接神经网络。

模型架构

一个深层神经网络一般会包含多个全连接层,每个全连接层中又包含多个神经元。深层审核网络的结构如图所示,深层神经网络从输入层开始,中间往往包含多个隐藏层,直到最后的输出层。

每个全连接层中都包含多个神经元,除了输入层之外,每个网络层中的每一个神经元都与前一层的所有神经元相连。也就是说,对于任何一个神经元来说,前一层所有神经元的输出都会作为它的输入,同时它的输出又会作为下一个网络层所有神经元的输入。

实现原理

信号前向传播是指将输入变量赋值给输入层神经元,然后使用前向传播算法计算出下一个网络层的神经元取值,逐层向前直到输出层。误差反向传播就是指采用损失函数计算出输出结果与样本中输出变量的差别,即误差,再计算损失函数对参数的偏导数,沿着误差降低的方向调整参数,逐步减小误差。反复执行参数优化,直到误差足够小,最终完成模型训练。

训练过程

深层神经网络训练步骤如图所示。

  1. 构建模型:构建一个能够将输入变量转换为输出变量的深层神经网络,具体来说就是输入/输出层的神经元数量分别与输入变量和输出变量的个数一致。中间隐藏层神经元的数量根据样本数据规模及业务经验设置。
  2. 初始化参数:生成随机数,对模型所需要的参数进行初始化。
  3. 信号前向传播:将输入变量赋值给输入层,然后按照前向传播算法向前逐层计算出各个网络层的输出值,直到计算出输出的输出结果。
  4. 计算误差:使用损失函数计算出模型的输出结果与样本数据中输出变量的差异。
  5. 误差反向传播:采用链式求导方法,计算损失函数对参数的偏导数,即 Δ θ \Delta \theta Δθ。
  6. 优化参数:根据公式 θ n e w = θ o l d − η Δ θ \theta_{new}=\theta_{old}-\eta\Delta \theta θnew=θold−ηΔθ来更新参数,其中 η \eta η代表学习率(Learn Rate)。
  7. 完成模型训练:反复执行步骤3~6,直到误差无穷小。

TensorFlow 2.0开发入门

TensorFlow 2.0内置Keras(tf.keras)

开发流程

  1. 使用tf.data加载数据。
  2. 使用tf.keras或者预置的Estimator构建训练和验证模型。
  3. Eager运行模式。
  4. 分布式训练。

张量

张量是TensorFlow开发中的数据类型,也是唯一的数据类型。

张量(Tensor)是向量和矩阵的泛化,可以有n个维度。

在TensorFlow内部,张量是通过n维数组来实现的。

张量的定义包含三个要素,分识是阶(Rank)、形状(Shape)和数据类型(Data Type)。阶定义了张量的维度,形状定义了张量在各个维度上的长度,数据类型定义了张量中每个元素的数据类型。注意,一个张量所有元素的数据类型都必须相同,在一个张量中只允许一种数据类型。

相关推荐
澜世12 分钟前
2024小迪安全基础入门第七课
网络·笔记·安全·网络安全
weixin_4786897621 分钟前
【二叉树】【2.1遍历二叉树】【刷题笔记】【灵神题单】
笔记
wzx_Eleven26 分钟前
【课堂笔记】隐私计算实训营第四期:“隐语”可信隐私计算开源框架
笔记
guihong0041 小时前
JAVA面试题、八股文学习之JVM篇
java·jvm·学习
CQXXCL1 小时前
MySQL-学习笔记
笔记·学习·mysql
多喝开水少熬夜1 小时前
FedGraph: Federated Graph Learning With Intelligent Sampling论文阅读
学习·论文·联邦学习
Lostgreen2 小时前
分布式查询处理优化之数据分片
大数据·笔记·分布式
hillstream32 小时前
gitlab工作笔记
笔记·gitlab
芯纪元2 小时前
Perl编程语言简介
笔记·perl