G1 - 生成对抗网络(GAN)

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客

  • 🍖 原作者:K同学啊

    复制代码
      GAN就是让两个AI"互相斗智":一个想"造假",一个想"识假"。通过不断"斗",造假的越来越像真的,识假的也越来越厉害,最后达到一个平衡点,造假的几乎能以假乱真。
     这就像一个学生和老师的关系:学生努力学习,老师不断出难题;学生通过老师的反馈不断进步,老师也通过学生的进步而更了解教学难点。最终,学生能解答几乎所有问题,老师能出最难的题。
    这就是GAN------两个AI互相"斗"出来的神奇结果!
python 复制代码
# ================ 代码功能说明 ================
# 这是一个生成对抗网络(GAN)的完整实现,用来学习生成手写数字图片(MNIST数据集)
# 生成器(Generator):把随机噪声变成手写数字图片
# 判别器(Discriminator):判断图片是真实的还是生成的
# 两者互相"斗法"直到生成器能造出以假乱真的图片

# ================ 1. 准备工作 ================
import argparse  # 用来接收命令行参数(比如训练轮数)
import os  # 操作系统命令(创建文件夹)
import numpy as np  # 科学计算库(处理数字)
import torchvision.transforms as transforms  # 图像处理工具
from torchvision.utils import save_image  # 保存图片
from torch.utils.data import DataLoader  # 加载数据集
from torchvision import datasets  # MNIST数据集
from torch.autograd import Variable  # 为Tensor添加梯度计算功能
import torch.nn as nn  # 神经网络核心模块
import torch  # PyTorch深度学习框架
import ssl  # 解决HTTPS证书问题(防止下载数据集时出错)
ssl._create_default_https_context = ssl._create_unverified_context  # 关闭SSL证书验证

# 创建三个文件夹:
# - images/:保存训练中生成的图片(看效果用)
# - save/:保存最终训练好的模型(以后能直接用)
# - datasets/mnist/:存放下载的MNIST手写数字数据集
os.makedirs("./images/", exist_ok=True)
os.makedirs("./save/", exist_ok=True)
os.makedirs("./datasets/mnist", exist_ok=True)

# ================ 2. 设置训练参数 ================
n_epochs = 50          # 训练50轮(每轮遍历所有数据)
batch_size = 64        # 每次训练用64张图片
lr = 0.0002            # 学习率(控制模型更新速度)
b1 = 0.5               # Adam优化器参数(控制梯度衰减)
b2 = 0.999             # Adam优化器参数
n_cpu = 2              # 使用2个CPU核心加速
latent_dim = 100       # 随机噪声的维度(100个随机数)
img_size = 28          # 图片尺寸(28x28像素)
channels = 1           # 图片通道(黑白图=1通道)
sample_interval = 500  # 每训练500次保存一次生成的图片

# 图片形状:(1, 28, 28) → 总像素数=784
img_shape = (channels, img_size, img_size)
img_area = np.prod(img_shape)  # 1*28*28=784

# 检查是否能用GPU(速度更快):
cuda = True if torch.cuda.is_available() else False
print("是否使用GPU:", cuda)  # 打印结果:True/False

# ================ 3. 下载并处理数据 ================
# 从MNIST下载手写数字数据集(28x28黑白图)
mnist = datasets.MNIST(
    root='./datasets/',  # 保存位置
    train=True,         # 下载训练集
    download=True,      # 自动下载
    transform=transforms.Compose([
        transforms.Resize(img_size),   # 缩放到28x28
        transforms.ToTensor(),         # 转成PyTorch张量
        transforms.Normalize([0.5], [0.5])  # 归一化到[-1,1]
    ]),
)

# 创建数据加载器(每次给64张图片)
dataloader = DataLoader(
    mnist,
    batch_size=batch_size,
    shuffle=True,  # 打乱顺序防止模型记住顺序
)

# ================ 4. 构建判别器(判断真假) ================
class Discriminator(nn.Module):  # 判别器类
    def __init__(self):
        super(Discriminator, self).__init__()
        # 一个简单的全连接神经网络:
        # 输入784个像素 → 512个神经元 → 256个神经元 → 1个输出(0~1概率)
        self.model = nn.Sequential(
            nn.Linear(img_area, 512),      # 784→512
            nn.LeakyReLU(0.2, inplace=True),  # 激活函数(解决梯度消失)
            nn.Linear(512, 256),          # 512→256
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),            # 256→1(输出概率)
            nn.Sigmoid()                  # 0~1概率(1=真图,0=假图)
        )
    
    def forward(self, img):
        # 把图片拉成一维向量(64,784)
        img_flat = img.view(img.size(0), -1)
        # 通过网络得到真假概率
        validity = self.model(img_flat)
        return validity

# ================ 5. 构建生成器(生成假图) ================
class Generator(nn.Module):  # 生成器类
    def __init__(self):
        super(Generator, self).__init__()
        # 辅助函数:创建一个带正则化的神经网络层
        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]  # 线性变换
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))  # 正则化(加速训练)
            layers.append(nn.LeakyReLU(0.2, inplace=True))  # 激活函数
            return layers
        
        # 生成器网络结构:
        # 100维噪声 → 128 → 256 → 512 → 1024 → 784(输出)
        # 最后用Tanh让输出在[-1,1]之间(符合归一化后的数据范围)
        self.model = nn.Sequential(
            *block(latent_dim, 128, normalize=False),  # 100→128(不用正则化)
            *block(128, 256),                          # 128→256
            *block(256, 512),                          # 256→512
            *block(512, 1024),                         # 512→1024
            nn.Linear(1024, img_area),                 # 1024→784
            nn.Tanh()                                  # 输出归一化到[-1,1]
        )
    
    def forward(self, z):
        # z是100维随机噪声(64个样本)
        imgs = self.model(z)  # 生成图片(784维向量)
        # 重塑成(64,1,28,28)(PyTorch需要的图片格式)
        imgs = imgs.view(imgs.size(0), *img_shape)
        return imgs

# ================ 6. 初始化模型 ================
generator = Generator()  # 创建生成器
discriminator = Discriminator()  # 创建判别器

# 损失函数:衡量真假判断的准确性(二分类交叉熵)
criterion = torch.nn.BCELoss()  # 二分类交叉熵

# 优化器:Adam优化器(比普通梯度下降更快更好)
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(b1, b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(b1, b2))

# 如果有GPU,把模型搬到GPU上加速
if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
    criterion = criterion.cuda()

# ================ 7. 训练循环 ================
for epoch in range(n_epochs):  # 训练50轮
    for i, (imgs, _) in enumerate(dataloader):  # 遍历数据集
        # ====== 步骤1:训练判别器 ======
        # 把图片拉成一维(64,784)
        imgs = imgs.view(imgs.size(0), -1)
        # 转成可计算张量(GPU上)
        real_img = Variable(imgs).cuda()
        # 真实图片的标签:全1(表示"这是真图")
        real_label = Variable(torch.ones(imgs.size(0), 1)).cuda()
        # 假图片的标签:全0(表示"这是假图")
        fake_label = Variable(torch.zeros(imgs.size(0), 1)).cuda()

        # 判别器的损失 = 真图被判断为真 + 假图被判断为假
        # 真图:输入判别器 → 得到概率 → 计算和标签的差距
        real_out = discriminator(real_img)
        loss_real_D = criterion(real_out, real_label)
        # 假图:生成器生成假图 → 判别器判断 → 计算差距
        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()  # 生成随机噪声
        fake_img = generator(z).detach()  # 生成假图(detach:不更新生成器参数)
        fake_out = discriminator(fake_img)
        loss_fake_D = criterion(fake_out, fake_label)
        loss_D = loss_real_D + loss_fake_D  # 总损失

        # 优化判别器:反向传播 + 更新参数
        optimizer_D.zero_grad()
        loss_D.backward()
        optimizer_D.step()

        # ====== 步骤2:训练生成器 ======
        # 生成器的目标:让判别器把假图判断成真图
        z = Variable(torch.randn(imgs.size(0), latent_dim)).cuda()  # 新随机噪声
        fake_img = generator(z)  # 生成假图
        output = discriminator(fake_img)  # 判别器判断假图
        # 生成器损失:希望判别器输出=1(真图)
        loss_G = criterion(output, real_label)

        # 优化生成器:反向传播 + 更新参数
        optimizer_G.zero_grad()
        loss_G.backward()
        optimizer_G.step()

        # ====== 打印训练进度 ======
        if (i + 1) % 300 == 0:  # 每300次打印一次
            print(f"[Epoch {epoch}/{n_epochs}] [Batch {i}/{len(dataloader)}] "
                  f"[D loss: {loss_D.item():.4f}] [G loss: {loss_G.item():.4f}] "
                  f"[D real: {real_out.data.mean():.4f}] [D fake: {fake_out.data.mean():.4f}]")

        # ====== 保存生成的图片(每500次保存一次) ======
        batches_done = epoch * len(dataloader) + i
        if batches_done % sample_interval == 0:
            # 保存前25张生成的图片(5x5网格)
            save_image(fake_img.data[:25], f"./images/{batches_done}.png", nrow=5, normalize=True)

# ================ 8. 保存最终模型 ================
torch.save(generator.state_dict(), './save/generator.pth')  # 保存生成器
torch.save(discriminator.state_dict(), './save/discriminator.pth')  # 保存判别器
bash 复制代码
[Epoch 0/50] [Batch 299/938] [D loss: 1.108700] [G loss: 1.494937] [D real: 0.765423] [D fake: 0.563390]
[Epoch 0/50] [Batch 599/938] [D loss: 0.981047] [G loss: 2.200819] [D real: 0.859328] [D fake: 0.555203]
[Epoch 0/50] [Batch 899/938] [D loss: 1.012156] [G loss: 1.935689] [D real: 0.728062] [D fake: 0.476248]
[Epoch 1/50] [Batch 299/938] [D loss: 1.188978] [G loss: 0.676110] [D real: 0.426300] [D fake: 0.200765]
[Epoch 1/50] [Batch 599/938] [D loss: 1.007571] [G loss: 1.044460] [D real: 0.562748] [D fake: 0.284159]
[Epoch 1/50] [Batch 899/938] [D loss: 1.071741] [G loss: 1.711364] [D real: 0.720821] [D fake: 0.483612]
[Epoch 2/50] [Batch 299/938] [D loss: 0.910406] [G loss: 2.151794] [D real: 0.764064] [D fake: 0.448280]
[Epoch 2/50] [Batch 599/938] [D loss: 0.800963] [G loss: 1.313761] [D real: 0.613358] [D fake: 0.188154]
[Epoch 2/50] [Batch 899/938] [D loss: 1.093633] [G loss: 1.053562] [D real: 0.531550] [D fake: 0.230020]
[Epoch 3/50] [Batch 299/938] [D loss: 0.963498] [G loss: 2.506877] [D real: 0.811666] [D fake: 0.497298]
[Epoch 3/50] [Batch 599/938] [D loss: 1.083450] [G loss: 0.882004] [D real: 0.465563] [D fake: 0.117864]
[Epoch 3/50] [Batch 899/938] [D loss: 0.973209] [G loss: 2.698256] [D real: 0.809422] [D fake: 0.502016]
[Epoch 4/50] [Batch 299/938] [D loss: 0.817019] [G loss: 1.351617] [D real: 0.666476] [D fake: 0.273635]
.......
相关推荐
三万棵雪松2 小时前
【AI小智后端部分(二)】
人工智能·ai小智·opus编码
愚公搬代码2 小时前
【愚公系列】《扣子开发 AI Agent 智能体应用》031-实战案例:多 Agent 模式开发旅游助手
人工智能·旅游
Elastic 中国社区官方博客2 小时前
Jina 模型的介绍,它们的功能,以及在 Elasticsearch 中的使用
大数据·人工智能·elasticsearch·搜索引擎·ai·全文检索·jina
大得3692 小时前
gpt-oss:20b大模型知识库,ai大模型
人工智能·python·gpt
2401_841495642 小时前
【机器学习】生成对抗网络(GAN)
人工智能·python·深度学习·神经网络·算法·机器学习·生成对抗网络
Hcoco_me2 小时前
大模型面试题24:小白版InfoNCE原理
人工智能·rnn·深度学习·自然语言处理·word2vec
无水先生2 小时前
图像处理方向的问题总结
图像处理·人工智能
阿正的梦工坊2 小时前
二次预训练与微调的区别
人工智能·深度学习·机器学习·大模型·llm
小宇的天下2 小时前
Calibre eqDRC(方程化 DRC)核心技术解析与实战指南(14-2)
人工智能·机器学习·支持向量机