GAN-Tutorial procedural record

DCGAN

This code defines a Deep Convolutional Generative Adversarial Network (DCGAN) for generating images of cars.

cpp 复制代码
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F

from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image

import numpy as np
import datetime
import os, sys

import glob

from PIL import Image

from matplotlib.pyplot import imshow, imsave
%matplotlib inline

MODEL_NAME = 'DCGAN'
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
IMAGE_DIM = (32, 32, 3)
def get_sample_image(G, n_noise):
    """
        save sample 100 images
    """
    z = torch.randn(10, n_noise).to(DEVICE)
    y_hat = G(z).view(10, 3, 28, 28).permute(0, 2, 3, 1) # (100, 28, 28)
    result = (y_hat.detach().cpu().numpy()+1)/2.
class Discriminator(nn.Module):
    """
        Convolutional Discriminator for MNIST
    """
    def __init__(self, in_channel=1, num_classes=1):
        super(Discriminator, self).__init__()
        self.conv = nn.Sequential(
            # 28 -> 14
            nn.Conv2d(in_channel, 512, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            # 14 -> 7
            nn.Conv2d(512, 256, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            # 7 -> 4
            nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            # 
            nn.Conv2d(128, 128, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool2d(1),
        )        
        self.fc = nn.Sequential(
            # reshape input, 128 -> 1
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )
    
    def forward(self, x, y=None):
        y_ = self.conv(x)
        y_ = y_.view(y_.size(0), -1)
        y_ = self.fc(y_)
        return y_
    return result
class Generator(nn.Module):
    """
        Convolutional Generator for MNIST
    """
    def __init__(self, out_channel=1, input_size=100, num_classes=784):
        super(Generator, self).__init__()
        assert IMAGE_DIM[0] % 2**4 == 0, 'Should be divided 16'
        self.init_dim = (IMAGE_DIM[0] // 2**4, IMAGE_DIM[1] // 2**4)
        self.fc = nn.Sequential(
            nn.Linear(input_size, self.init_dim[0]*self.init_dim[1]*512),
            nn.ReLU(),
        )
        self.conv = nn.Sequential(
            nn.Conv2d(512, 512, 3, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            # x2
            nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            # x2
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            # x2
            nn.ConvTranspose2d(128, 128, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            # x2
            nn.ConvTranspose2d(128, out_channel, 4, stride=2, padding=1, bias=False),
            nn.Tanh(),
        )
        
    def forward(self, x, y=None):
        x = x.view(x.size(0), -1)
        y_ = self.fc(x)
        y_ = y_.view(y_.size(0), 512, self.init_dim[0], self.init_dim[1])
        y_ = self.conv(y_)
        return y_
class CARS(Dataset):
    '''
    CARS Dataset
    You should download this dataset from below url.
    url: https://ai.stanford.edu/~jkrause/cars/car_dataset.html
    '''
    def __init__(self, data_path, transform=None):
        '''
        Args:
            data_path (str): path to dataset
        '''
        self.data_path = data_path
        self.transform = transform
        self.fpaths = sorted(glob.glob(os.path.join(data_path, '*.jpg')))
        gray_lst = [266, 1085, 2176, 3048, 3439, 3469, 3539, 4577, 4848, 5177, 5502, 5713, 6947, 7383, 7693, 7774, 8137, 8144]
        for num in gray_lst:
            self.fpaths.remove(os.path.join(data_path, '{:05d}.jpg'.format(num)))
        
    def __getitem__(self, idx):
        img = self.transform(Image.open(self.fpaths[idx]))
        return img

    def __len__(self):
        return len(self.fpaths)
D = Discriminator(in_channel=IMAGE_DIM[-1]).to(DEVICE)
G = Generator(out_channel=IMAGE_DIM[-1]).to(DEVICE)
# D.load_state_dict('D_dc.pkl')
# G.load_state_dict('G_dc.pkl')
transform = transforms.Compose([transforms.Resize((IMAGE_DIM[0],IMAGE_DIM[1])),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=(0.5, 0.5, 0.5),
                                std=(0.5, 0.5, 0.5))
                               ]
)
dataset = CARS(data_path='/home/yangyangii/git/cars_train', transform=transform)
逐行解释代码
cpp 复制代码
这段代码定义了一个DCGAN的判别器模型和一个获取样本图像的函数。

`MODEL_NAME`是DCGAN模型的名称,`DEVICE`是判断是否使用cuda的设备。`IMAGE_DIM`是图像的维度。

`get_sample_image`函数用于保存100个样本图像。首先使用正态分布生成一个大小为`n_noise`的随机向量`z`,并将其发送到设备上。然后将随机向量通过生成器`G`得到生成的图像`y_hat`,并将其reshape为(10, 3, 28, 28)的形状,并按照维度顺序重新排列为(10, 28, 28, 3)。最后返回将生成的图像转换为numpy数组并进行归一化处理的结果。

`Discriminator`类是一个用于MNIST数据集的卷积判别器模型。该模型包含几个卷积层和全连接层。卷积层部分使用了`nn.Conv2d`进行卷积操作,`nn.BatchNorm2d`进行批归一化操作,`nn.LeakyReLU`进行LeakyReLU激活操作,`nn.AdaptiveAvgPool2d`进行自适应平均池化操作。全连接层部分使用了`nn.Linear`进行线性变换操作,`nn.Sigmoid`进行Sigmoid激活操作。

`forward`函数是判别器模型的前向传播方法。输入`x`经过卷积层和reshape操作后,经过全连接层得到输出`y_`。若有标签`y`,则在最后通过softmax函数将输出映射到[0,1]之间;若无标签,则直接返回输出结果。

最后返回结果。
相关推荐
Lululaurel几秒前
AI编程提示词工程实战指南:从入门到精通
人工智能·python·机器学习·ai·ai编程
JOYCE_Leo162 分钟前
Learning Diffusion Texture Priors for Image Restoration(DTPM)-CVPR2024
深度学习·扩散模型·图像复原
财经三剑客12 分钟前
东风集团股份:11月生产量达21.6万辆 销量19.6万辆
大数据·人工智能·汽车
老蒋新思维15 分钟前
创客匠人峰会新解:高势能 IP 打造 ——AI 时代知识变现的十倍增长密码
大数据·网络·人工智能·tcp/ip·创始人ip·创客匠人·知识变现
Dev7z17 分钟前
基于神经网络的风电机组齿轮箱故障诊断研究与设计
人工智能·深度学习·神经网络
老蒋新思维17 分钟前
创客匠人峰会洞察:AI 时代教育知识变现的重构 —— 从 “刷题记忆” 到 “成长赋能” 的革命
大数据·人工智能·网络协议·tcp/ip·重构·创始人ip·创客匠人
飞鹰@四海17 分钟前
AutoGLM 旧安卓一键变 AI 手机:安装与使用指南
android·人工智能·智能手机
paopao_wu19 分钟前
智普GLM-TTS开源:可控且富含情感的零样本语音合成模型
人工智能·ai·开源·大模型·tts
少林and叔叔21 分钟前
基于yolov11s模型训练与推理测试(VScode开发环境)
ide·人工智能·vscode·yolo·目标检测
serve the people22 分钟前
tensorflow 零基础吃透:RaggedTensor 的评估(访问值的 4 种核心方式)
人工智能·tensorflow