pytorch生成对抗网络

生成对抗网络

import os

import torch

import torchvision

import torch.nn as nn

from torchvision import transforms

from torchvision.utils import save_image

Device configuration

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

超参数

latent_size = 64 # 潜在空间(latent space)的维度数量

hidden_size = 256

image_size = 784

num_epochs = 200

batch_size = 100

sample_dir = 'samples'

Create a directory if not exists

if not os.path.exists(sample_dir):

os.makedirs(sample_dir)

Image processing

transform = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize(mean=(0.5, 0.5, 0.5), # 3 for RGB channels

std=(0.5, 0.5, 0.5))])

transform = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize(mean=[0.5], # 1 for greyscale channels

std=[0.5])])

MNIST dataset

mnist = torchvision.datasets.MNIST(root='./datasets',

train=True,

transform=transform,

download=True)

data_loader = torch.utils.data.DataLoader(dataset=mnist,

batch_size=batch_size,

shuffle=True)

for i,_ in data_loader:

print(i.shape,i.max(),i.min(),torch.unique(_))

break

鉴别器

D = nn.Sequential(

nn.Linear(image_size, hidden_size),

nn.LeakyReLU(0.2),

nn.Linear(hidden_size, hidden_size),

nn.LeakyReLU(0.2),

nn.Linear(hidden_size, 1),

nn.Sigmoid())

生成器

G = nn.Sequential(

nn.Linear(latent_size, hidden_size),

nn.ReLU(),

nn.Linear(hidden_size, hidden_size),

nn.ReLU(),

nn.Linear(hidden_size, image_size),

nn.Tanh())

Device setting

D = D.to(device)

G = G.to(device)

Binary cross entropy loss and optimizer

criterion = nn.BCELoss()

d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)

g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

def denorm(x):

out = (x + 1) / 2

return out.clamp(0, 1) # 裁剪到0-1

def reset_grad():

d_optimizer.zero_grad()

g_optimizer.zero_grad()

torch.zeros(4, 1).shape

torch.randn(4,50).shape

生成对抗中有生成器和鉴别器,生成器是把潜在向量作为输入,网络输出图片的一维向量形式,在模型训练时,

判别器对真实图片和生成图片进行判别(打分),这是个二分类,这里用1表示真,0表示假,所以鉴别器损失就由

真实图片和全1的损失,生成图片和全0的损失,这两个损失组成,之后反向传播,这时候更新的就是鉴别器的参数

鉴别器训练的目的是为了辨别真实和生成,而生成器呢,生成器的目的是让生成的图片尽量接近真实,但是你怎么判断

它接近真实,就是传入鉴别器,鉴别器打分越高,说明它越真实,所以生成器损失就是全1标签和鉴别logits间的误差

特别注意的是:在鉴别器损失反向传播时,更新的只是鉴别器模型中的参数,而生成器损失反向传播时,更新的也只是

#生成器模型中的参数

训练

total_step = len(data_loader) # 总批次数

for epoch in range(num_epochs):

遍历每个批次数据

for i, (images, _) in enumerate(data_loader):

images = images.reshape(batch_size, -1).to(device)

创建稍后用作BCE损失输入的标签(全1表示真,全0表示假)

real_labels = torch.ones(batch_size, 1).to(device)

fake_labels = torch.zeros(batch_size, 1).to(device)

训练鉴别器

outputs = D(images) # 获取鉴别器对真实图片的鉴别分数

d_loss_real = criterion(outputs, real_labels) # 真实鉴别损失

real_score = outputs # 真实鉴别分数

z是随机初始化的生成图片(latent_size是用这个大小的向量表示图片)

z = torch.randn(batch_size, latent_size).to(device)

fake_images = G(z) # 获取生成的图片

outputs = D(fake_images) # 获取鉴别器对生成图片的鉴别分数

d_loss_fake = criterion(outputs, fake_labels) # 计算生成(假)的鉴别损失

fake_score = outputs

Backprop and optimize

d_loss = d_loss_real + d_loss_fake # 这两个加起来是鉴别器损失

reset_grad()

鉴别器损失反向传播

d_loss.backward()

根据梯度更新参数

d_optimizer.step()

训练生成器

随机初始化一个噪音图片(用一定大小的向量表示)

z = torch.randn(batch_size, latent_size).to(device)

通过生成器生成图片

fake_images = G(z)

outputs = D(fake_images) # 鉴别器对生成图片的鉴别得分

生成器的目的是使生成的图片足够真实,也就是最小化全1标签和鉴别logits间的误差

g_loss = criterion(outputs, real_labels)

清理之前梯度,用生成器损失反向传播,用g_optimizer更新参数

reset_grad()

g_loss.backward()

g_optimizer.step()

每隔200批次打印日志

if (i+1) % 200 == 0:

print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'

.format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),

real_score.mean().item(), fake_score.mean().item()))

真实图片只需要保存一次

if (epoch+1) == 1:

images = images.reshape(images.size(0), 1, 28, 28)

save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))

每个轮次都会保存一次生成器生成的图片

fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)

save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

Save the model checkpoints (把模型中各个层的参数字典保存到磁盘)

torch.save(G.state_dict(), 'G.ckpt')

torch.save(D.state_dict(), 'D.ckpt')

上面是真实图片,下面是经过200个轮次后的生成图片

相关推荐
Coder_Boy_1 天前
基于SpringAI的智能AIOps项目:部署相关容器化部署管理技术
人工智能·spring boot·k8s·运维开发
极新1 天前
生数科技商业化总监陈鹤天:视频生成破瓶颈,AI赋能漫剧产业|2025极新AIGC峰会演讲实录
人工智能·科技·aigc
u1301301 天前
开源版 NotebookLM:Open Notebook 深度体验与部署指南
人工智能·开源
说私域1 天前
基于开源AI大模型AI智能名片S2B2C商城小程序的内容价值生成与多点选择传播策略研究
人工智能·微信·小程序·开源
说私域1 天前
数据分析能力在开源AI智能名片链动2+1模式多商户商城小程序中的价值与应用研究
人工智能·数据分析·开源
Coder_Boy_1 天前
基于SpringAI企业级智能教学考试平台试卷管理模块全业务闭环方案
java·大数据·人工智能·spring boot·springboot
拾荒的小海螺1 天前
开源项目:Z-Image 轻量高效的开源 AI 图像生成模型
人工智能·开源
dagouaofei1 天前
实测!6款AI自动生成PPT工具体验分享
人工智能·python·powerpoint
newrank_kk1 天前
下一代品牌战略:把智汇GEO作为核心品牌AI形象管理工具
大数据·人工智能
传感器与混合集成电路1 天前
面向航天、深地与核工业场景的高可靠电源方案设计要点
人工智能·物联网