-
🍨 本文为🔗365天深度学习训练营 中的学习记录博客
-
🍖 原作者:K同学啊
GAN就是让两个AI"互相斗智":一个想"造假",一个想"识假"。通过不断"斗",造假的越来越像真的,识假的也越来越厉害,最后达到一个平衡点,造假的几乎能以假乱真。 这就像一个学生和老师的关系:学生努力学习,老师不断出难题;学生通过老师的反馈不断进步,老师也通过学生的进步而更了解教学难点。最终,学生能解答几乎所有问题,老师能出最难的题。 这就是GAN------两个AI互相"斗"出来的神奇结果!
python
# ================ 代码功能说明 ================
# 这是一个生成对抗网络(GAN)的完整实现,用来学习生成手写数字图片(MNIST数据集)
# 生成器(Generator):把随机噪声变成手写数字图片
# 判别器(Discriminator):判断图片是真实的还是生成的
# 两者互相"斗法"直到生成器能造出以假乱真的图片
# ================ 1. 准备工作 ================
import argparse # 用来接收命令行参数(比如训练轮数)
import os # 操作系统命令(创建文件夹)
import numpy as np # 科学计算库(处理数字)
import torchvision.transforms as transforms # 图像处理工具
from torchvision.utils import save_image # 保存图片
from torch.utils.data import DataLoader # 加载数据集
from torchvision import datasets # MNIST数据集
from torch.autograd import Variable # 为Tensor添加梯度计算功能
import torch.nn as nn # 神经网络核心模块
import torch # PyTorch深度学习框架
import ssl # 解决HTTPS证书问题(防止下载数据集时出错)
ssl._create_default_https_context = ssl._create_unverified_context # 关闭SSL证书验证
# 创建三个文件夹:
# - images/:保存训练中生成的图片(看效果用)
# - save/:保存最终训练好的模型(以后能直接用)
# - datasets/mnist/:存放下载的MNIST手写数字数据集
os.makedirs("./images/", exist_ok=True)
os.makedirs("./save/", exist_ok=True)
os.makedirs("./datasets/mnist", exist_ok=True)
# ================ 2. 设置训练参数 ================
n_epochs = 50 # 训练50轮(每轮遍历所有数据)
batch_size = 64 # 每次训练用64张图片
lr = 0.0002 # 学习率(控制模型更新速度)
b1 = 0.5 # Adam优化器参数(控制梯度衰减)
b2 = 0.999 # Adam优化器参数
n_cpu = 2 # 使用2个CPU核心加速
latent_dim = 100 # 随机噪声的维度(100个随机数)
img_size = 28 # 图片尺寸(28x28像素)
channels = 1 # 图片通道(黑白图=1通道)
sample_interval = 500 # 每训练500次保存一次生成的图片
# 图片形状:(1, 28, 28) → 总像素数=784
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape) # 1*28*28=784
# 检查是否能用GPU(速度更快):
cuda = True if torch.cuda.is_available() else False
print("是否使用GPU:", cuda) # 打印结果:True/False
# ================ 3. 下载并处理数据 ================
# 从MNIST下载手写数字数据集(28x28黑白图)
mnist = datasets.MNIST(
root='./datasets/', # 保存位置
train=True, # 下载训练集
download=True, # 自动下载
transform=transforms.Compose([
transforms.Resize(img_size), # 缩放到28x28
transforms.ToTensor(), # 转成PyTorch张量
transforms.Normalize([0.5], [0.5]) # 归一化到[-1,1]
]),
)
# 创建数据加载器(每次给64张图片)
dataloader = DataLoader(
mnist,
batch_size=batch_size,
shuffle=True, # 打乱顺序防止模型记住顺序
)
# ================ 4. 构建判别器(判断真假) ================
class Discriminator(nn.Module): # 判别器类
def __init__(self):
super(Discriminator, self).__init__()
# 一个简单的全连接神经网络:
# 输入784个像素 → 512个神经元 → 256个神经元 → 1个输出(0~1概率)
self.model = nn.Sequential(
nn.Linear(img_area, 512), # 784→512
nn.LeakyReLU(0.2, inplace=True), # 激活函数(解决梯度消失)
nn.Linear(512, 256), # 512→256
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1), # 256→1(输出概率)
nn.Sigmoid() # 0~1概率(1=真图,0=假图)
)
def forward(self, img):
# 把图片拉成一维向量(64,784)
img_flat = img.view(img.size(0), -1)
# 通过网络得到真假概率
validity = self.model(img_flat)
return validity
# ================ 5. 构建生成器(生成假图) ================
class Generator(nn.Module): # 生成器类
def __init__(self):
super(Generator, self).__init__()
# 辅助函数:创建一个带正则化的神经网络层
def block(in_feat, out_feat, normalize=True):
layers = [nn.Linear(in_feat, out_feat)] # 线性变换
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8)) # 正则化(加速训练)
layers.append(nn.LeakyReLU(0.2, inplace=True)) # 激活函数
return layers
# 生成器网络结构:
# 100维噪声 → 128 → 256 → 512 → 1024 → 784(输出)
# 最后用Tanh让输出在[-1,1]之间(符合归一化后的数据范围)
self.model = nn.Sequential(
*block(latent_dim, 128, normalize=False), # 100→128(不用正则化)
*block(128, 256), # 128→256
*block(256, 512), # 256→512
*block(512, 1024), # 512→1024
nn.Linear(1024, img_area), # 1024→784
nn.Tanh() # 输出归一化到[-1,1]
)
def forward(self, z):
# z是100维随机噪声(64个样本)
imgs = self.model(z) # 生成图片(784维向量)
# 重塑成(64,1,28,28)(PyTorch需要的图片格式)
imgs = imgs.view(imgs.size(0), *img_shape)
return imgs
# ================ 6. 初始化模型 ================
generator = Generator() # 创建生成器
discriminator = Discriminator() # 创建判别器
# 损失函数:衡量真假判断的准确性(二分类交叉熵)
criterion = torch.nn.BCELoss() # 二分类交叉熵
# 优化器:Adam优化器(比普通梯度下降更快更好)
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))
# 如果有GPU,把模型搬到GPU上加速
if cuda:
generator = generator.cuda()
discriminator = discriminator.cuda()
criterion = criterion.cuda()
# ================ 7. 训练循环 ================
for epoch in range(n_epochs): # 训练50轮
for i, (imgs, _) in enumerate(dataloader): # 遍历数据集
# ====== 步骤1:训练判别器 ======
# 把图片拉成一维(64,784)
imgs = imgs.view(imgs.size(0), -1)
# 转成可计算张量(GPU上)
real_img = Variable(imgs).cuda()
# 真实图片的标签:全1(表示"这是真图")
real_label = Variable(torch.ones(imgs.size(0), 1)).cuda()
# 假图片的标签:全0(表示"这是假图")
fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda()
# 判别器的损失 = 真图被判断为真 + 假图被判断为假
# 真图:输入判别器 → 得到概率 → 计算和标签的差距
real_out = discriminator(real_img)
loss_real_D = criterion(real_out, real_label)
# 假图:生成器生成假图 → 判别器判断 → 计算差距
z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda() # 生成随机噪声
fake_img = generator(z).detach() # 生成假图(detach:不更新生成器参数)
fake_out = discriminator(fake_img)
loss_fake_D = criterion(fake_out, fake_label)
loss_D = loss_real_D + loss_fake_D # 总损失
# 优化判别器:反向传播 + 更新参数
optimizer_D.zero_grad()
loss_D.backward()
optimizer_D.step()
# ====== 步骤2:训练生成器 ======
# 生成器的目标:让判别器把假图判断成真图
z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda() # 新随机噪声
fake_img = generator(z) # 生成假图
output = discriminator(fake_img) # 判别器判断假图
# 生成器损失:希望判别器输出=1(真图)
loss_G = criterion(output, real_label)
# 优化生成器:反向传播 + 更新参数
optimizer_G.zero_grad()
loss_G.backward()
optimizer_G.step()
# ====== 打印训练进度 ======
if (i + 1) % 300 == 0: # 每300次打印一次
print(f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] "
f"[D loss: {loss_D.item():.4f}] [G loss: {loss_G.item():.4f}] "
f"[D real: {real_out.data.mean():.4f}] [D fake: {fake_out.data.mean():.4f}]")
# ====== 保存生成的图片(每500次保存一次) ======
batches_done = epoch * len(dataloader) + i
if batches_done % sample_interval == 0:
# 保存前25张生成的图片(5x5网格)
save_image(fake_img.data[:25], f"./images/{batches_done}.png", nrow=5, normalize=True)
# ================ 8. 保存最终模型 ================
torch.save(generator.state_dict(), './save/generator.pth') # 保存生成器
torch.save(discriminator.state_dict(), './save/discriminator.pth') # 保存判别器
bash
[Epoch 0/50] [Batch 299/938] [D loss: 1.108700] [G loss: 1.494937] [D real: 0.765423] [D fake: 0.563390]
[Epoch 0/50] [Batch 599/938] [D loss: 0.981047] [G loss: 2.200819] [D real: 0.859328] [D fake: 0.555203]
[Epoch 0/50] [Batch 899/938] [D loss: 1.012156] [G loss: 1.935689] [D real: 0.728062] [D fake: 0.476248]
[Epoch 1/50] [Batch 299/938] [D loss: 1.188978] [G loss: 0.676110] [D real: 0.426300] [D fake: 0.200765]
[Epoch 1/50] [Batch 599/938] [D loss: 1.007571] [G loss: 1.044460] [D real: 0.562748] [D fake: 0.284159]
[Epoch 1/50] [Batch 899/938] [D loss: 1.071741] [G loss: 1.711364] [D real: 0.720821] [D fake: 0.483612]
[Epoch 2/50] [Batch 299/938] [D loss: 0.910406] [G loss: 2.151794] [D real: 0.764064] [D fake: 0.448280]
[Epoch 2/50] [Batch 599/938] [D loss: 0.800963] [G loss: 1.313761] [D real: 0.613358] [D fake: 0.188154]
[Epoch 2/50] [Batch 899/938] [D loss: 1.093633] [G loss: 1.053562] [D real: 0.531550] [D fake: 0.230020]
[Epoch 3/50] [Batch 299/938] [D loss: 0.963498] [G loss: 2.506877] [D real: 0.811666] [D fake: 0.497298]
[Epoch 3/50] [Batch 599/938] [D loss: 1.083450] [G loss: 0.882004] [D real: 0.465563] [D fake: 0.117864]
[Epoch 3/50] [Batch 899/938] [D loss: 0.973209] [G loss: 2.698256] [D real: 0.809422] [D fake: 0.502016]
[Epoch 4/50] [Batch 299/938] [D loss: 0.817019] [G loss: 1.351617] [D real: 0.666476] [D fake: 0.273635]
.......