pytorch生成对抗网络

生成对抗网络

import os

import torch

import torchvision

import torch.nn as nn

from torchvision import transforms

from torchvision.utils import save_image

Device configuration

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

超参数

latent_size = 64 # 潜在空间(latent space)的维度数量

hidden_size = 256

image_size = 784

num_epochs = 200

batch_size = 100

sample_dir = 'samples'

Create a directory if not exists

if not os.path.exists(sample_dir):

os.makedirs(sample_dir)

Image processing

transform = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize(mean=(0.5, 0.5, 0.5), # 3 for RGB channels

std=(0.5, 0.5, 0.5))])

transform = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize(mean=[0.5], # 1 for greyscale channels

std=[0.5])])

MNIST dataset

mnist = torchvision.datasets.MNIST(root='./datasets',

train=True,

transform=transform,

download=True)

data_loader = torch.utils.data.DataLoader(dataset=mnist,

batch_size=batch_size,

shuffle=True)

for i,_ in data_loader:

print(i.shape,i.max(),i.min(),torch.unique(_))

break

鉴别器

D = nn.Sequential(

nn.Linear(image_size, hidden_size),

nn.LeakyReLU(0.2),

nn.Linear(hidden_size, hidden_size),

nn.LeakyReLU(0.2),

nn.Linear(hidden_size, 1),

nn.Sigmoid())

生成器

G = nn.Sequential(

nn.Linear(latent_size, hidden_size),

nn.ReLU(),

nn.Linear(hidden_size, hidden_size),

nn.ReLU(),

nn.Linear(hidden_size, image_size),

nn.Tanh())

Device setting

D = D.to(device)

G = G.to(device)

Binary cross entropy loss and optimizer

criterion = nn.BCELoss()

d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)

g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

def denorm(x):

out = (x + 1) / 2

return out.clamp(0, 1) # 裁剪到0-1

def reset_grad():

d_optimizer.zero_grad()

g_optimizer.zero_grad()

torch.zeros(4, 1).shape

torch.randn(4,50).shape

生成对抗中有生成器和鉴别器,生成器是把潜在向量作为输入,网络输出图片的一维向量形式,在模型训练时,

判别器对真实图片和生成图片进行判别(打分),这是个二分类,这里用1表示真,0表示假,所以鉴别器损失就由

真实图片和全1的损失,生成图片和全0的损失,这两个损失组成,之后反向传播,这时候更新的就是鉴别器的参数

鉴别器训练的目的是为了辨别真实和生成,而生成器呢,生成器的目的是让生成的图片尽量接近真实,但是你怎么判断

它接近真实,就是传入鉴别器,鉴别器打分越高,说明它越真实,所以生成器损失就是全1标签和鉴别logits间的误差

特别注意的是:在鉴别器损失反向传播时,更新的只是鉴别器模型中的参数,而生成器损失反向传播时,更新的也只是

#生成器模型中的参数

训练

total_step = len(data_loader) # 总批次数

for epoch in range(num_epochs):

遍历每个批次数据

for i, (images, _) in enumerate(data_loader):

images = images.reshape(batch_size, -1).to(device)

创建稍后用作BCE损失输入的标签(全1表示真,全0表示假)

real_labels = torch.ones(batch_size, 1).to(device)

fake_labels = torch.zeros(batch_size, 1).to(device)

训练鉴别器

outputs = D(images) # 获取鉴别器对真实图片的鉴别分数

d_loss_real = criterion(outputs, real_labels) # 真实鉴别损失

real_score = outputs # 真实鉴别分数

z是随机初始化的生成图片(latent_size是用这个大小的向量表示图片)

z = torch.randn(batch_size, latent_size).to(device)

fake_images = G(z) # 获取生成的图片

outputs = D(fake_images) # 获取鉴别器对生成图片的鉴别分数

d_loss_fake = criterion(outputs, fake_labels) # 计算生成(假)的鉴别损失

fake_score = outputs

Backprop and optimize

d_loss = d_loss_real + d_loss_fake # 这两个加起来是鉴别器损失

reset_grad()

鉴别器损失反向传播

d_loss.backward()

根据梯度更新参数

d_optimizer.step()

训练生成器

随机初始化一个噪音图片(用一定大小的向量表示)

z = torch.randn(batch_size, latent_size).to(device)

通过生成器生成图片

fake_images = G(z)

outputs = D(fake_images) # 鉴别器对生成图片的鉴别得分

生成器的目的是使生成的图片足够真实,也就是最小化全1标签和鉴别logits间的误差

g_loss = criterion(outputs, real_labels)

清理之前梯度,用生成器损失反向传播,用g_optimizer更新参数

reset_grad()

g_loss.backward()

g_optimizer.step()

每隔200批次打印日志

if (i+1) % 200 == 0:

print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'

.format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),

real_score.mean().item(), fake_score.mean().item()))

真实图片只需要保存一次

if (epoch+1) == 1:

images = images.reshape(images.size(0), 1, 28, 28)

save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))

每个轮次都会保存一次生成器生成的图片

fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)

save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

Save the model checkpoints (把模型中各个层的参数字典保存到磁盘)

torch.save(G.state_dict(), 'G.ckpt')

torch.save(D.state_dict(), 'D.ckpt')

上面是真实图片,下面是经过200个轮次后的生成图片

相关推荐
weixin_437497774 小时前
读书笔记:Context Engineering 2.0 (上)
人工智能·nlp
喝拿铁写前端4 小时前
前端开发者使用 AI 的能力层级——从表面使用到工程化能力的真正分水岭
前端·人工智能·程序员
goodfat4 小时前
Win11如何关闭自动更新 Win11暂停系统更新的设置方法【教程】
人工智能·禁止windows更新·win11优化工具
北京领雁科技5 小时前
领雁科技反洗钱案例白皮书暨人工智能在反洗钱系统中的深度应用
人工智能·科技·安全
落叶,听雪5 小时前
河南建站系统哪个好
大数据·人工智能·python
清月电子5 小时前
杰理AC109N系列AC1082 AC1074 AC1090 芯片停产替代及资料说明
人工智能·单片机·嵌入式硬件·物联网
Dev7z5 小时前
非线性MPC在自动驾驶路径跟踪与避障控制中的应用及Matlab实现
人工智能·matlab·自动驾驶
七月shi人5 小时前
AI浪潮下,前端路在何方
前端·人工智能·ai编程
橙汁味的风5 小时前
1隐马尔科夫模型HMM与条件随机场CRF
人工智能·深度学习·机器学习
itwangyang5206 小时前
AIDD-人工智能药物设计-AI 制药编码之战:预测癌症反应,选对方法是关键
人工智能