【机器学习】生成对抗网络 – Generative Adversarial Networks(GAN)的基本概念和代码示例

引言

生成对抗网络(GANs)是深度学习中的一种创新模型,自2014年由Ian Goodfellow等人首次提出以来,已成为深度学习领域中最活跃的研究方向之一

文章目录

  • 引言
  • 一、GANs的基本概念
    • [1.1 GANs的基本原理和结构](#1.1 GANs的基本原理和结构)
      • [1.1.1 生成器](#1.1.1 生成器)
      • [1.1.2 判别器](#1.1.2 判别器)
      • [1.1.3 大白话版本](#1.1.3 大白话版本)
    • [1.2 GANs的训练过程](#1.2 GANs的训练过程)
    • [1.3 GAN的优缺点](#1.3 GAN的优缺点)
      • [1.3.1 优点](#1.3.1 优点)
      • [1.3.2 缺陷](#1.3.2 缺陷)
    • [1.4 GANs的变体](#1.4 GANs的变体)
      • [1.4.1 Conditional GAN (cGAN)](#1.4.1 Conditional GAN (cGAN))
      • [1.4.2 Deep Convolutional GAN (DCGAN)](#1.4.2 Deep Convolutional GAN (DCGAN))
      • [1.4.3 Wasserstein GAN (WGAN)](#1.4.3 Wasserstein GAN (WGAN))
      • [1.4.4 StyleGAN](#1.4.4 StyleGAN)
    • [1.5 GANs的应用领域](#1.5 GANs的应用领域)
    • [1.6 总结](#1.6 总结)
    • [1.7 GAN代码实例](#1.7 GAN代码实例)
      • [1.7.1 代码](#1.7.1 代码)
      • [1.7.2 代码解释](#1.7.2 代码解释)

一、GANs的基本概念

1.1 GANs的基本原理和结构

1.1.1 生成器

生成器的任务是生成新的数据样本。它从随机噪声中学习生成与真实数据相似的数据。在训练过程中,生成器和判别器进行对抗训练,生成器不断优化生成的数据样本,以达到欺骗判别器的目的

1.1.2 判别器

判别器的任务则是判断输入数据是否真实,即尝试区分生成的数据和真实数据,判别器则努力提高区分真实与生成数据的能力。这种竞争推动了模型的不断进化,使得生成的数据逐渐接近真实数据分布

1.1.3 大白话版本

  1. 假设一个城市治安混乱,很快,这个城市里就会出现无数的小偷。在这些小偷中,有的可能是盗窃高手,有的可能毫无技术可言。假如这个城市开始整饬其治安,突然开展一场打击犯罪的"运动",警察们开始恢复城市中的巡逻,很快,一批"学艺不精"的小偷就被捉住了。之所以捉住的是那些没有技术含量的小偷,是因为警察们的技术也不行了,在捉住一批低端小偷后,城市的治安水平变得怎样倒还不好说,但很明显,城市里小偷们的平均水平已经大大提高了

  2. 警察严打导致小偷水平提升

    警察们开始继续训练自己的破案技术,开始抓住那些越来越狡猾的小偷。随着这些职业惯犯们的落网,警察们也练就了特别的本事,他们能很快能从一群人中发现可疑人员,于是上前盘查,并最终逮捕嫌犯;小偷们的日子也不好过了,因为警察们的水平大大提高,如果还想以前那样表现得鬼鬼祟祟,那么很快就会被警察捉住

  3. 经常提升技能,更多小偷被抓

    为了避免被捕,小偷们努力表现得不那么"可疑",而魔高一尺、道高一丈,警察也在不断提高自己的水平,争取将小偷和无辜的普通群众区分开。随着警察和小偷之间的这种"交流"与"切磋",小偷们都变得非常谨慎,他们有着极高的偷窃技巧,表现得跟普通群众一模一样,而警察们都练就了"火眼金睛",一旦发现可疑人员,就能马上发现并及时控制

  4. 最终,我们同时得到了最强的小偷和最强的警察

1.2 GANs的训练过程

  1. 初始化生成器和判别器的权重
  2. 在一个批次的数据上训练判别器,使其能够区分真实数据和生成数据
  3. 训练生成器,使其生成的假数据能够欺骗判别器,提高判别器的错误率
  4. 重复步骤2和3,直到达到预设的训练轮数或满足一定的性能指标

1.3 GAN的优缺点

1.3.1 优点

  1. 能更好建模数据分布(图像更锐利、清晰)
  2. 理论上,GANs 能训练任何一种生成器网络。其他的框架需要生成器网络有一些特定的函数形式,比如输出层是高斯的
  3. 无需利用马尔科夫链反复采样,无需在学习过程中进行推断,没有复杂的变分下界,避开近似计算棘手的概率的难题

1.3.2 缺陷

  1. 难训练,不稳定。生成器和判别器之间需要很好的同步,但是在实际训练中很容易D收敛,G发散。D/G 的训练需要精心的设计
  2. 模式缺失(Mode Collapse)问题。GANs的学习过程可能出现模式缺失,生成器开始退化,总是生成同样的样本点,无法继续学习

1.4 GANs的变体

为了改善其训练稳定性、提高生成质量、扩展应用范围等,研究人员提出了许多GAN的变体

1.4.1 Conditional GAN (cGAN)

引入条件变量,生成特定类别的样本

1.4.2 Deep Convolutional GAN (DCGAN)

使用卷积层和反卷积层,提高图像生成的质量和稳定性

1.4.3 Wasserstein GAN (WGAN)

改变损失函数,使用Wasserstein距离,改善训练稳定性和模式覆盖率

1.4.4 StyleGAN

引入风格分离的概念,控制生成图像的局部属性

1.5 GANs的应用领域

GANs的应用非常广泛,包括但不限于:

  • 图像生成:风格迁移、超分辨率、人脸生成等
  • 数据增强:生成额外的样本以增强训练集
  • 医学图像分析:生成医学图像以辅助诊断
  • 声音合成:生成或修改语音信号
  • 化学分子设计:设计新的分子结构,加速药物研发和材料设计过程

1.6 总结

在实践中,训练GANs需要注意权重的初始化、优化器的选择等因素,以确保训练的稳定性和效果

1.7 GAN代码实例

1.7.1 代码

python 复制代码
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time    : 2024/8/15 16:34
# @Software: PyCharm
# @Author  : xialiwei
# @Email   : xxxxlw198031@163.com
import time

import numpy as np
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Reshape
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import LeakyReLU
import matplotlib.pyplot as plt

# 加载MNIST数据集
(x_train, _), (_, _) = mnist.load_data()

# 数据预处理
x_train = (x_train.astype(np.float32) - 127.5) / 127.5
x_train = np.expand_dims(x_train, axis=3)


# 生成器模型
def build_generator(z_dim):
    model = Sequential()
    model.add(Dense(128, input_dim=z_dim))
    model.add(LeakyReLU(alpha=0.01))
    model.add(Dense(28 * 28 * 1, activation='tanh'))
    model.add(Reshape((28, 28, 1)))
    return model


# 判别器模型
def build_discriminator(img_shape):
    model = Sequential()
    model.add(Flatten(input_shape=img_shape))
    model.add(Dense(128))
    model.add(LeakyReLU(alpha=0.01))
    model.add(Dense(1, activation='sigmoid'))
    return model


# GAN模型
def build_gan(generator, discriminator):
    model = Sequential()
    model.add(generator)
    model.add(discriminator)
    return model


# 设置参数
z_dim = 100
img_shape = (28, 28, 1)

# 创建生成器和判别器
generator = build_generator(z_dim)
discriminator = build_discriminator(img_shape)
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy'])

# 对生成器进行编译,但在训练GAN模型时不训练判别器部分
discriminator.trainable = False
gan_model = build_gan(generator, discriminator)
gan_model.compile(loss='binary_crossentropy', optimizer=Adam())

# 训练GAN
epochs = 100
batch_size = 256
sample_interval = 50

# 训练过程中真样本标签为1,假样本标签为0
real = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))

for epoch in range(epochs):
    # 训练判别器
    idx = np.random.randint(0, x_train.shape[0], batch_size)
    real_imgs = x_train[idx]

    z = np.random.normal(0, 1, (batch_size, z_dim))
    fake_imgs = generator.predict(z)

    d_loss_real = discriminator.train_on_batch(real_imgs, real)
    d_loss_fake = discriminator.train_on_batch(fake_imgs, fake)
    d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # 训练生成器
    z = np.random.normal(0, 1, (batch_size, z_dim))
    g_loss = gan_model.train_on_batch(z, real)

    if epoch % sample_interval == 0:
        print(f"{epoch} [D loss: {d_loss[0]}, acc.: {100 * d_loss[1]}] [G loss: {g_loss}]")
        # 这里可以添加代码来保存或显示生成的图片样本
        # 显示生成的图片样本
        z = np.random.normal(0, 1, (25, z_dim))  # 生成25个样本
        gen_imgs = generator.predict(z)

        # 将生成的图片数据转换为0-1范围
        gen_imgs = 0.5 * gen_imgs + 0.5

        # 绘制生成的图片
        fig, axs = plt.subplots(5, 5)
        cnt = 0
        for i in range(5):
            for j in range(5):
                axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
                axs[i, j].axis('off')
                cnt += 1
        plt.show()  # 显示图形
        plt.pause(3)  # 暂停3秒
        plt.close(fig)  # 关闭图形,防止阻塞
# 保存模型
generator.save('generator.h5')
discriminator.save('discriminator.h5')

输出结果:

1.7.2 代码解释

这段代码是一个生成对抗网络(GAN)的完整实现,用于生成类似于MNIST数据集中的手写数字图像

相关推荐
DARLING Zero two♡4 分钟前
【优选算法】D&C-Mergesort-Harmonies:分治-归并的算法之谐
java·数据结构·c++·算法·leetcode
gfdgd xi6 分钟前
deepin 终端,但是版本是 deepin 15 的
linux·python·架构·ssh·bash·shell·deepin
禁默7 分钟前
基于金仓KFS工具,破解多数据并存,浙人医改造实战医疗信创
数据库·人工智能·金仓数据库
CoovallyAIHub9 分钟前
万字详解:多目标跟踪(MOT)终极指南
深度学习·算法·计算机视觉
云卓SKYDROID13 分钟前
无人机动力学模块技术要点与难点
人工智能·无人机·材质·高科技·云卓科技
王六岁14 分钟前
🐍 前端开发 0 基础学 Python 入门指南:条件语句篇
前端·python
java1234_小锋14 分钟前
PyTorch2 Python深度学习 - 初识PyTorch2,实现一个简单的线性神经网络
开发语言·python·深度学习·pytorch2
胡萝卜3.015 分钟前
C++面向对象继承全面解析:不能被继承的类、多继承、菱形虚拟继承与设计模式实践
开发语言·c++·人工智能·stl·继承·菱形继承·组合vs继承
Danileaf_Guo26 分钟前
29瓦功耗运行140亿参数模型!Mac mini M4的AI能效革命
人工智能
小草cys28 分钟前
华为910B服务器(搭载昇腾Ascend 910B AI 芯片的AI服务器查看服务器终端信息
服务器·人工智能·华为·昇腾·910b