pytorch生成对抗网络

生成对抗网络

import os

import torch

import torchvision

import torch.nn as nn

from torchvision import transforms

from torchvision.utils import save_image

Device configuration

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

超参数

latent_size = 64 # 潜在空间(latent space)的维度数量

hidden_size = 256

image_size = 784

num_epochs = 200

batch_size = 100

sample_dir = 'samples'

Create a directory if not exists

if not os.path.exists(sample_dir):

os.makedirs(sample_dir)

Image processing

transform = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize(mean=(0.5, 0.5, 0.5), # 3 for RGB channels

std=(0.5, 0.5, 0.5))])

transform = transforms.Compose([

transforms.ToTensor(),

transforms.Normalize(mean=[0.5], # 1 for greyscale channels

std=[0.5])])

MNIST dataset

mnist = torchvision.datasets.MNIST(root='./datasets',

train=True,

transform=transform,

download=True)

data_loader = torch.utils.data.DataLoader(dataset=mnist,

batch_size=batch_size,

shuffle=True)

for i,_ in data_loader:

print(i.shape,i.max(),i.min(),torch.unique(_))

break

鉴别器

D = nn.Sequential(

nn.Linear(image_size, hidden_size),

nn.LeakyReLU(0.2),

nn.Linear(hidden_size, hidden_size),

nn.LeakyReLU(0.2),

nn.Linear(hidden_size, 1),

nn.Sigmoid())

生成器

G = nn.Sequential(

nn.Linear(latent_size, hidden_size),

nn.ReLU(),

nn.Linear(hidden_size, hidden_size),

nn.ReLU(),

nn.Linear(hidden_size, image_size),

nn.Tanh())

Device setting

D = D.to(device)

G = G.to(device)

Binary cross entropy loss and optimizer

criterion = nn.BCELoss()

d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)

g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

def denorm(x):

out = (x + 1) / 2

return out.clamp(0, 1) # 裁剪到0-1

def reset_grad():

d_optimizer.zero_grad()

g_optimizer.zero_grad()

torch.zeros(4, 1).shape

torch.randn(4,50).shape

生成对抗中有生成器和鉴别器,生成器是把潜在向量作为输入,网络输出图片的一维向量形式,在模型训练时,

判别器对真实图片和生成图片进行判别(打分),这是个二分类,这里用1表示真,0表示假,所以鉴别器损失就由

真实图片和全1的损失,生成图片和全0的损失,这两个损失组成,之后反向传播,这时候更新的就是鉴别器的参数

鉴别器训练的目的是为了辨别真实和生成,而生成器呢,生成器的目的是让生成的图片尽量接近真实,但是你怎么判断

它接近真实,就是传入鉴别器,鉴别器打分越高,说明它越真实,所以生成器损失就是全1标签和鉴别logits间的误差

特别注意的是:在鉴别器损失反向传播时,更新的只是鉴别器模型中的参数,而生成器损失反向传播时,更新的也只是

#生成器模型中的参数

训练

total_step = len(data_loader) # 总批次数

for epoch in range(num_epochs):

遍历每个批次数据

for i, (images, _) in enumerate(data_loader):

images = images.reshape(batch_size, -1).to(device)

创建稍后用作BCE损失输入的标签(全1表示真,全0表示假)

real_labels = torch.ones(batch_size, 1).to(device)

fake_labels = torch.zeros(batch_size, 1).to(device)

训练鉴别器

outputs = D(images) # 获取鉴别器对真实图片的鉴别分数

d_loss_real = criterion(outputs, real_labels) # 真实鉴别损失

real_score = outputs # 真实鉴别分数

z是随机初始化的生成图片(latent_size是用这个大小的向量表示图片)

z = torch.randn(batch_size, latent_size).to(device)

fake_images = G(z) # 获取生成的图片

outputs = D(fake_images) # 获取鉴别器对生成图片的鉴别分数

d_loss_fake = criterion(outputs, fake_labels) # 计算生成(假)的鉴别损失

fake_score = outputs

Backprop and optimize

d_loss = d_loss_real + d_loss_fake # 这两个加起来是鉴别器损失

reset_grad()

鉴别器损失反向传播

d_loss.backward()

根据梯度更新参数

d_optimizer.step()

训练生成器

随机初始化一个噪音图片(用一定大小的向量表示)

z = torch.randn(batch_size, latent_size).to(device)

通过生成器生成图片

fake_images = G(z)

outputs = D(fake_images) # 鉴别器对生成图片的鉴别得分

生成器的目的是使生成的图片足够真实,也就是最小化全1标签和鉴别logits间的误差

g_loss = criterion(outputs, real_labels)

清理之前梯度,用生成器损失反向传播,用g_optimizer更新参数

reset_grad()

g_loss.backward()

g_optimizer.step()

每隔200批次打印日志

if (i+1) % 200 == 0:

print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'

.format(epoch, num_epochs, i+1, total_step, d_loss.item(), g_loss.item(),

real_score.mean().item(), fake_score.mean().item()))

真实图片只需要保存一次

if (epoch+1) == 1:

images = images.reshape(images.size(0), 1, 28, 28)

save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))

每个轮次都会保存一次生成器生成的图片

fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)

save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

Save the model checkpoints (把模型中各个层的参数字典保存到磁盘)

torch.save(G.state_dict(), 'G.ckpt')

torch.save(D.state_dict(), 'D.ckpt')

上面是真实图片,下面是经过200个轮次后的生成图片

相关推荐
牛客企业服务12 分钟前
2025年AI面试推荐榜单,数字化招聘转型优选
人工智能·python·算法·面试·职场和发展·金融·求职招聘
视觉语言导航42 分钟前
RAL-2025 | 清华大学数字孪生驱动的机器人视觉导航!VR-Robo:面向视觉机器人导航与运动的现实-模拟-现实框架
人工智能·深度学习·机器人·具身智能
**梯度已爆炸**1 小时前
自然语言处理入门
人工智能·自然语言处理
ctrlworks1 小时前
楼宇自控核心功能:实时监控设备运行,快速诊断故障,赋能设备寿命延长
人工智能·ba系统厂商·楼宇自控系统厂家·ibms系统厂家·建筑管理系统厂家·能耗监测系统厂家
BFT白芙堂2 小时前
睿尔曼系列机器人——以创新驱动未来,重塑智能协作新生态(上)
人工智能·机器学习·机器人·协作机器人·复合机器人·睿尔曼机器人
aneasystone本尊2 小时前
使用 MCP 让 Claude Code 集成外部工具
人工智能
静心问道2 小时前
SEW:无监督预训练在语音识别中的性能-效率权衡
人工智能·语音识别
羊小猪~~2 小时前
【NLP入门系列五】中文文本分类案例
人工智能·深度学习·考研·机器学习·自然语言处理·分类·数据挖掘
xwz小王子2 小时前
从LLM到WM:大语言模型如何进化成具身世界模型?
人工智能·语言模型·自然语言处理
我爱一条柴ya2 小时前
【AI大模型】深入理解 Transformer 架构:自然语言处理的革命引擎
人工智能·ai·ai作画·ai编程·ai写作