pytorch生成对抗网络

生成对抗网络

import os

import torch

import torchvision

import torch.nn as nn

from torchvision import transforms

from torchvision.utils import save_image

Device configuration

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

超参数

latent_size = 64 # 潜在空间(latent space)的维度数量

hidden_size = 256

image_size = 784

num_epochs = 200

batch_size = 100

sample_dir = 'samples'

Create a directory if not exists

if not os.path.exists(sample_dir):

os.makedirs(sample_dir)

Image processing

transform = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize(mean=(0.5, 0.5, 0.5), # 3 for RGB channels

std=(0.5, 0.5, 0.5))])

transform = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize(mean=[0.5], # 1 for greyscale channels

std=[0.5])])

MNIST dataset

mnist = torchvision.datasets.MNIST(root='./datasets',

train=True,

transform=transform,

download=True)

data_loader = torch.utils.data.DataLoader(dataset=mnist,

batch_size=batch_size,

shuffle=True)

for i,_ in data_loader:

print(i.shape,i.max(),i.min(),torch.unique(_))

break

鉴别器

D = nn.Sequential(

nn.Linear(image_size, hidden_size),

nn.LeakyReLU(0.2),

nn.Linear(hidden_size, hidden_size),

nn.LeakyReLU(0.2),

nn.Linear(hidden_size, 1),

nn.Sigmoid())

生成器

G = nn.Sequential(

nn.Linear(latent_size, hidden_size),

nn.ReLU(),

nn.Linear(hidden_size, hidden_size),

nn.ReLU(),

nn.Linear(hidden_size, image_size),

nn.Tanh())

Device setting

D = D.to(device)

G = G.to(device)

Binary cross entropy loss and optimizer

criterion = nn.BCELoss()

d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)

g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

def denorm(x):

out = (x + 1) / 2

return out.clamp(0, 1) # 裁剪到0-1

def reset_grad():

d_optimizer.zero_grad()

g_optimizer.zero_grad()

torch.zeros(4, 1).shape

torch.randn(4,50).shape

生成对抗中有生成器和鉴别器,生成器是把潜在向量作为输入,网络输出图片的一维向量形式,在模型训练时,

判别器对真实图片和生成图片进行判别(打分),这是个二分类,这里用1表示真,0表示假,所以鉴别器损失就由

真实图片和全1的损失,生成图片和全0的损失,这两个损失组成,之后反向传播,这时候更新的就是鉴别器的参数

鉴别器训练的目的是为了辨别真实和生成,而生成器呢,生成器的目的是让生成的图片尽量接近真实,但是你怎么判断

它接近真实,就是传入鉴别器,鉴别器打分越高,说明它越真实,所以生成器损失就是全1标签和鉴别logits间的误差

特别注意的是:在鉴别器损失反向传播时,更新的只是鉴别器模型中的参数,而生成器损失反向传播时,更新的也只是

#生成器模型中的参数

训练

total_step = len(data_loader) # 总批次数

for epoch in range(num_epochs):

遍历每个批次数据

for i, (images, _) in enumerate(data_loader):

images = images.reshape(batch_size, -1).to(device)

创建稍后用作BCE损失输入的标签(全1表示真,全0表示假)

real_labels = torch.ones(batch_size, 1).to(device)

fake_labels = torch.zeros(batch_size, 1).to(device)

训练鉴别器

outputs = D(images) # 获取鉴别器对真实图片的鉴别分数

d_loss_real = criterion(outputs, real_labels) # 真实鉴别损失

real_score = outputs # 真实鉴别分数

z是随机初始化的生成图片(latent_size是用这个大小的向量表示图片)

z = torch.randn(batch_size, latent_size).to(device)

fake_images = G(z) # 获取生成的图片

outputs = D(fake_images) # 获取鉴别器对生成图片的鉴别分数

d_loss_fake = criterion(outputs, fake_labels) # 计算生成(假)的鉴别损失

fake_score = outputs

Backprop and optimize

d_loss = d_loss_real + d_loss_fake # 这两个加起来是鉴别器损失

reset_grad()

鉴别器损失反向传播

d_loss.backward()

根据梯度更新参数

d_optimizer.step()

训练生成器

随机初始化一个噪音图片(用一定大小的向量表示)

z = torch.randn(batch_size, latent_size).to(device)

通过生成器生成图片

fake_images = G(z)

outputs = D(fake_images) # 鉴别器对生成图片的鉴别得分

生成器的目的是使生成的图片足够真实,也就是最小化全1标签和鉴别logits间的误差

g_loss = criterion(outputs, real_labels)

清理之前梯度,用生成器损失反向传播,用g_optimizer更新参数

reset_grad()

g_loss.backward()

g_optimizer.step()

每隔200批次打印日志

if (i+1) % 200 == 0:

print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'

.format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),

real_score.mean().item(), fake_score.mean().item()))

真实图片只需要保存一次

if (epoch+1) == 1:

images = images.reshape(images.size(0), 1, 28, 28)

save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))

每个轮次都会保存一次生成器生成的图片

fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)

save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

Save the model checkpoints (把模型中各个层的参数字典保存到磁盘)

torch.save(G.state_dict(), 'G.ckpt')

torch.save(D.state_dict(), 'D.ckpt')

上面是真实图片,下面是经过200个轮次后的生成图片

相关推荐
volcanical41 分钟前
Dataset Distillation with Attention Labels for Fine-tuning BERT
人工智能·深度学习·bert
L_cl41 分钟前
【NLP 17、NLP的基础——分词】
人工智能·自然语言处理
西西弗Sisyphus43 分钟前
大型语言模型(LLMs)演化树 Large Language Models
人工智能·语言模型·自然语言处理·大模型
轻口味2 小时前
【每日学点鸿蒙知识】沙箱目录、图片压缩、characteristicsArray、gm-crypto 国密加解密、通知权限
pytorch·华为·harmonyos
车载诊断技术3 小时前
电子电气架构 --- 什么是EPS?
网络·人工智能·安全·架构·汽车·需求分析
KevinRay_3 小时前
Python超能力:高级技巧让你的代码飞起来
网络·人工智能·python·lambda表达式·列表推导式·python高级技巧
跃跃欲试-迪之3 小时前
animatediff 模型网盘分享
人工智能·stable diffusion
Captain823Jack3 小时前
nlp新词发现——浅析 TF·IDF
人工智能·python·深度学习·神经网络·算法·自然语言处理
被制作时长两年半的个人练习生3 小时前
【AscendC】ReduceSum中指定workLocal大小时如何计算
人工智能·算子开发·ascendc
Captain823Jack4 小时前
w04_nlp大模型训练·中文分词
人工智能·python·深度学习·神经网络·算法·自然语言处理·中文分词