G1 - 生成对抗网络(GAN)

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客

  • 🍖 原作者:K同学啊

    复制代码
      GAN就是让两个AI"互相斗智":一个想"造假",一个想"识假"。通过不断"斗",造假的越来越像真的,识假的也越来越厉害,最后达到一个平衡点,造假的几乎能以假乱真。
     这就像一个学生和老师的关系:学生努力学习,老师不断出难题;学生通过老师的反馈不断进步,老师也通过学生的进步而更了解教学难点。最终,学生能解答几乎所有问题,老师能出最难的题。
    这就是GAN------两个AI互相"斗"出来的神奇结果!
python 复制代码
# ================ 代码功能说明 ================
# 这是一个生成对抗网络(GAN)的完整实现,用来学习生成手写数字图片(MNIST数据集)
# 生成器(Generator):把随机噪声变成手写数字图片
# 判别器(Discriminator):判断图片是真实的还是生成的
# 两者互相"斗法"直到生成器能造出以假乱真的图片

# ================ 1. 准备工作 ================
import argparse  # 用来接收命令行参数(比如训练轮数)
import os  # 操作系统命令(创建文件夹)
import numpy as np  # 科学计算库(处理数字)
import torchvision.transforms as transforms  # 图像处理工具
from torchvision.utils import save_image  # 保存图片
from torch.utils.data import DataLoader  # 加载数据集
from torchvision import datasets  # MNIST数据集
from torch.autograd import Variable  # 为Tensor添加梯度计算功能
import torch.nn as nn  # 神经网络核心模块
import torch  # PyTorch深度学习框架
import ssl  # 解决HTTPS证书问题(防止下载数据集时出错)
ssl._create_default_https_context = ssl._create_unverified_context  # 关闭SSL证书验证

# 创建三个文件夹:
# - images/:保存训练中生成的图片(看效果用)
# - save/:保存最终训练好的模型(以后能直接用)
# - datasets/mnist/:存放下载的MNIST手写数字数据集
os.makedirs("./images/", exist_ok=True)
os.makedirs("./save/", exist_ok=True)
os.makedirs("./datasets/mnist", exist_ok=True)

# ================ 2. 设置训练参数 ================
n_epochs = 50          # 训练50轮(每轮遍历所有数据)
batch_size = 64        # 每次训练用64张图片
lr = 0.0002            # 学习率(控制模型更新速度)
b1 = 0.5               # Adam优化器参数(控制梯度衰减)
b2 = 0.999             # Adam优化器参数
n_cpu = 2              # 使用2个CPU核心加速
latent_dim = 100       # 随机噪声的维度(100个随机数)
img_size = 28          # 图片尺寸(28x28像素)
channels = 1           # 图片通道(黑白图=1通道)
sample_interval = 500  # 每训练500次保存一次生成的图片

# 图片形状:(1, 28, 28) → 总像素数=784
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)  # 1*28*28=784

# 检查是否能用GPU(速度更快):
cuda = True if torch.cuda.is_available() else False
print("是否使用GPU:", cuda)  # 打印结果:True/False

# ================ 3. 下载并处理数据 ================
# 从MNIST下载手写数字数据集(28x28黑白图)
mnist = datasets.MNIST(
    root='./datasets/',  # 保存位置
    train=True,         # 下载训练集
    download=True,      # 自动下载
    transform=transforms.Compose([
        transforms.Resize(img_size),   # 缩放到28x28
        transforms.ToTensor(),         # 转成PyTorch张量
        transforms.Normalize([0.5], [0.5])  # 归一化到[-1,1]
    ]),
)

# 创建数据加载器(每次给64张图片)
dataloader = DataLoader(
    mnist,
    batch_size=batch_size,
    shuffle=True,  # 打乱顺序防止模型记住顺序
)

# ================ 4. 构建判别器(判断真假) ================
class Discriminator(nn.Module):  # 判别器类
    def __init__(self):
        super(Discriminator, self).__init__()
        # 一个简单的全连接神经网络:
        # 输入784个像素 → 512个神经元 → 256个神经元 → 1个输出(0~1概率)
        self.model = nn.Sequential(
            nn.Linear(img_area, 512),      # 784→512
            nn.LeakyReLU(0.2, inplace=True),  # 激活函数(解决梯度消失)
            nn.Linear(512, 256),          # 512→256
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),            # 256→1(输出概率)
            nn.Sigmoid()                  # 0~1概率(1=真图,0=假图)
        )
    
    def forward(self, img):
        # 把图片拉成一维向量(64,784)
        img_flat = img.view(img.size(0), -1)
        # 通过网络得到真假概率
        validity = self.model(img_flat)
        return validity

# ================ 5. 构建生成器(生成假图) ================
class Generator(nn.Module):  # 生成器类
    def __init__(self):
        super(Generator, self).__init__()
        # 辅助函数:创建一个带正则化的神经网络层
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]  # 线性变换
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))  # 正则化(加速训练)
            layers.append(nn.LeakyReLU(0.2, inplace=True))  # 激活函数
            return layers
        
        # 生成器网络结构:
        # 100维噪声 → 128 → 256 → 512 → 1024 → 784(输出)
        # 最后用Tanh让输出在[-1,1]之间(符合归一化后的数据范围)
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),  # 100→128(不用正则化)
            *block(128, 256),                          # 128→256
            *block(256, 512),                          # 256→512
            *block(512, 1024),                         # 512→1024
            nn.Linear(1024, img_area),                 # 1024→784
            nn.Tanh()                                  # 输出归一化到[-1,1]
        )
    
    def forward(self, z):
        # z是100维随机噪声(64个样本)
        imgs = self.model(z)  # 生成图片(784维向量)
        # 重塑成(64,1,28,28)(PyTorch需要的图片格式)
        imgs = imgs.view(imgs.size(0), *img_shape)
        return imgs

# ================ 6. 初始化模型 ================
generator = Generator()  # 创建生成器
discriminator = Discriminator()  # 创建判别器

# 损失函数:衡量真假判断的准确性(二分类交叉熵)
criterion = torch.nn.BCELoss()  # 二分类交叉熵

# 优化器:Adam优化器(比普通梯度下降更快更好)
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

# 如果有GPU,把模型搬到GPU上加速
if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion = criterion.cuda()

# ================ 7. 训练循环 ================
for epoch in range(n_epochs):  # 训练50轮
    for i, (imgs, _) in enumerate(dataloader):  # 遍历数据集
        # ====== 步骤1:训练判别器 ======
        # 把图片拉成一维(64,784)
        imgs = imgs.view(imgs.size(0), -1)
        # 转成可计算张量(GPU上)
        real_img = Variable(imgs).cuda()
        # 真实图片的标签:全1(表示"这是真图")
        real_label = Variable(torch.ones(imgs.size(0), 1)).cuda()
        # 假图片的标签:全0(表示"这是假图")
        fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda()

        # 判别器的损失 = 真图被判断为真 + 假图被判断为假
        # 真图:输入判别器 → 得到概率 → 计算和标签的差距
        real_out = discriminator(real_img)
        loss_real_D = criterion(real_out, real_label)
        # 假图:生成器生成假图 → 判别器判断 → 计算差距
        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()  # 生成随机噪声
        fake_img = generator(z).detach()  # 生成假图(detach:不更新生成器参数)
        fake_out = discriminator(fake_img)
        loss_fake_D = criterion(fake_out, fake_label)
        loss_D = loss_real_D + loss_fake_D  # 总损失

        # 优化判别器:反向传播 + 更新参数
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()

        # ====== 步骤2:训练生成器 ======
        # 生成器的目标:让判别器把假图判断成真图
        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()  # 新随机噪声
        fake_img = generator(z)  # 生成假图
        output = discriminator(fake_img)  # 判别器判断假图
        # 生成器损失:希望判别器输出=1(真图)
        loss_G = criterion(output, real_label)

        # 优化生成器:反向传播 + 更新参数
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        # ====== 打印训练进度 ======
        if (i + 1) % 300 == 0:  # 每300次打印一次
            print(f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] "
                  f"[D loss: {loss_D.item():.4f}] [G loss: {loss_G.item():.4f}] "
                  f"[D real: {real_out.data.mean():.4f}] [D fake: {fake_out.data.mean():.4f}]")

        # ====== 保存生成的图片(每500次保存一次) ======
        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            # 保存前25张生成的图片(5x5网格)
            save_image(fake_img.data[:25], f"./images/{batches_done}.png", nrow=5, normalize=True)

# ================ 8. 保存最终模型 ================
torch.save(generator.state_dict(), './save/generator.pth')  # 保存生成器
torch.save(discriminator.state_dict(), './save/discriminator.pth')  # 保存判别器
bash 复制代码
[Epoch 0/50] [Batch 299/938] [D loss: 1.108700] [G loss: 1.494937] [D real: 0.765423] [D fake: 0.563390]
[Epoch 0/50] [Batch 599/938] [D loss: 0.981047] [G loss: 2.200819] [D real: 0.859328] [D fake: 0.555203]
[Epoch 0/50] [Batch 899/938] [D loss: 1.012156] [G loss: 1.935689] [D real: 0.728062] [D fake: 0.476248]
[Epoch 1/50] [Batch 299/938] [D loss: 1.188978] [G loss: 0.676110] [D real: 0.426300] [D fake: 0.200765]
[Epoch 1/50] [Batch 599/938] [D loss: 1.007571] [G loss: 1.044460] [D real: 0.562748] [D fake: 0.284159]
[Epoch 1/50] [Batch 899/938] [D loss: 1.071741] [G loss: 1.711364] [D real: 0.720821] [D fake: 0.483612]
[Epoch 2/50] [Batch 299/938] [D loss: 0.910406] [G loss: 2.151794] [D real: 0.764064] [D fake: 0.448280]
[Epoch 2/50] [Batch 599/938] [D loss: 0.800963] [G loss: 1.313761] [D real: 0.613358] [D fake: 0.188154]
[Epoch 2/50] [Batch 899/938] [D loss: 1.093633] [G loss: 1.053562] [D real: 0.531550] [D fake: 0.230020]
[Epoch 3/50] [Batch 299/938] [D loss: 0.963498] [G loss: 2.506877] [D real: 0.811666] [D fake: 0.497298]
[Epoch 3/50] [Batch 599/938] [D loss: 1.083450] [G loss: 0.882004] [D real: 0.465563] [D fake: 0.117864]
[Epoch 3/50] [Batch 899/938] [D loss: 0.973209] [G loss: 2.698256] [D real: 0.809422] [D fake: 0.502016]
[Epoch 4/50] [Batch 299/938] [D loss: 0.817019] [G loss: 1.351617] [D real: 0.666476] [D fake: 0.273635]
.......
相关推荐
NAGNIP7 小时前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab8 小时前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab8 小时前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP12 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年12 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼12 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS13 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区14 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈14 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang14 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx