昇思25天学习打卡营第5天|GAN图像生成

文章目录

昇思MindSpore应用实践

本系列文章主要用于记录昇思25天学习打卡营的学习心得。

基于MindSpore的生成对抗网络图像生成
1、生成对抗网络简介
零和博弈 vs 极大极小博弈

生成对抗网络Generative adversarial networks (GANs)主要包括生成器网络(Generator)和判别器网络(Discriminator)

这两个网络在GAN的训练过程中相互竞争,形成了一种博弈论中的极大极小博弈(MinMax game)

零和博弈 (Zero-sum game)是博弈论中的一个重要概念,指的是参与者的利益完全相反,即一方的利益的增加意味着另一方的利益的减少,总利益为零。在零和博弈中,参与者之间的利益是完全对立的,因此一个参与者的利益的增加必然导致其他参与者的利益减少。在非合作博弈中,纳什均衡是一种重要的解,纳什均衡代表每个玩家选择的策略都是其在对方策略给定的情况下的最优策略。在零和博弈中,寻找纳什均衡通常涉及找到使每个玩家的预期收益最大化的策略组合。

极大极小博弈(MinMax game)是一种博弈论中的解决方法,用于确定参与者的最佳决策策略,此外为人所熟知用于决策的方法还有强化学习。在极大极小博弈中,每个参与者都试图最大化自己的最小收益。也就是说,每个参与者都采取行动,以确保在对手选择其最优策略时自己的收益最大化。

假设GAN网络训练达到了纳什平衡状态,那么判别器无法准确地判断出输入样本是真样本还是假样本,此时判别器失效,生成器达到了巅峰状态,我们就无需使用判别器并终止训练了,得到的生成器就是我们用来生成数据的预训练模型。

从理论上讲,此博弈游戏的平衡点是 p G ( x ; θ ) = p d a t a ( x ) p_{G}(x;\theta) = p_{data}(x) pG(x;θ)=pdata(x),此时判别器会随机猜测输入是真图像还是假图像。下面我们简要说明生成器和判别器的博弈过程:

  1. 在训练刚开始的时候,生成器和判别器的质量都比较差,生成器会随机生成一个数据分布;
  2. 判别器通过求取梯度和损失函数对网络进行优化,将接近真实数据分布的数据判定为1 ( D ( x ) = 1 D(x)=1 D(x)=1),将接近生成器生成数据分布数据判定为0 (( G ( z ) = 0 G(z)=0 G(z)=0)),即希望 min ⁡ G max ⁡ D V ( G , D ) \underset{G}{\min} \underset{D}{\max}V(G, D) GminDmaxV(G,D);
  3. 生成器通过优化,生成出更加贴近真实数据分布的数据;
  4. 生成器所生成的数据和真实数据达到相同的分布,此时判别器的输出为1/2,如上图中的(d)所示。
GAN的生成对抗损失

min ⁡ G max ⁡ D V ( G , D ) = E x ∼ p data ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ ( 1 − D ( G ( z ) ) ) ] \underset{G}{\min} \underset{D}{\max}V(G, D) = \mathbb{E}{x \sim p{\text{data}}(x)}[\log D(x)] + \mathbb{E}{z \sim p_z(z)}[\log(1 - D(G(z)))] GminDmaxV(G,D)=Ex∼pdata(x)[logD(x)]+Ez∼pz(z)[log(1−D(G(z)))]

GAN网络本身就是在训练一个能达到平衡状态的损失函数,生成对抗损失是GANs中最基本的损失函数。

当生成对抗损失达到纳什均衡时,判别器对真假数据的判别概率都是0.5,即 D ( x ) = 1 − G ( z ) = 0.5 D(x)=1-G(z)=0.5 D(x)=1−G(z)=0.5,

即 l o g ( D ( x ) ) = l o g ( 1 − G ( z ) ) ≈ 0.693 log(D(x))=log(1-G(z))\approx0.693 log(D(x))=log(1−G(z))≈0.693

由于数据x和G(z)不仅是一张图片,再分别取两者的均值 E \mathbb{E} E,相加,就得到了生成对抗损失。

近十年来著名的GAN网络结构

2、基于MindSpore的 Vanilla GAN

生成器部分:生成器 Generator 的功能是将隐码映射到数据空间。通过五层 Dense 全连接层来完成的,每层都与 BatchNorm1d 批归一化层和 ReLU 激活层配对,输出数据会经过 Tanh 函数,使其返回 [-1,1] 的数据范围内,并返回一张28x28的图像作为生成结果。

python 复制代码
from mindspore import nn
import mindspore.ops as ops

img_size = 28  # 训练图像长(宽)28x28

class Generator(nn.Cell):
    def __init__(self, latent_size, auto_prefix=True):
        super(Generator, self).__init__(auto_prefix=auto_prefix)
        self.model = nn.SequentialCell()
        # [N, 100] -> [N, 128]
        # 输入一个100维的0~1之间的高斯分布,通过第一层线性变换将其映射到128维
        self.model.append(nn.Dense(latent_size, 128))
        self.model.append(nn.ReLU())
        # 通过第二层线性变换将其映射到256维
        # [N, 128] -> [N, 256]
        self.model.append(nn.Dense(128, 256))
        self.model.append(nn.BatchNorm1d(256))
        self.model.append(nn.ReLU())
        # [N, 256] -> [N, 512]
        self.model.append(nn.Dense(256, 512))
        self.model.append(nn.BatchNorm1d(512))
        self.model.append(nn.ReLU())
        # [N, 512] -> [N, 1024]
        self.model.append(nn.Dense(512, 1024))
        self.model.append(nn.BatchNorm1d(1024))
        self.model.append(nn.ReLU())
        # [N, 1024] -> [N, 784]
        # 经过线性变换将其变成784维
        self.model.append(nn.Dense(1024, img_size * img_size))
        # 经过Tanh激活函数是希望生成的假的图片数据分布能够在-1~1之间
        self.model.append(nn.Tanh())

    def construct(self, x):
        img = self.model(x)
        return ops.reshape(img, (-1, 1, 28, 28))

net_g = Generator(latent_size)
net_g.update_parameters_name('generator')

判别器部分:判别器 Discriminator 是一个二分类网络模型,在训练时,判别器接收生成器的生成图像与对应的真实数据相对比,输出判定该图像为真实图的概率。主要通过一系列的 Dense 层和 LeakyReLU 层对其进行处理,最后通过 Sigmoid 激活函数,使其返回 [0, 1] 的数据范围内,得到最终概率。

python 复制代码
 # 判别器
class Discriminator(nn.Cell):
    def __init__(self, auto_prefix=True):
        super().__init__(auto_prefix=auto_prefix)
        self.model = nn.SequentialCell()
        # [N, 784] -> [N, 512]
        self.model.append(nn.Dense(img_size * img_size, 512))  # 输入特征数为784,输出为512
        self.model.append(nn.LeakyReLU())  # 默认斜率为0.2的非线性映射激活函数
        # [N, 512] -> [N, 256]
        self.model.append(nn.Dense(512, 256))  # 进行一个线性映射
        self.model.append(nn.LeakyReLU())
        # [N, 256] -> [N, 1]
        self.model.append(nn.Dense(256, 1))
        self.model.append(nn.Sigmoid())  # 二分类激活函数,将实数映射到[0,1]

    def construct(self, x):
        x_flat = ops.reshape(x, (-1, img_size * img_size))
        return self.model(x_flat)

net_d = Discriminator()
net_d.update_parameters_name('discriminator')
3、基于MindSpore的手写数字图像生成
导入数据
python 复制代码
import numpy as np
import mindspore.dataset as ds

batch_size = 128
latent_size = 100  # 潜在编码的长度

train_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/train')
test_dataset = ds.MnistDataset(dataset_dir='./MNIST_Data/test')

def data_load(dataset):
    dataset1 = ds.GeneratorDataset(dataset, ["image", "label"], shuffle=True, python_multiprocessing=False)
    
    # 数据增强
    mnist_ds = dataset1.map(  # 通过map方法给每张图像映射一个潜在编码
    # 将图像数据转换为 float32 类型
    # 生成一个长度为 latent_size 的服从正态分布的随机向量,并将其转换为 float32 类型
        operations=lambda x: (x.astype("float32"), np.random.normal(size=latent_size).astype("float32")),
        output_columns=["image", "latent_code"])
    mnist_ds = mnist_ds.project(["image", "latent_code"])

    # 批量操作
    mnist_ds = mnist_ds.batch(batch_size, True)

    return mnist_ds

mnist_ds = data_load(train_dataset)

iter_size = mnist_ds.get_dataset_size()
print('Iter size: %d' % iter_size)
数据可视化
python 复制代码
import matplotlib.pyplot as plt

data_iter = next(mnist_ds.create_dict_iterator(output_numpy=True))
figure = plt.figure(figsize=(3, 3))
cols, rows = 5, 5
for idx in range(1, cols * rows + 1):
    image = data_iter['image'][idx]
    figure.add_subplot(rows, cols, idx)
    plt.axis("off")
    plt.imshow(image.squeeze(), cmap="gray")
plt.show()

潜在编码(latent code)的构造:

为了跟踪生成器的学习进度,我们在训练的过程中的每轮迭代结束后,将一组固定的遵循高斯分布的隐码test_noise输入到生成器中,通过这组固定的潜在编码(也叫隐码)所生成的图像效果来评估生成器的生成质量。

python 复制代码
import random
import numpy as np
from mindspore import Tensor
from mindspore.common import dtype

# 利用随机种子创建一批隐码
np.random.seed(2323)
test_noise = Tensor(np.random.normal(size=(25, 100)), dtype.float32)
random.shuffle(test_noise)
模型训练

定义损失函数和优化器:

python 复制代码
lr = 0.0002  # 学习率

# 损失函数
adversarial_loss = nn.BCELoss(reduction='mean')

# 优化器
optimizer_d = nn.Adam(net_d.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g = nn.Adam(net_g.trainable_params(), learning_rate=lr, beta1=0.5, beta2=0.999)
optimizer_g.update_parameters_name('optim_g')
optimizer_d.update_parameters_name('optim_d')

训练分为两个主要部分,也就是要训练两个网络:生成与对抗网络。

第一部分是训练判别器。训练判别器的目的是最大程度地提高判别图像真伪的概率。按照原论文的方法,通过提高其随机梯度来更新判别器,最大化 l o g D ( x ) + l o g ( 1 − D ( G ( z ) ) log D(x) + log(1 - D(G(z)) logD(x)+log(1−D(G(z)) 的值。

第二部分是训练生成器。如论文所述,最小化 l o g ( 1 − D ( G ( z ) ) ) log(1 - D(G(z))) log(1−D(G(z))) 来训练生成器,以产生更好的虚假图像。

在这两个部分中,分别获取训练过程中的损失,并在每轮迭代结束时进行测试,将固定隐码批量推送到生成器中,以直观地跟踪生成器 Generator 的训练效果。

python 复制代码
import os
import time
import matplotlib.pyplot as plt
import mindspore as ms
from mindspore import Tensor, save_checkpoint

total_epoch = 24  # 训练周期数
batch_size = 64  # 用于训练的训练集批量大小

# 加载预训练模型的参数
pred_trained = False
pred_trained_g = './result/checkpoints/Generator99.ckpt'
pred_trained_d = './result/checkpoints/Discriminator99.ckpt'

checkpoints_path = "./result/checkpoints"  # 结果保存路径
image_path = "./result/images"  # 测试结果保存路径

# 生成器计算损失过程
def generator_forward(test_noises):
    fake_data = net_g(test_noises)
    fake_out = net_d(fake_data)
    loss_g = adversarial_loss(fake_out, ops.ones_like(fake_out))
    return loss_g

# 判别器计算损失过程
def discriminator_forward(real_data, test_noises):
    fake_data = net_g(test_noises)
    fake_out = net_d(fake_data)
    real_out = net_d(real_data)
    real_loss = adversarial_loss(real_out, ops.ones_like(real_out))
    fake_loss = adversarial_loss(fake_out, ops.zeros_like(fake_out))
    loss_d = real_loss + fake_loss
    return loss_d

# 梯度方法
grad_g = ms.value_and_grad(generator_forward, None, net_g.trainable_params())
grad_d = ms.value_and_grad(discriminator_forward, None, net_d.trainable_params())

def train_step(real_data, latent_code):
    # 计算判别器损失和梯度
    loss_d, grads_d = grad_d(real_data, latent_code)
    optimizer_d(grads_d)
    loss_g, grads_g = grad_g(latent_code)
    optimizer_g(grads_g)

    return loss_d, loss_g

# 保存生成的test图像
def save_imgs(gen_imgs1, idx):
    for i3 in range(gen_imgs1.shape[0]):
        plt.subplot(5, 5, i3 + 1)
        plt.imshow(gen_imgs1[i3, 0, :, :] / 2 + 0.5, cmap="gray")
        plt.axis("off")
    plt.savefig(image_path + "/test_{}.png".format(idx))

# 设置参数保存路径
os.makedirs(checkpoints_path, exist_ok=True)
# 设置中间过程生成图片保存路径
os.makedirs(image_path, exist_ok=True)

net_g.set_train()
net_d.set_train()

# 储存生成器和判别器loss
losses_g, losses_d = [], []

for epoch in range(total_epoch):
    start = time.time()
    for (iter, data) in enumerate(mnist_ds):
        start1 = time.time()
        image, latent_code = data
        image = (image - 127.5) / 127.5  # [0, 255] -> [-1, 1]
        image = image.reshape(image.shape[0], 1, image.shape[1], image.shape[2])
        d_loss, g_loss = train_step(image, latent_code)
        end1 = time.time()
        if iter % 10 == 0:
            print(f"Epoch:[{int(epoch):>3d}/{int(total_epoch):>3d}], "
                  f"step:[{int(iter):>4d}/{int(iter_size):>4d}], "
                  f"loss_d:{d_loss.asnumpy():>4f} , "
                  f"loss_g:{g_loss.asnumpy():>4f} , "
                  f"time:{(end1 - start1):>3f}s, "
                  f"lr:{lr:>6f}")

    end = time.time()
    print("time of epoch {} is {:.2f}s".format(epoch + 1, end - start))

    losses_d.append(d_loss.asnumpy())
    losses_g.append(g_loss.asnumpy())

    # 每个epoch结束后,使用生成器生成一组图片
    gen_imgs = net_g(test_noise)
    save_imgs(gen_imgs.asnumpy(), epoch)

    # 根据epoch保存模型权重文件
    if epoch % 1 == 0:
        save_checkpoint(net_g, checkpoints_path + "/Generator%d.ckpt" % (epoch))
        save_checkpoint(net_d, checkpoints_path + "/Discriminator%d.ckpt" % (epoch))

import time
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),'Wayn_Fan-sail')

使用cpu进行12个epoch的生成效果如下:

Reference

昇思官方文档-GAN图像生成
昇思大模型平台

相关推荐
volcanical10 分钟前
Judging LLM-as-a-Judge with MT-Bench and Chatbot Arena
人工智能·自然语言处理·机器翻译
大知闲闲哟12 分钟前
深度学习J6周 ResNeXt-50实战解析
人工智能·深度学习
静静AI学堂1 小时前
Yolo11改策略:卷积改进|SAC,提升模型对小目标和遮挡目标的检测性能|即插即用
人工智能·深度学习·目标跟踪
martian6651 小时前
【人工智能离散数学基础】——深入详解数理逻辑:理解基础逻辑概念,支持推理和决策系统
人工智能·数理逻辑·推理·决策系统
Schwertlilien1 小时前
图像处理-Ch7-图像金字塔和其他变换
图像处理·人工智能
凡人的AI工具箱1 小时前
每天40分玩转Django:Django类视图
数据库·人工智能·后端·python·django·sqlite
千天夜1 小时前
深度学习中的残差网络、加权残差连接(WRC)与跨阶段部分连接(CSP)详解
网络·人工智能·深度学习·神经网络·yolo·机器学习
凡人的AI工具箱2 小时前
每天40分玩转Django:实操图片分享社区
数据库·人工智能·后端·python·django
小军军军军军军2 小时前
MLU运行Stable Diffusion WebUI Forge【flux】
人工智能·python·语言模型·stable diffusion
诚威_lol_中大努力中2 小时前
关于VQ-GAN利用滑动窗口生成 高清图像
人工智能·神经网络·生成对抗网络