GAN(生成对抗网络)

简介:GAN生成对抗网络本质上是一种思想,其依靠神经网络能够拟合任意函数的能力,设计了一种架构来实现数据的生成。
原理:GAN的原理就是最小化生成器Generator的损失,但是在最小化损失的过程中加入了一个约束,这个约束就是使Generator生成的数据满足我们指定数据的分布,GAN的巧妙之处在于使用一个神经网络(鉴别器Discriminator)来自动判断生成的数据是否符合我们所需要的分布。
实现细节:

一:

准备好我们想要让生成器生成的数据类型,比如MINIST手写数字集,包含1-10十个数字,一共60000张图片。生成器的目的就是学习这个数据集的分布。

二,

定义一个生成器,用于判别一张图片是实际的还是生成器生成的,当生成器完美学习得到数据分布之后,鉴别器可能就分不清图片是生成器的还是实际的,这样的话生成器就能生成我们想要的图片了。

生成器的训练过程为:实际数据输出结果1,生成数据输出结果为0,目的是学会区分真假数据,相当于提供一个约束,使生成数据符合指定分布。当鉴别生成器的数据分布时,只需要更新鉴别器的参数权重,不能够通过计算图将生成器的参数进行更新。

三,

定义一个生成器,给定一个输入,他就能生成1-10里面的一个数字的图片。生成器的反向更新是根据鉴别器的损失来确定(被约束进行反向更新)。生成器的网络权重参数是单独的,反向更新时,只需要更新计算图当中属于生成器部分的参数。
下面给出生成1-0-1-0数据格式的代码:

python 复制代码
# %%
import torch
import numpy
import torch.nn as nn
import matplotlib.pyplot as plt

# %%
def gennerate1010():
    return torch.FloatTensor([numpy.random.uniform(0.9,1.1),
                              numpy.random.uniform(0.,.1),
                              numpy.random.uniform(0.9,1.1),
                              numpy.random.uniform(0.0,.1)])

# %%
def genneratexxxx():
    return torch.rand(4)

# %%
class Discrimer(nn.Module):
    def __init__(self) -> None:
        father_obj = super(Discrimer,self)
        father_obj.__init__()
        self.create_model()
        self.counter = 0
        self.progress = []
        
    def create_model(self):
        self.model = nn.Sequential(
            nn.Linear(4,3),
            nn.Sigmoid(),
            nn.Linear(3,1),
            nn.Sigmoid(),           
        )
        self.loss_functon = nn.MSELoss()
        self.optimiser = torch.optim.SGD(self.parameters(),lr=0.01)
        
    def forward(self,x):
        return self.model(x)
    
    def train(self,x,targets):
        outputs = self.forward(x)
        loss = self.loss_functon(outputs,targets)
        self.counter += 1
        if self.counter%10 == 0:
            self.progress.append(loss.item())
        if self.counter%10000 == 0:
            print(self.counter)
            
            
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
        
    def plotprogress(self):
        plt.plot(self.progress,marker='*')
        plt.show()

        

# %%
class Gennerater(nn.Module):
    def __init__(self) -> None:
        father_obj = super(Gennerater,self)
        father_obj.__init__()
        self.create_model()
        self.counter = 0
        self.progress = []

        
    def create_model(self):
        self.model = nn.Sequential(
            nn.Linear(1,3),
            nn.Sigmoid(),
            nn.Linear(3,4),
            nn.Sigmoid(),           
        )
        # 这个优化器只能优化生成器部分的参数
        self.optimiser = torch.optim.SGD(self.parameters(),lr=0.01)
        
    def forward(self,x):
        return self.model(x)
    
    def train(self,D,x,targets):
        g_outputs = self.forward(x)
        d_outputs = D.forward(g_outputs)
        # 使用鉴别器的loss函数,但是只更新生成器的参数,生成器的参数需要根据鉴别器的约束进行更新
        loss = D.loss_functon(d_outputs,targets)
        self.counter += 1
        if self.counter%10 == 0:
            self.progress.append(loss.item())
        if self.counter%10000 == 0:
            print(self.counter)
            
            
        self.optimiser.zero_grad()
        loss.backward()
        self.optimiser.step()
        
    def plotprogress(self):
        plt.plot(self.progress,marker='*')
        plt.show()


# %%
D = Discrimer()

# %%
G = Gennerater()


# %%
for id in range(15000):
    # 喂入实际数据给鉴别器
    D.train(gennerate1010(),torch.FloatTensor([1.]))
    # 喂入生成的数据,使用detach从计算图脱离,用于更新鉴别器,而生成器得不到更新
    D.train(G.forward(torch.FloatTensor([0.5]).detach()),torch.FloatTensor([0.0]))
    G.train(D,torch.FloatTensor([0.5]),torch.FloatTensor([1.]))
    

# %%
D.plotprogress()

# %%
G.plotprogress()

# %%
G.forward(torch.FloatTensor([0.5]))

参考:PyTorch生成对抗网络编程

相关推荐
攻城狮_Dream2 分钟前
“探索未来医疗:生成式人工智能在医疗领域的革命性应用“
人工智能·设计·医疗·毕业
Chef_Chen8 分钟前
从0开始机器学习--Day17--神经网络反向传播作业
python·神经网络·机器学习
学习前端的小z31 分钟前
【AIGC】如何通过ChatGPT轻松制作个性化GPTs应用
人工智能·chatgpt·aigc
埃菲尔铁塔_CV算法1 小时前
人工智能图像算法:开启视觉新时代的钥匙
人工智能·算法
EasyCVR1 小时前
EHOME视频平台EasyCVR视频融合平台使用OBS进行RTMP推流,WebRTC播放出现抖动、卡顿如何解决?
人工智能·算法·ffmpeg·音视频·webrtc·监控视频接入
打羽毛球吗️1 小时前
机器学习中的两种主要思路:数据驱动与模型驱动
人工智能·机器学习
好喜欢吃红柚子1 小时前
万字长文解读空间、通道注意力机制机制和超详细代码逐行分析(SE,CBAM,SGE,CA,ECA,TA)
人工智能·pytorch·python·计算机视觉·cnn
小馒头学python1 小时前
机器学习是什么?AIGC又是什么?机器学习与AIGC未来科技的双引擎
人工智能·python·机器学习
神奇夜光杯2 小时前
Python酷库之旅-第三方库Pandas(202)
开发语言·人工智能·python·excel·pandas·标准库及第三方库·学习与成长
正义的彬彬侠2 小时前
《XGBoost算法的原理推导》12-14决策树复杂度的正则化项 公式解析
人工智能·决策树·机器学习·集成学习·boosting·xgboost