[PyTorch][chapter 57][WGAN-GP 代码实现]

前言:

下图为WGAN 的效果图:

绿色为真实数据的分布: 8个高斯分布

红色: 为随机产生的数据分布,跟真实分布基本一致

WGAN-GP:

1 判别器D: 最后一层去掉sigmoid

2 生成器G 和判别器D: loss不取log

3 损失函数 增加了penalty,使用Adam

Wasserstein GAN

1 判别器D: 最后一层去掉sigmoid

2 生成器G 和判别器D: loss不取log

3 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c
4 不要用基于动量的优化算法(包括momentum和Adam) ,推荐RMSProp,SGD也行


一 简介

1.1 模型结构

1.2 伪代码


wgan.py

主要变化:

Generator 中 去掉了之前的logit 函数

复制代码
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 28 11:10:19 2023

@author: chengxf2
"""

import torch
from   torch import nn



#生成器模型
h_dim = 400
class Generator(nn.Module):
    
    def __init__(self):
        
        super(Generator,self).__init__()
        # z: [batch,input_features]
       
        self.net = nn.Sequential(
            nn.Linear(2, h_dim),
            nn.ReLU(True),
            nn.Linear( h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, h_dim),
            nn.ReLU(True),
            nn.Linear(h_dim, 2)
            )
        
    def forward(self, z):
        
        output = self.net(z)
        return output
    
#鉴别器模型
class Discriminator(nn.Module):
    
    def __init__(self):
        
        super(Discriminator,self).__init__()
        
        hDim=400
        # x: [batch,input_features]
        self.net = nn.Sequential(
            nn.Linear(2, hDim),
            nn.ReLU(True),
            nn.Linear(hDim, hDim),
            nn.ReLU(True),
            nn.Linear(hDim, hDim),
            nn.ReLU(True),
            nn.Linear(hDim, 1),
            )
        
    def forward(self, x):
        
        #x:[batch,1]
        output = self.net(x)
        
        out = output.view(-1)
        return out

main.py

主要变化:

损失函数中增加了gradient_penalty

复制代码
# -*- coding: utf-8 -*-
"""
Created on Thu Sep 28 11:28:32 2023

@author: chengxf2
"""


import visdom
from gan  import  Discriminator
from gan  import Generator
import numpy as np
import random
import torch
from   torch import nn, optim
from    matplotlib import pyplot as plt
from torch import autograd


h_dim =400
batchsz = 256
viz = visdom.Visdom()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")



def weights_init(net):
   if isinstance(net, nn.Linear):
         # net.weight.data.normal_(0.0, 0.02)
         nn.init.kaiming_normal_(net.weight)
         net.bias.data.fill_(0)

def data_generator():
    """
    8- gaussian destribution

    Returns
    -------
    None.

    """
    scale = 2
    a = np.sqrt(2.0)
    centers =[
         (1,0),
         (-1,0),
         (0,1),
         (0,-1),
         (1/a,1/a),
         (1/a,-1/a),
         (-1/a, 1/a),
         (-1/a,-1/a)
        ]
    
    centers = [(scale*x, scale*y) for x,y in centers]
    
    while True:
        
         dataset =[]
         
         for i in range(batchsz):
             
             point = np.random.randn(2)*0.02
             center = random.choice(centers)
             point[0] += center[0]
             point[1] += center[1]
             dataset.append(point)
         dataset = np.array(dataset).astype(np.float32)
         dataset /=a
         #生成器函数是一个特殊的函数,可以返回一个迭代器
         yield dataset


def generate_image(D, G, xr, epoch):      #xr表示真实的sample
    """
    Generates and saves a plot of the true distribution, the generator, and the
    critic.
    """
    N_POINTS = 128
    RANGE = 3
    plt.clf()

    points = np.zeros((N_POINTS, N_POINTS, 2), dtype='float32')
    points[:, :, 0] = np.linspace(-RANGE, RANGE, N_POINTS)[:, None]
    points[:, :, 1] = np.linspace(-RANGE, RANGE, N_POINTS)[None, :]
    points = points.reshape((-1, 2))             # (16384, 2)
    x = y = np.linspace(-RANGE, RANGE, N_POINTS)
    N = len(x)
    # draw contour
    with torch.no_grad():
        points = torch.Tensor(points)      # [16384, 2]
        disc_map = D(points).cpu().numpy() # [16384]
   
    plt.contour(x, y, disc_map.reshape((N, N)).transpose())
    #plt.clabel(cs, inline=1, fontsize=10)
    plt.colorbar()


    # draw samples
    with torch.no_grad():
        z = torch.randn(batchsz, 2)                 # [b, 2]
        samples = G(z).cpu().numpy()                # [b, 2]
    plt.scatter(xr[:, 0], xr[:, 1], c='green', marker='.')
    plt.scatter(samples[:, 0], samples[:, 1], c='red', marker='+')

    viz.matplot(plt, win='contour', opts=dict(title='p(x):%d'%epoch))
    

def gradient_penalty(D, xr,xf):

    #[b,1]
    t =  torch.rand(batchsz, 1).to(device)       
    #[b,1]=>[b,2]  保证每个sample t 相同
    t =  t.expand_as(xr)
    
    #sample penalty interpoation [b,2]
    mid = t*xr +(1-t)*xf
    mid.requires_grad_()
    
    pred = D(mid) #[256]
   
    '''
    grad_outputs:   如果outputs 是向量,则此参数必须写
    retain_graph:  True 则保留计算图, False则释放计算图
    create_graph: 若要计算高阶导数,则必须选为True
    allow_unused: 允许输入变量不进入计算
    '''
    grads = autograd.grad(outputs= pred, inputs = mid,
                      grad_outputs= torch.ones_like(pred),
                      create_graph=True,
                      retain_graph=True,
                      only_inputs=True)[0]
    
    gp = torch.pow(grads.norm(2, dim=1)-1,2).mean()
    
    return gp
    
    
    
    
    
    
         
def main():
  
    lambd = 0.2 #超参数
    maxIter = 1000
    torch.manual_seed(10)
    np.random.seed(10)
    data_iter  = data_generator()
    
   
    G = Generator().to(device)
    D = Discriminator().to(device)
    G.apply(weights_init)
    D.apply(weights_init)
    optim_G = optim.Adam(G.parameters(),lr =5e-4, betas=(0.5,0.9))
    optim_D = optim.Adam(D.parameters(),lr =5e-4, betas=(0.5,0.9))
    K = 5
 
    

    
   
    viz.line([[0,0]], [0], win='loss', opts=dict(title='loss', legend=['D', 'G']))

    for epoch in range(maxIter):
        
        #1: train Discrimator fistly
        for k in range(K):
            
            #1.1: train on real data
            xr = next(data_iter)
            xr = torch.from_numpy(xr).to(device)
            predr = D(xr)
            
       
            #max(predr) == min(-predr)
            lossr = -predr.mean()
            
            
            #1.2: train on fake data
            z = torch.randn(batchsz,2).to(device) #[b,2] 随机产生的噪声
            xf = G(z).detach() #固定G,不更新G参数 tf.stop_gradient()
            predf =D(xf)
            lossf = predf.mean()
            
            #1.3 gradient_penalty
            gp = gradient_penalty(D, xr,xf.detach())
            
            #aggregate all
            loss_D = lossr + lossf +lambd*gp
            
            optim_D.zero_grad()
            loss_D.backward()
            optim_D.step()
            #print("\n Discriminator 训练结束 ",loss_D.item())
        
        # 2 train  Generator
        
        #2.1 train on fake data
        z = torch.randn(batchsz, 2).to(device)
        xf = G(z)
        predf =D(xf) #期望最大
        loss_G= -predf.mean()
        
        #optimize
        optim_G.zero_grad()
        loss_G.backward()
        optim_G.step()
        
        if epoch %100 ==0:
            viz.line([[loss_D.item(), loss_G.item()]], [epoch], win='loss', update='append')
            generate_image(D, G, xr, epoch)
            print("\n epoch: %d"%epoch,"\t lossD: %7.4f"%loss_D.item(),"\t lossG: %7.4f"%loss_G.item())
         
        
 

    
    
    

if __name__ == "__main__":
    
    main()

参考:

课时130 WGAN-GP实战_哔哩哔哩_bilibili

WGAN基本原理及Pytorch实现WGAN-CSDN博客

CSDN

相关推荐
CoovallyAIHub11 分钟前
CVPR 2026 | MixerCSeg:仅2.05 GFLOPs刷新四大裂缝分割基准!解耦Mamba隐式注意力,CNN+Transformer+Mamba三
深度学习·算法·计算机视觉
ssshooter29 分钟前
告别 Chat Completions:深度解析 AI 接口新标准 `/v1/responses`
人工智能·后端·开源
CoovallyAIHub41 分钟前
YOLO26-Pose 深度解读:端到端架构重新设计,姿态估计凭什么跨代领先?
深度学习·算法·计算机视觉
陈广亮44 分钟前
OpenClaw 自动化工作流实战:Cron + Heartbeat 构建 24/7 智能助手
人工智能
大模型落地手艺人1 小时前
OpenClaw/CoPaw易用性突破!Skill UI Generator 让skill可视化
人工智能
CodeDevMaster1 小时前
从零开始:OpenClaw本地 AI 助手部署指南
人工智能·agent·ai编程
CoovallyAIHub1 小时前
化工厂气体泄漏怎么用AI检测?30张图3D重建气体泄漏场景——美国国家实验室NeRF新研究
深度学习·算法·计算机视觉
阿聪谈架构1 小时前
第03章:LCEL 链式调用 —— 让 AI 任务像流水线一样运转
人工智能
chaors1 小时前
从零学RAG0x04向量检索算法初探
人工智能·程序员·ai编程
chaors1 小时前
Langchain入门到精通0x01:结果解析器
人工智能·langchain·ai编程