生成对抗网络GAN(MNIST实现、时间序列实现)

生成对抗网络

生成对抗网络介绍

生成对抗网络(Generative Adversarial Network,简称GAN)是一种深度学习模型,由Ian Goodfellow等人于2014年提出。它由两个主要组件组成:生成器(Generator)和判别器(Discriminator)。GAN的目标是通过两个网络之间的对抗学习来生成逼真的数据。

  1. 生成器(Generator): 生成器是一个神经网络,它接收一个随机噪声向量作为输入,并试图将这个随机噪声转换为逼真的数据样本。在训练过程中,生成器不断试图提高生成样本的质量,使其能够欺骗判别器。初始阶段生成的样本可能不够真实,但随着训练的进行,生成器逐渐学会生成更加逼真的数据样本。
  2. 判别器(Discriminator): 判别器也是一个神经网络,它的任务是区分真实数据样本和由生成器生成的假样本。它类似于一个二分类器,努力将输入样本分为"真实"和"假"的两个类别。在训练过程中,判别器通过不断学习区分真实样本和生成样本,使得判别器的准确率不断提高。

GAN的训练过程是一个对抗过程:

  1. 生成器通过将随机噪声转换为生成样本,并将这些生成样本传递给判别器。
  2. 判别器根据传递给它的真实样本和生成样本对其进行分类,并输出相应的概率分数。
  3. 根据判别器的输出,生成器试图生成能够欺骗判别器的更逼真的样本。
  4. 这个过程不断重复,直到生成器生成的样本足够逼真,判别器无法准确区分真假样本。

通过这种对抗学习的过程,GAN能够生成高质量的数据样本,广泛应用于图像、音频、文本等领域。然而,训练GAN也存在一些挑战,如训练不稳定、模式崩溃等问题,需要经验丰富的研究人员进行调优和改进。

MNIST---GAN

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# 定义生成器和判别器的类
class Generator(nn.Module):
    def __init__(self, z_dim=100, hidden_dim=128, output_dim=784):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, hidden_dim),
            nn.LeakyReLU(0.01),
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.LeakyReLU(0.01),
            nn.Linear(hidden_dim * 2, output_dim),
            nn.Tanh()
        )

    def forward(self, noise):
        return self.gen(noise)

class Discriminator(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=128):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim * 2),
            nn.LeakyReLU(0.01),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LeakyReLU(0.01),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, image):
        return self.disc(image)

# 定义训练函数
def train_gan(generator, discriminator, dataloader, num_epochs=50, z_dim=100, lr=0.0002):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator.to(device)
    discriminator.to(device)
    
    gen_optim = optim.Adam(generator.parameters(), lr=lr)
    disc_optim = optim.Adam(discriminator.parameters(), lr=lr)
    criterion = nn.BCELoss()

    for epoch in range(num_epochs):
        for batch_idx, (real_images, _) in enumerate(dataloader):
            batch_size = real_images.size(0)
            real_images = real_images.view(batch_size, -1).to(device)
            ones_labels = torch.ones(batch_size, 1).to(device)
            zeros_labels = torch.zeros(batch_size, 1).to(device)

            # 训练判别器
            disc_optim.zero_grad()
            real_preds = discriminator(real_images)
            real_loss = criterion(real_preds, ones_labels)

            noise = torch.randn(batch_size, z_dim).to(device)
            fake_images = generator(noise)
            fake_preds = discriminator(fake_images.detach())
            fake_loss = criterion(fake_preds, zeros_labels)

            disc_loss = (real_loss + fake_loss) / 2
            disc_loss.backward()
            disc_optim.step()

            # 训练生成器
            gen_optim.zero_grad()
            noise = torch.randn(batch_size, z_dim).to(device)
            fake_images = generator(noise)
            preds = discriminator(fake_images)
            gen_loss = criterion(preds, ones_labels)
            gen_loss.backward()
            gen_optim.step()

        print(f"Epoch [{epoch+1}/{num_epochs}], Generator Loss: {gen_loss.item():.4f}, Discriminator Loss: {disc_loss.item():.4f}")

# 主函数
if __name__ == "__main__":
    # 定义参数和数据加载
    z_dim = 100
    batch_size = 64
    num_epochs = 50
    lr = 0.0002

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    dataset = MNIST(root="data", transform=transform, download=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # 创建生成器和判别器实例
    generator = Generator(z_dim=z_dim)
    discriminator = Discriminator()

    # 训练GAN
    train_gan(generator, discriminator, dataloader, num_epochs=num_epochs, z_dim=z_dim, lr=lr)

    # 生成并显示一些图像样本
    num_samples = 16
    noise = torch.randn(num_samples, z_dim)
    generated_images = generator(noise).detach().cpu()
    plt.figure(figsize=(8, 8))
    for i in range(num_samples):
        plt.subplot(4, 4, i + 1)
        plt.imshow(generated_images[i].view(28, 28), cmap='gray')
        plt.axis('off')
    plt.show()

Conditional GAN (CGAN)---时间序列预测

生成对抗网络(GAN)通常用于生成静态数据,例如图像、文本等。然而,要将GAN应用于时间序列预测,则需要对GAN进行适当的修改。在这里,我将向你介绍一个基于GAN的时间序列预测方法------Conditional GAN (CGAN)。

Conditional GAN (CGAN) 是GAN的扩展,它在生成器和判别器的输入中加入条件信息,使得生成器可以生成与给定条件相关的时间序列数据。在时间序列预测任务中,我们将使用历史时间序列数据作为条件信息来预测未来的时间序列值。

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# 准备合成时间序列数据
def generate_time_series(num_samples, num_timesteps):
    freq1, freq2, offset1, offset2 = np.random.rand(4, num_samples, 1)
    time = np.linspace(0, 1, num_timesteps)
    series = 0.5 * np.sin((time - offset1) * (freq1 * 10 + 10))
    series += 0.2 * np.sin((time - offset2) * (freq2 * 20 + 20))
    series += 0.1 * (np.random.rand(num_samples, num_timesteps) - 0.5)
    return series[..., np.newaxis].astype(np.float32)

# 数据预处理
def prepare_data(data, seq_length):
    num_samples, num_timesteps, num_features = data.shape
    X, y = [], []
    for i in range(num_timesteps - seq_length):
        X.append(data[:, i:i+seq_length, :])
        y.append(data[:, i+seq_length, :])
    X = np.array(X)
    y = np.array(y)
    return X, y

# 生成器和判别器的定义
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=64):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, output_dim)
        )

    def forward(self, noise, condition):
        combined_input = torch.cat((noise, condition), dim=1)
        return self.gen(combined_input)

class Discriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim=64):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, sequence, condition):
        combined_input = torch.cat((sequence, condition), dim=1)
        return self.disc(combined_input)

# 训练CGAN
def train_cgan(generator, discriminator, data, num_epochs=1000, batch_size=64, seq_length=10, noise_dim=20, lr=0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator.to(device)
    discriminator.to(device)

    gen_optim = optim.Adam(generator.parameters(), lr=lr)
    disc_optim = optim.Adam(discriminator.parameters(), lr=lr)
    criterion = nn.BCELoss()

    for epoch in range(num_epochs):
        idx = np.random.randint(0, data.shape[0] - seq_length, batch_size)
        real_data = torch.tensor(data[idx]).to(device)
        condition = torch.tensor(data[idx + seq_length]).to(device)

        # 训练判别器
        disc_optim.zero_grad()
        noise = torch.randn(batch_size, noise_dim).to(device)
        fake_data = generator(noise, condition)
        disc_real_preds = discriminator(real_data, condition)
        disc_fake_preds = discriminator(fake_data, condition)

        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        disc_real_loss = criterion(disc_real_preds, real_labels)
        disc_fake_loss = criterion(disc_fake_preds, fake_labels)
        disc_loss = (disc_real_loss + disc_fake_loss) / 2
        disc_loss.backward()
        disc_optim.step()

        # 训练生成器
        gen_optim.zero_grad()
        noise = torch.randn(batch_size, noise_dim).to(device)
        fake_data = generator(noise, condition)
        gen_disc_preds = discriminator(fake_data, condition)
        gen_labels = torch.ones(batch_size, 1).to(device)

        gen_loss = criterion(gen_disc_preds, gen_labels)
        gen_loss.backward()
        gen_optim.step()

        if epoch % 100 == 0:
            print(f"Epoch [{epoch}/{num_epochs}], Generator Loss: {gen_loss.item():.4f}, Discriminator Loss: {disc_loss.item():.4f}")

# 主函数
if __name__ == "__main__":
    # 定义参数和合成时间序列数据集
    num_samples = 1000
    num_timesteps = 50
    seq_length = 10
    noise_dim = 20
    lr = 0.001
    num_epochs = 2000

    data = generate_time_series(num_samples, num_timesteps)
    X, y = prepare_data(data, seq_length)

    # 创建生成器和判别器实例
    input_dim = noise_dim + seq_length
    output_dim = y.shape[-1]
    generator = Generator(input_dim, output_dim)
    discriminator = Discriminator(input_dim)

    # 训练CGAN
    train_cgan(generator, discriminator, X, num_epochs=num_epochs, batch_size=64, seq_length=seq_length, noise_dim=noise_dim, lr=lr)

    # 预测未来的时间序列
    num_predictions = 5
    noise = torch.randn(num_predictions, noise_dim).to(device)
    condition = torch.tensor(data[-num_predictions:, -seq_length:, :]).to(device)
    generated_data = generator(noise, condition).detach().cpu().numpy()

    # 显示生成的时间序列
    plt.figure(figsize=(10, 6))
    for i in range(num_predictions):
        plt.plot(range(num_timesteps - seq_length, num_timesteps), condition[i, :, 0], 'b-')
        plt.plot(range(num_timesteps, num_timesteps + seq_length), generated_data[i, :, 0], 'r--')
    plt.xlabel('Time')
    plt.ylabel('Value')
    plt.legend(['Past', 'Generated Future'])
    plt.show()
相关推荐
NAGNIP5 小时前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab7 小时前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab7 小时前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP10 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年10 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼11 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS11 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区12 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈12 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang13 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx