GAN初识

1. 生成对抗网络GAN简介

1.1 生成器

G(Z)接受随机噪声Z作为输入生成仿品,并训练自己去欺骗判别器D,让D以为G(Z)产生的任何数据都是真实的。

1.2 判别器

D(Y)可以基于真品和仿品来判断仿造品的仿真程度,通常值越靠近0表示越真实(靠近1表示仿造)。其目标是:使每个真实数据分布中的图像的D(Y)值最大化,并使真实数据分布之外的图像D(Y)值最小化。

1.3 训练原理

生成器和判别器进行一个相对立的博弈,因此名为对抗性训练。一般以交替方式训练G和D,其目标函数表示为损失函数,并用梯度下降算法进行优化。D和G都通过反向传播来调整生成器的参数,这样G就能够学习如何在越来越多的情况下欺骗D,最后G将学习如何生成可以以假乱真的仿造图片。

2. 实现DCGAN网络

小知识:

(1)LeakyReLU允许单元未激活时有个小的梯度,在许多情况下,它能提高GAN的性能。

(2)BatchNormalization() 批归一化通过将每个单元的输入归一化为0均值、单位方差,来帮助稳定学习的技术。在许多情况下加快了训练,减少了初始化不良的问题 ,并且能更普遍地产生更准确的结果。

2.1 定义生成器G

python 复制代码
from keras.models import Sequential
from keras.layers.convolutional import Conv2D, MaxPooling2D
from keras.layers.core import Activation, Flatten, Dense, Dropout
from keras.datasets import cifar10
from keras.utils import np_utils
from keras.optimizers import RMSprop

def generator_model():
    model = Sequential()
    model.add(Dense(input_dim=100, output_dim=1024))
    model.add(Activation('tanh'))
    model.add(Dense(128*7*7))
    model.add(BatchNormalization())
    model.add(Activation('tanh'))
    model.add(Reshape((128, 7, 7), input_shape=(128*7*7,)))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(64, 5, 5, border_mode='same'))
    model.add(Activation('tanh'))
    model.add(UpSampling2D(size=(2, 2)))
    model.add(Conv2D(1, 5, 5, border_mode='same'))
    model.add(Activation('tanh'))
    return model
    # 注意:这里的卷积网络没有任何池化操作

2.2 定义判别器D

python 复制代码
def discriminator_model():
    model = Sequential()
    model.add(Dense(input_dim=100, output_dim=1024))
    model.add(Conv2D(64, 5, 5, border_mode='same', input_shape=(1, 28, 28)))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size(2, 2)))
    model.add(Conv2D(128, 5, 5))
    model.add(Activation('tanh'))
    model.add(MaxPooling2D(pool_size(2, 2)))
    model.add(Flatten())
    model.add(Dense(1024))
    model.add(Activation('tanh'))
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    return model

2.3 组合

python 复制代码
# 一个对抗模型,其 G和 D基于相同的模型 M
adversarial_model = AdversarialModel(base_model=M,
                                     player_params=[generator.trainable_weights, discriminator.trainable_weights],
                                     player_names=['generator', 'discriminator'])
# G和D基于两个不同的模型
adversarial_model = AdversarialModel(player_models=[gan_g, gan_d],
                                     player_params=[generator.trainable_weights, discriminator.trainable_weights],
                                     player_names=['generator', 'discriminator'])
相关推荐
kebijuelun37 分钟前
百度文心 4.5 大模型详解:ERNIE 4.5 Technical Report
人工智能·深度学习·百度·语言模型·自然语言处理·aigc
吹风看太阳2 小时前
机器学习16-总体架构
人工智能·机器学习
微学AI3 小时前
遥感影像岩性分类:基于CNN与CNN-EL集成学习的深度学习方法
深度学习·分类·cnn
IT古董3 小时前
【第三章:神经网络原理详解与Pytorch入门】02.深度学习框架PyTorch入门-(5)PyTorch 实战——使用 RNN 进行人名分类
pytorch·深度学习·神经网络
AI生存日记4 小时前
AI 行业早报:微软发布诊断工具,上海聚焦四大应用场景
人工智能·microsoft·机器学习·open ai大模型
视觉语言导航5 小时前
ICCV-2025 | 复杂场景的精准可控生成新突破!基于场景图的可控 3D 户外场景生成
人工智能·深度学习·具身智能
AI街潜水的八角7 小时前
深度学习图像分类数据集—濒危动物识别分类
人工智能·深度学习
安思派Anspire8 小时前
LangGraph + MCP + Ollama:构建强大代理 AI 的关键(一)
前端·深度学习·架构
FF-Studio8 小时前
大语言模型(LLM)课程学习(Curriculum Learning)、数据课程(data curriculum)指南:从原理到实践
人工智能·python·深度学习·神经网络·机器学习·语言模型·自然语言处理
狗头大军之江苏分军8 小时前
疑似华为盘古AI大模型翻车造假风波【实时记录篇】
人工智能·机器学习·程序员