pytorch生成对抗网络

生成对抗网络

import os

import torch

import torchvision

import torch.nn as nn

from torchvision import transforms

from torchvision.utils import save_image

Device configuration

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

超参数

latent_size = 64 # 潜在空间(latent space)的维度数量

hidden_size = 256

image_size = 784

num_epochs = 200

batch_size = 100

sample_dir = 'samples'

Create a directory if not exists

if not os.path.exists(sample_dir):

os.makedirs(sample_dir)

Image processing

transform = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize(mean=(0.5, 0.5, 0.5), # 3 for RGB channels

std=(0.5, 0.5, 0.5))])

transform = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize(mean=0.5, # 1 for greyscale channels

std=0.5)])

MNIST dataset

mnist = torchvision.datasets.MNIST(root='./datasets',

train=True,

transform=transform,

download=True)

data_loader = torch.utils.data.DataLoader(dataset=mnist,

batch_size=batch_size,

shuffle=True)

for i,_ in data_loader:

print(i.shape,i.max(),i.min(),torch.unique(_))

break

鉴别器

D = nn.Sequential(

nn.Linear(image_size, hidden_size),

nn.LeakyReLU(0.2),

nn.Linear(hidden_size, hidden_size),

nn.LeakyReLU(0.2),

nn.Linear(hidden_size, 1),

nn.Sigmoid())

生成器

G = nn.Sequential(

nn.Linear(latent_size, hidden_size),

nn.ReLU(),

nn.Linear(hidden_size, hidden_size),

nn.ReLU(),

nn.Linear(hidden_size, image_size),

nn.Tanh())

Device setting

D = D.to(device)

G = G.to(device)

Binary cross entropy loss and optimizer

criterion = nn.BCELoss()

d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)

g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

def denorm(x):

out = (x + 1) / 2

return out.clamp(0, 1) # 裁剪到0-1

def reset_grad():

d_optimizer.zero_grad()

g_optimizer.zero_grad()

torch.zeros(4, 1).shape

torch.randn(4,50).shape

生成对抗中有生成器和鉴别器,生成器是把潜在向量作为输入,网络输出图片的一维向量形式,在模型训练时,

判别器对真实图片和生成图片进行判别(打分),这是个二分类,这里用1表示真,0表示假,所以鉴别器损失就由

真实图片和全1的损失,生成图片和全0的损失,这两个损失组成,之后反向传播,这时候更新的就是鉴别器的参数

鉴别器训练的目的是为了辨别真实和生成,而生成器呢,生成器的目的是让生成的图片尽量接近真实,但是你怎么判断

它接近真实,就是传入鉴别器,鉴别器打分越高,说明它越真实,所以生成器损失就是全1标签和鉴别logits间的误差

特别注意的是:在鉴别器损失反向传播时,更新的只是鉴别器模型中的参数,而生成器损失反向传播时,更新的也只是

#生成器模型中的参数

训练

total_step = len(data_loader) # 总批次数

for epoch in range(num_epochs):

遍历每个批次数据

for i, (images, _) in enumerate(data_loader):

images = images.reshape(batch_size, -1).to(device)

创建稍后用作BCE损失输入的标签(全1表示真,全0表示假)

real_labels = torch.ones(batch_size, 1).to(device)

fake_labels = torch.zeros(batch_size, 1).to(device)

训练鉴别器

outputs = D(images) # 获取鉴别器对真实图片的鉴别分数

d_loss_real = criterion(outputs, real_labels) # 真实鉴别损失

real_score = outputs # 真实鉴别分数

z是随机初始化的生成图片(latent_size是用这个大小的向量表示图片)

z = torch.randn(batch_size, latent_size).to(device)

fake_images = G(z) # 获取生成的图片

outputs = D(fake_images) # 获取鉴别器对生成图片的鉴别分数

d_loss_fake = criterion(outputs, fake_labels) # 计算生成(假)的鉴别损失

fake_score = outputs

Backprop and optimize

d_loss = d_loss_real + d_loss_fake # 这两个加起来是鉴别器损失

reset_grad()

鉴别器损失反向传播

d_loss.backward()

根据梯度更新参数

d_optimizer.step()

训练生成器

随机初始化一个噪音图片(用一定大小的向量表示)

z = torch.randn(batch_size, latent_size).to(device)

通过生成器生成图片

fake_images = G(z)

outputs = D(fake_images) # 鉴别器对生成图片的鉴别得分

生成器的目的是使生成的图片足够真实,也就是最小化全1标签和鉴别logits间的误差

g_loss = criterion(outputs, real_labels)

清理之前梯度,用生成器损失反向传播,用g_optimizer更新参数

reset_grad()

g_loss.backward()

g_optimizer.step()

每隔200批次打印日志

if (i+1) % 200 == 0:

print('Epoch {}/{}, Step {}/{}, d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'

.format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),

real_score.mean().item(), fake_score.mean().item()))

真实图片只需要保存一次

if (epoch+1) == 1:

images = images.reshape(images.size(0), 1, 28, 28)

save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))

每个轮次都会保存一次生成器生成的图片

fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)

save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

Save the model checkpoints (把模型中各个层的参数字典保存到磁盘)

torch.save(G.state_dict(), 'G.ckpt')

torch.save(D.state_dict(), 'D.ckpt')

上面是真实图片,下面是经过200个轮次后的生成图片

相关推荐
星浩AI6 分钟前
合规项目大模型如何部署?硬件选型 + vLLM/LMDeploy 实战
pytorch·后端·llm
云烟成雨TD14 分钟前
Agent Scope Java 2.x 系列【10】技能(Skill)
java·人工智能·agent
GDAL16 分钟前
书签栏的 AI 转型:用 bge-small-zh-v1.5 重塑书签管理
人工智能·书签栏
青山如墨雨如画21 分钟前
【北邮-无线通信中的人工智能】物理层技术中AI的应用实践:基于KNN的调制识别(1)理论基础
人工智能·python·机器学习·matlab·jupyter
xhtdj37 分钟前
智源大会圆桌大模型没有终局具身智能可能是中国的 AlphaGo 时刻
人工智能·clickhouse·安全·动态规划
HavenlonLabs39 分钟前
区块链解决信任分布,AI 需要解决能力控制
人工智能·安全·区块链
良枫42 分钟前
01 “自进化 Agent”是什么
人工智能
LaughingZhu43 分钟前
Product Hunt 每日热榜 | 2026-06-12
人工智能·经验分享·深度学习·神经网络·产品运营
数据门徒44 分钟前
神经网络原理 第十一章:植根于统计力学的随机机器和它们的逼近
人工智能·深度学习·神经网络
AI 编程助手GPT1 小时前
用 Python 做一个世界杯赛前分析脚本:以巴西 vs 摩洛哥为例
开发语言·网络·人工智能·python·chatgpt