【机器学习】生成对抗网络 – Generative Adversarial Networks(GAN)的基本概念和代码示例

引言

生成对抗网络(GANs)是深度学习中的一种创新模型,自2014年由Ian Goodfellow等人首次提出以来,已成为深度学习领域中最活跃的研究方向之一

文章目录

  • 引言
  • 一、GANs的基本概念
    • [1.1 GANs的基本原理和结构](#1.1 GANs的基本原理和结构)
      • [1.1.1 生成器](#1.1.1 生成器)
      • [1.1.2 判别器](#1.1.2 判别器)
      • [1.1.3 大白话版本](#1.1.3 大白话版本)
    • [1.2 GANs的训练过程](#1.2 GANs的训练过程)
    • [1.3 GAN的优缺点](#1.3 GAN的优缺点)
      • [1.3.1 优点](#1.3.1 优点)
      • [1.3.2 缺陷](#1.3.2 缺陷)
    • [1.4 GANs的变体](#1.4 GANs的变体)
      • [1.4.1 Conditional GAN (cGAN)](#1.4.1 Conditional GAN (cGAN))
      • [1.4.2 Deep Convolutional GAN (DCGAN)](#1.4.2 Deep Convolutional GAN (DCGAN))
      • [1.4.3 Wasserstein GAN (WGAN)](#1.4.3 Wasserstein GAN (WGAN))
      • [1.4.4 StyleGAN](#1.4.4 StyleGAN)
    • [1.5 GANs的应用领域](#1.5 GANs的应用领域)
    • [1.6 总结](#1.6 总结)
    • [1.7 GAN代码实例](#1.7 GAN代码实例)
      • [1.7.1 代码](#1.7.1 代码)
      • [1.7.2 代码解释](#1.7.2 代码解释)

一、GANs的基本概念

1.1 GANs的基本原理和结构

1.1.1 生成器

生成器的任务是生成新的数据样本。它从随机噪声中学习生成与真实数据相似的数据。在训练过程中,生成器和判别器进行对抗训练,生成器不断优化生成的数据样本,以达到欺骗判别器的目的

1.1.2 判别器

判别器的任务则是判断输入数据是否真实,即尝试区分生成的数据和真实数据,判别器则努力提高区分真实与生成数据的能力。这种竞争推动了模型的不断进化,使得生成的数据逐渐接近真实数据分布

1.1.3 大白话版本

  1. 假设一个城市治安混乱,很快,这个城市里就会出现无数的小偷。在这些小偷中,有的可能是盗窃高手,有的可能毫无技术可言。假如这个城市开始整饬其治安,突然开展一场打击犯罪的"运动",警察们开始恢复城市中的巡逻,很快,一批"学艺不精"的小偷就被捉住了。之所以捉住的是那些没有技术含量的小偷,是因为警察们的技术也不行了,在捉住一批低端小偷后,城市的治安水平变得怎样倒还不好说,但很明显,城市里小偷们的平均水平已经大大提高了

  2. 警察严打导致小偷水平提升

    警察们开始继续训练自己的破案技术,开始抓住那些越来越狡猾的小偷。随着这些职业惯犯们的落网,警察们也练就了特别的本事,他们能很快能从一群人中发现可疑人员,于是上前盘查,并最终逮捕嫌犯;小偷们的日子也不好过了,因为警察们的水平大大提高,如果还想以前那样表现得鬼鬼祟祟,那么很快就会被警察捉住

  3. 经常提升技能,更多小偷被抓

    为了避免被捕,小偷们努力表现得不那么"可疑",而魔高一尺、道高一丈,警察也在不断提高自己的水平,争取将小偷和无辜的普通群众区分开。随着警察和小偷之间的这种"交流"与"切磋",小偷们都变得非常谨慎,他们有着极高的偷窃技巧,表现得跟普通群众一模一样,而警察们都练就了"火眼金睛",一旦发现可疑人员,就能马上发现并及时控制

  4. 最终,我们同时得到了最强的小偷和最强的警察

1.2 GANs的训练过程

  1. 初始化生成器和判别器的权重
  2. 在一个批次的数据上训练判别器,使其能够区分真实数据和生成数据
  3. 训练生成器,使其生成的假数据能够欺骗判别器,提高判别器的错误率
  4. 重复步骤2和3,直到达到预设的训练轮数或满足一定的性能指标

1.3 GAN的优缺点

1.3.1 优点

  1. 能更好建模数据分布(图像更锐利、清晰)
  2. 理论上,GANs 能训练任何一种生成器网络。其他的框架需要生成器网络有一些特定的函数形式,比如输出层是高斯的
  3. 无需利用马尔科夫链反复采样,无需在学习过程中进行推断,没有复杂的变分下界,避开近似计算棘手的概率的难题

1.3.2 缺陷

  1. 难训练,不稳定。生成器和判别器之间需要很好的同步,但是在实际训练中很容易D收敛,G发散。D/G 的训练需要精心的设计
  2. 模式缺失(Mode Collapse)问题。GANs的学习过程可能出现模式缺失,生成器开始退化,总是生成同样的样本点,无法继续学习

1.4 GANs的变体

为了改善其训练稳定性、提高生成质量、扩展应用范围等,研究人员提出了许多GAN的变体

1.4.1 Conditional GAN (cGAN)

引入条件变量,生成特定类别的样本

1.4.2 Deep Convolutional GAN (DCGAN)

使用卷积层和反卷积层,提高图像生成的质量和稳定性

1.4.3 Wasserstein GAN (WGAN)

改变损失函数,使用Wasserstein距离,改善训练稳定性和模式覆盖率

1.4.4 StyleGAN

引入风格分离的概念,控制生成图像的局部属性

1.5 GANs的应用领域

GANs的应用非常广泛,包括但不限于:

  • 图像生成:风格迁移、超分辨率、人脸生成等
  • 数据增强:生成额外的样本以增强训练集
  • 医学图像分析:生成医学图像以辅助诊断
  • 声音合成:生成或修改语音信号
  • 化学分子设计:设计新的分子结构,加速药物研发和材料设计过程

1.6 总结

在实践中,训练GANs需要注意权重的初始化、优化器的选择等因素,以确保训练的稳定性和效果

1.7 GAN代码实例

1.7.1 代码

python 复制代码
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2024/8/15 16:34
# @Software: PyCharm
# @Author  : xialiwei
# @Email   : xxxxlw198031@163.com
import time

import numpy as np
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import LeakyReLU
import matplotlib.pyplot as plt

# 加载MNIST数据集
(x_train, _), (_, _) = mnist.load_data()

# 数据预处理
x_train = (x_train.astype(np.float32) - 127.5) / 127.5
x_train = np.expand_dims(x_train, axis=3)


# 生成器模型
def build_generator(z_dim):
    model = Sequential()
    model.add(Dense(128, input_dim=z_dim))
    model.add(LeakyReLU(alpha=0.01))
    model.add(Dense(28 * 28 * 1, activation='tanh'))
    model.add(Reshape((28, 28, 1)))
    return model


# 判别器模型
def build_discriminator(img_shape):
    model = Sequential()
    model.add(Flatten(input_shape=img_shape))
    model.add(Dense(128))
    model.add(LeakyReLU(alpha=0.01))
    model.add(Dense(1, activation='sigmoid'))
    return model


# GAN模型
def build_gan(generator, discriminator):
    model = Sequential()
    model.add(generator)
    model.add(discriminator)
    return model


# 设置参数
z_dim = 100
img_shape = (28, 28, 1)

# 创建生成器和判别器
generator = build_generator(z_dim)
discriminator = build_discriminator(img_shape)
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])

# 对生成器进行编译,但在训练GAN模型时不训练判别器部分
discriminator.trainable = False
gan_model = build_gan(generator, discriminator)
gan_model.compile(loss='binary_crossentropy', optimizer=Adam())

# 训练GAN
epochs = 100
batch_size = 256
sample_interval = 50

# 训练过程中真样本标签为1,假样本标签为0
real = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

for epoch in range(epochs):
    # 训练判别器
    idx = np.random.randint(0, x_train.shape[0], batch_size)
    real_imgs = x_train[idx]

    z = np.random.normal(0, 1, (batch_size, z_dim))
    fake_imgs = generator.predict(z)

    d_loss_real = discriminator.train_on_batch(real_imgs, real)
    d_loss_fake = discriminator.train_on_batch(fake_imgs, fake)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # 训练生成器
    z = np.random.normal(0, 1, (batch_size, z_dim))
    g_loss = gan_model.train_on_batch(z, real)

    if epoch % sample_interval == 0:
        print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}] [G loss: {g_loss}]")
        # 这里可以添加代码来保存或显示生成的图片样本
        # 显示生成的图片样本
        z = np.random.normal(0, 1, (25, z_dim))  # 生成25个样本
        gen_imgs = generator.predict(z)

        # 将生成的图片数据转换为0-1范围
        gen_imgs = 0.5 * gen_imgs + 0.5

        # 绘制生成的图片
        fig, axs = plt.subplots(5, 5)
        cnt = 0
        for i in range(5):
            for j in range(5):
                axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
                axs[i, j].axis('off')
                cnt += 1
        plt.show()  # 显示图形
        plt.pause(3)  # 暂停3秒
        plt.close(fig)  # 关闭图形,防止阻塞
# 保存模型
generator.save('generator.h5')
discriminator.save('discriminator.h5')

输出结果:

1.7.2 代码解释

这段代码是一个生成对抗网络(GAN)的完整实现,用于生成类似于MNIST数据集中的手写数字图像

相关推荐
Topstip3 分钟前
Gemini 对话机器人加入开源盲水印技术来检测 AI 生成的内容
人工智能·ai·机器人
SEEONTIME6 分钟前
python-24-一篇文章彻底掌握Python HTTP库Requests
开发语言·python·http·http库requests
Bearnaise6 分钟前
PointMamba: A Simple State Space Model for Point Cloud Analysis——点云论文阅读(10)
论文阅读·笔记·python·深度学习·机器学习·计算机视觉·3d
shymoy12 分钟前
Radix Sorts
数据结构·算法·排序算法
小嗷犬19 分钟前
【论文笔记】VCoder: Versatile Vision Encoders for Multimodal Large Language Models
论文阅读·人工智能·语言模型·大模型·多模态
风影小子20 分钟前
注册登录学生管理系统小项目
算法
黑龙江亿林等保22 分钟前
深入探索哈尔滨二级等保下的负载均衡SLB及其核心算法
运维·算法·负载均衡
Struart_R24 分钟前
LVSM: A LARGE VIEW SYNTHESIS MODEL WITH MINIMAL 3D INDUCTIVE BIAS 论文解读
人工智能·3d·transformer·三维重建
lucy1530275107925 分钟前
【青牛科技】GC5931:工业风扇驱动芯片的卓越替代者
人工智能·科技·单片机·嵌入式硬件·算法·机器学习
哇咔咔哇咔37 分钟前
【科普】conda、virtualenv, venv分别是什么?它们之间有什么区别?
python·conda·virtualenv