GAN原理 & 代码解读

模型架构

代码

数据准备

python 复制代码
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision import datasets
import torch.nn as nn
import torch

# 创建文件夹存放图片
os.makedirs("data", exist_ok=True)
python 复制代码
transform = transforms.Compose([
    transforms.ToTensor(), #它会进行0-1归一化,h方向/h,w方向/w。 然后将图片格式转换为 (channel,h,w)
    transforms.Normalize(0.5,0.5),#把数据归一化为均值为0.5,方差为0.5,图像的数值范围变成-1到1
])
python 复制代码
# 下载训练数据后对图片进行transform里的toTensor和用均值方差归一化
train_dataset = datasets.MNIST('data',
                               train=True,
                               transform=transform,
                               download=True)
dataloader = torch.utils.data.DataLoader(train_dataset,batch_size=64,shuffle=True)

定义生成器

python 复制代码
'''
    输入:正态分布随机数噪声(长度为100)
    输出:生成的图片,(1,28,28)
    中间过程:
        linear1: 100 -> 256
        linear2: 256 -> 512
        linear3: 512 -> 28*28
        reshape: 28x28 -> (1,28,28)
'''
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__() # super().__init__() 是调用父类的__init__函数
        self.model = nn.Sequential(nn.Linear(100,256),nn.ReLU(),
                                   nn.Linear(256,512),nn.ReLU(),
                                    # 最后一层用tanh激活,将数据压缩到-1到1
                                   nn.Linear(512,28*28),nn.Tanh())
    def forward(self,x):
        img = self.model(x)
        img = img.view(-1,28,28,1) # 得到的是28*28=784,把它reshape为 (批量,h,w,channel)
        return img

定义判别器

python 复制代码
'''
    判别器
    输入:(1,28,28)的图片
    输出:二分类的概率值 用sigmoid压缩到0-1之间
    内容:
    判别器 推荐使用LeakyRelu,因为生成器难以训练,Relu的负值直接变成0没有梯度了
'''
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.model = nn.Sequential(
            nn.Linear(28*28,512),nn.LeakyReLU(),
            nn.Linear(512,256),nn.LeakyReLU(),
            nn.Linear(256,1),nn.Sigmoid(),
        )
    def forward(self,x):
        x = x.view(-1,28*28)
        x = self.model(x)
        return x

初始化模型,优化器及损失计算函数

python 复制代码
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device) # 初始化并放到了相应的设备上
dis = Discriminator().to(device)
dis_optim = torch.optim.Adam(dis.parameters(),lr=0.0001)
gen_optim = torch.optim.Adam(gen.parameters(),lr=0.0001)
bce_loss = torch.nn.BCELoss()

画生成器生成的图的绘图函数

python 复制代码
def gen_img_plot(model,epoch,test_input):
    prediction = model(test_input).detach().cpu().numpy() # 放在内存上 并转换为Numpy
    prediction = np.squeeze(prediction) # np.squeeze是一个numpy函数,删除数组中形状为1的维度
    fig = plt.figure(figsize=(4,4))
    for i in range(16): # 迭代这n张图片
        plt.subplot(4,4,i+1)
        plt.imshow((prediction[i] + 1) / 2) # 生成器生成的图片是-1到1之间的,无法绘图。通过 (原+1)/2把[-1,1]压缩到[0,1]
        plt.axis('off')
    plt.show()

显示图片的函数

python 复制代码
def img_plot(img):
    img = np.squeeze(img) # np.squeeze是一个numpy函数,删除数组中形状为1的维度
    fig = plt.figure(figsize=(4,4))
    for i in range(16): # 迭代这n张图片
        plt.subplot(4,4,i+1)
        plt.imshow((img[i] + 1) / 2) # 生成器生成的图片是-1到1之间的,无法绘图。通过 (原+1)/2把[-1,1]压缩到[0,1]
        plt.axis('off')
    plt.show()

定义训练函数

python 复制代码
def train(num_epoch,test_input):
    D_loss = []
    G_loss = []
    # 训练循环
    for epoch in range(num_epoch):
        d_epoch_loss = 0
        g_epoch_loss = 0
        count = len(dataloader) # 返回批次数
        for step,(img,_) in enumerate(dataloader): # _是标签数据,img是(批次,h,w),每次取的img形状为(64,1,28,28)
            # print(f'step={step},img.shape={img.shape}')
            # img_plot(img)
            img = img.to(device)
            size = img.size(0) # 得到一个批次的图片
            random_noise = torch.randn(size,100,device=device) # 生成器的输入

            '''一. 训练判别器'''
            '''用真实图片训练判别器'''
            dis_optim.zero_grad()
            real_output = dis(img) # 对判别取输入真实的图片,输出对真实图片的预测结果
            # 判别器在真实图像上的损失
            d_real_loss = bce_loss(real_output,
                                   # torch.ones_like(real_output) 创建一个根real_loss一样形状的全1数组,作为标签。
                                   torch.ones_like(real_output))
            d_real_loss.backward()

            '''用生成的图片训练判别器'''
            gen_img = gen(random_noise)
            # 因为此时是为了训练判别器,所以不能让生成器的梯度参与进来。所以用detach()取出无梯度的tensor
            fake_output = dis(gen_img.detach())
            d_fake_loss = bce_loss(fake_output,
                                   torch.zeros_like(fake_output))
            d_fake_loss.backward()
            d_loss = d_real_loss+d_fake_loss
            dis_optim.step() # 对参数进行优化

            '''二.训练生成器'''
            gen_optim.zero_grad()
            # 刚才是去掉生成器生成的图片的梯度,来训练判别器。此处不需要去掉梯度。让判别器进行判别
            fake_output = dis(gen_img)
            # 思想:目的是生成越来越逼真的图片瞒过判别器,让判别器判定生成的图片是真实的图片。
            # 实现方法:把判别器的结果输入到bce_loss,用1作为标签,看判别器把生成的图片判别为真的损失。
            g_loss = bce_loss(fake_output,
                              torch.ones_like(fake_output))
            g_loss.backward()
            gen_optim.step()

            # 计算一个epoch的损失
            with torch.no_grad(): #  禁止梯度计算和参数更新
                d_epoch_loss +=d_loss
                g_epoch_loss +=g_loss
        # 计算整体loss每个epoch的平均Loss
        with torch.no_grad(): #  禁止梯度计算和参数更新
            d_epoch_loss /= count
            g_epoch_loss /= count
            D_loss.append(d_epoch_loss)
            G_loss.append(g_epoch_loss)
            print('Epoch:', epoch+1)
            print(f'd_epoch_loss={d_epoch_loss}')
            print(f'g_epoch_loss={g_epoch_loss}')
            # 将16个长度为100的噪音输入到生成器并画图
            gen_img_plot(gen,test_input)

开始训练

python 复制代码
'''开始计时'''
start_time = time.time()

'''开始训练'''
test_input = torch.randn(16,100,device=device) # 生成16个 长度为100的正太分布随机数。放到GPU中 作为输入
print(test_input)
num_epoch = 50
train(num_epoch,test_input)
# 保存训练50次的参数
torch.save(gen.state_dict(),'gen_weights.pth')
torch.save(dis.state_dict(),'dis_weights.pth')

'''计时结束'''
end_time = time.time()
run_time = end_time - start_time
# 将输出的秒数保留两位小数
if int(run_time)<60:
    print(f'{round(run_time,2)}s')
else:
    print(f'{round(run_time/60,2)}minutes')

结果可视化

加载训练好的参数

python 复制代码
gen.load_state_dict(torch.load('/opt/software/computer_vision/codes/My_codes/paper_codes/GAN/weights/gen_weights.pth'))

用训练好的生成器生成图片并画图

python 复制代码
test_new_input = torch.randn(16,100,device=device) # 生成16个 长度为100的正太分布随机数。放到GPU中 作为输入
gen_img_plot(gen,test_new_input)

GAN的生成是随机的,不同的噪声,生成不同的数字

相关推荐
CodeCraft Studio几秒前
全新AI增强Demo发布:DHTMLX Gantt与Diagram如何通过LLM更智能地构建项目与组织结构
人工智能·ai·项目管理·甘特图·dhtmlx gantt·gantt·llm大模型
孟祥_成都12 分钟前
不易懂你打我!写给前端和小白的 大模型(ChatGPT) 工作基本原理!
前端·人工智能
斯文~13 分钟前
【AI论文速递】SymAgent:知识图谱复杂推理的agent框架
人工智能·深度学习·llm·nlp·知识图谱
搬砖者(视觉算法工程师)14 分钟前
检索增强生成(RAG)与大语言模型微调(Fine-tuning)的差异、优势及使用场景详解
人工智能·语言模型·自然语言处理
明月照山海-1 小时前
机器学习周报二十三
人工智能·机器学习
科研面壁者2 小时前
SPSS——独立样本T检验
数据库·人工智能·机器学习·信息可视化·数据分析·spss·数据处理
ToTensor3 小时前
Tree of Thoughts:让大语言模型像人类一样思考
人工智能·语言模型·自然语言处理
shangjian0076 小时前
AI大模型-评价指标-相关术语
人工智能·算法
江河地笑6 小时前
opencv、cmake、vcpkg
人工智能·opencv·计算机视觉
海边夕阳20067 小时前
【每天一个AI小知识】:什么是卷积神经网络?
人工智能·经验分享·深度学习·神经网络·机器学习·cnn