[PyTorch][chapter 57][WGAN-GP 代码实现]

前言:

下图为WGAN 的效果图:

绿色为真实数据的分布: 8个高斯分布

红色: 为随机产生的数据分布,跟真实分布基本一致

WGAN-GP:

1 判别器D: 最后一层去掉sigmoid

2 生成器G 和判别器D: loss不取log

3 损失函数 增加了penalty,使用Adam

Wasserstein GAN

1 判别器D: 最后一层去掉sigmoid

2 生成器G 和判别器D: loss不取log

3 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c
4 不要用基于动量的优化算法(包括momentum和Adam) ,推荐RMSProp,SGD也行


一 简介

1.1 模型结构

1.2 伪代码


wgan.py

主要变化:

Generator 中 去掉了之前的logit 函数

复制代码
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 28 11:10:19 2023

@author: chengxf2
"""

import torch
from   torch import nn



#生成器模型
h_dim = 400
class Generator(nn.Module):
    
    def __init__(self):
        
        super(Generator,self).__init__()
        # z: [batch,input_features]
       
        self.net = nn.Sequential(
            nn.Linear(2, h_dim),
            nn.ReLU(True),
            nn.Linear( h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 2)
            )
        
    def forward(self, z):
        
        output = self.net(z)
        return output
    
#鉴别器模型
class Discriminator(nn.Module):
    
    def __init__(self):
        
        super(Discriminator,self).__init__()
        
        hDim=400
        # x: [batch,input_features]
        self.net = nn.Sequential(
            nn.Linear(2, hDim),
            nn.ReLU(True),
            nn.Linear(hDim, hDim),
            nn.ReLU(True),
            nn.Linear(hDim, hDim),
            nn.ReLU(True),
            nn.Linear(hDim, 1),
            )
        
    def forward(self, x):
        
        #x:[batch,1]
        output = self.net(x)
        
        out = output.view(-1)
        return out

main.py

主要变化:

损失函数中增加了gradient_penalty

复制代码
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 28 11:28:32 2023

@author: chengxf2
"""


import visdom
from gan  import  Discriminator
from gan  import Generator
import numpy as np
import random
import torch
from   torch import nn, optim
from    matplotlib import pyplot as plt
from torch import autograd


h_dim =400
batchsz = 256
viz = visdom.Visdom()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



def weights_init(net):
   if isinstance(net, nn.Linear):
         # net.weight.data.normal_(0.0, 0.02)
         nn.init.kaiming_normal_(net.weight)
         net.bias.data.fill_(0)

def data_generator():
    """
    8- gaussian destribution

    Returns
    -------
    None.

    """
    scale = 2
    a = np.sqrt(2.0)
    centers =[
         (1,0),
         (-1,0),
         (0,1),
         (0,-1),
         (1/a,1/a),
         (1/a,-1/a),
         (-1/a, 1/a),
         (-1/a,-1/a)
        ]
    
    centers = [(scale*x, scale*y) for x,y in centers]
    
    while True:
        
         dataset =[]
         
         for i in range(batchsz):
             
             point = np.random.randn(2)*0.02
             center = random.choice(centers)
             point[0] += center[0]
             point[1] += center[1]
             dataset.append(point)
         dataset = np.array(dataset).astype(np.float32)
         dataset /=a
         #生成器函数是一个特殊的函数,可以返回一个迭代器
         yield dataset


def generate_image(D, G, xr, epoch):      #xr表示真实的sample
    """
    Generates and saves a plot of the true distribution, the generator, and the
    critic.
    """
    N_POINTS = 128
    RANGE = 3
    plt.clf()

    points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')
    points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]
    points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]
    points = points.reshape((-1, 2))             # (16384, 2)
    x = y = np.linspace(-RANGE, RANGE, N_POINTS)
    N = len(x)
    # draw contour
    with torch.no_grad():
        points = torch.Tensor(points)      # [16384, 2]
        disc_map = D(points).cpu().numpy() # [16384]
   
    plt.contour(x, y, disc_map.reshape((N, N)).transpose())
    #plt.clabel(cs, inline=1, fontsize=10)
    plt.colorbar()


    # draw samples
    with torch.no_grad():
        z = torch.randn(batchsz, 2)                 # [b, 2]
        samples = G(z).cpu().numpy()                # [b, 2]
    plt.scatter(xr[:, 0], xr[:, 1], c='green', marker='.')
    plt.scatter(samples[:, 0], samples[:, 1], c='red', marker='+')

    viz.matplot(plt, win='contour', opts=dict(title='p(x):%d'%epoch))
    

def gradient_penalty(D, xr,xf):

    #[b,1]
    t =  torch.rand(batchsz, 1).to(device)       
    #[b,1]=>[b,2]  保证每个sample t 相同
    t =  t.expand_as(xr)
    
    #sample penalty interpoation [b,2]
    mid = t*xr +(1-t)*xf
    mid.requires_grad_()
    
    pred = D(mid) #[256]
   
    '''
    grad_outputs:   如果outputs 是向量,则此参数必须写
    retain_graph:  True 则保留计算图, False则释放计算图
    create_graph: 若要计算高阶导数,则必须选为True
    allow_unused: 允许输入变量不进入计算
    '''
    grads = autograd.grad(outputs= pred, inputs = mid,
                      grad_outputs= torch.ones_like(pred),
                      create_graph=True,
                      retain_graph=True,
                      only_inputs=True)[0]
    
    gp = torch.pow(grads.norm(2, dim=1)-1,2).mean()
    
    return gp
    
    
    
    
    
    
         
def main():
  
    lambd = 0.2 #超参数
    maxIter = 1000
    torch.manual_seed(10)
    np.random.seed(10)
    data_iter  = data_generator()
    
   
    G = Generator().to(device)
    D = Discriminator().to(device)
    G.apply(weights_init)
    D.apply(weights_init)
    optim_G = optim.Adam(G.parameters(),lr =5e-4, betas=(0.5,0.9))
    optim_D = optim.Adam(D.parameters(),lr =5e-4, betas=(0.5,0.9))
    K = 5
 
    

    
   
    viz.line([[0,0]], [0], win='loss', opts=dict(title='loss', legend=['D', 'G']))

    for epoch in range(maxIter):
        
        #1: train Discrimator fistly
        for k in range(K):
            
            #1.1: train on real data
            xr = next(data_iter)
            xr = torch.from_numpy(xr).to(device)
            predr = D(xr)
            
       
            #max(predr) == min(-predr)
            lossr = -predr.mean()
            
            
            #1.2: train on fake data
            z = torch.randn(batchsz,2).to(device) #[b,2] 随机产生的噪声
            xf = G(z).detach() #固定G,不更新G参数 tf.stop_gradient()
            predf =D(xf)
            lossf = predf.mean()
            
            #1.3 gradient_penalty
            gp = gradient_penalty(D, xr,xf.detach())
            
            #aggregate all
            loss_D = lossr + lossf +lambd*gp
            
            optim_D.zero_grad()
            loss_D.backward()
            optim_D.step()
            #print("\n Discriminator 训练结束 ",loss_D.item())
        
        # 2 train  Generator
        
        #2.1 train on fake data
        z = torch.randn(batchsz, 2).to(device)
        xf = G(z)
        predf =D(xf) #期望最大
        loss_G= -predf.mean()
        
        #optimize
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()
        
        if epoch %100 ==0:
            viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append')
            generate_image(D, G, xr, epoch)
            print("\n epoch: %d"%epoch,"\t lossD: %7.4f"%loss_D.item(),"\t lossG: %7.4f"%loss_G.item())
         
        
 

    
    
    

if __name__ == "__main__":
    
    main()

参考:

课时130 WGAN-GP实战_哔哩哔哩_bilibili

WGAN基本原理及Pytorch实现WGAN-CSDN博客

CSDN

相关推荐
杨小扩2 小时前
第4章:实战项目一 打造你的第一个AI知识库问答机器人 (RAG)
人工智能·机器人
whaosoft-1432 小时前
51c~目标检测~合集4
人工智能
雪兽软件2 小时前
2025 年网络安全与人工智能发展趋势
人工智能·安全·web安全
元宇宙时间3 小时前
全球发展币GDEV:从中国出发,走向全球的数字发展合作蓝图
大数据·人工智能·去中心化·区块链
小黄人20253 小时前
自动驾驶安全技术的演进与NVIDIA的创新实践
人工智能·安全·自动驾驶
ZStack开发者社区4 小时前
首批 | 云轴科技ZStack加入施耐德电气技术本地化创新生态
人工智能·科技·云计算
X Y O5 小时前
神经网络初步学习3——数据与损失
人工智能·神经网络·学习
FL16238631295 小时前
如何使用目标检测深度学习框架yolov8训练钢管管道表面缺陷VOC+YOLO格式1159张3类别的检测数据集步骤和流程
深度学习·yolo·目标检测
唯创知音5 小时前
玩具语音方案选型决策OTP vs Flash 的成本功耗与灵活性
人工智能·语音识别
Jamence5 小时前
多模态大语言模型arxiv论文略读(151)
论文阅读·人工智能·语言模型·自然语言处理·论文笔记