生成对抗网络GAN(MNIST实现、时间序列实现)

生成对抗网络

生成对抗网络介绍

生成对抗网络(Generative Adversarial Network,简称GAN)是一种深度学习模型,由Ian Goodfellow等人于2014年提出。它由两个主要组件组成:生成器(Generator)和判别器(Discriminator)。GAN的目标是通过两个网络之间的对抗学习来生成逼真的数据。

  1. 生成器(Generator): 生成器是一个神经网络,它接收一个随机噪声向量作为输入,并试图将这个随机噪声转换为逼真的数据样本。在训练过程中,生成器不断试图提高生成样本的质量,使其能够欺骗判别器。初始阶段生成的样本可能不够真实,但随着训练的进行,生成器逐渐学会生成更加逼真的数据样本。
  2. 判别器(Discriminator): 判别器也是一个神经网络,它的任务是区分真实数据样本和由生成器生成的假样本。它类似于一个二分类器,努力将输入样本分为"真实"和"假"的两个类别。在训练过程中,判别器通过不断学习区分真实样本和生成样本,使得判别器的准确率不断提高。

GAN的训练过程是一个对抗过程:

  1. 生成器通过将随机噪声转换为生成样本,并将这些生成样本传递给判别器。
  2. 判别器根据传递给它的真实样本和生成样本对其进行分类,并输出相应的概率分数。
  3. 根据判别器的输出,生成器试图生成能够欺骗判别器的更逼真的样本。
  4. 这个过程不断重复,直到生成器生成的样本足够逼真,判别器无法准确区分真假样本。

通过这种对抗学习的过程,GAN能够生成高质量的数据样本,广泛应用于图像、音频、文本等领域。然而,训练GAN也存在一些挑战,如训练不稳定、模式崩溃等问题,需要经验丰富的研究人员进行调优和改进。

MNIST---GAN

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

# 定义生成器和判别器的类
class Generator(nn.Module):
    def __init__(self, z_dim=100, hidden_dim=128, output_dim=784):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, hidden_dim),
            nn.LeakyReLU(0.01),
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.LeakyReLU(0.01),
            nn.Linear(hidden_dim * 2, output_dim),
            nn.Tanh()
        )

    def forward(self, noise):
        return self.gen(noise)

class Discriminator(nn.Module):
    def __init__(self, input_dim=784, hidden_dim=128):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim * 2),
            nn.LeakyReLU(0.01),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.LeakyReLU(0.01),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, image):
        return self.disc(image)

# 定义训练函数
def train_gan(generator, discriminator, dataloader, num_epochs=50, z_dim=100, lr=0.0002):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator.to(device)
    discriminator.to(device)
    
    gen_optim = optim.Adam(generator.parameters(), lr=lr)
    disc_optim = optim.Adam(discriminator.parameters(), lr=lr)
    criterion = nn.BCELoss()

    for epoch in range(num_epochs):
        for batch_idx, (real_images, _) in enumerate(dataloader):
            batch_size = real_images.size(0)
            real_images = real_images.view(batch_size, -1).to(device)
            ones_labels = torch.ones(batch_size, 1).to(device)
            zeros_labels = torch.zeros(batch_size, 1).to(device)

            # 训练判别器
            disc_optim.zero_grad()
            real_preds = discriminator(real_images)
            real_loss = criterion(real_preds, ones_labels)

            noise = torch.randn(batch_size, z_dim).to(device)
            fake_images = generator(noise)
            fake_preds = discriminator(fake_images.detach())
            fake_loss = criterion(fake_preds, zeros_labels)

            disc_loss = (real_loss + fake_loss) / 2
            disc_loss.backward()
            disc_optim.step()

            # 训练生成器
            gen_optim.zero_grad()
            noise = torch.randn(batch_size, z_dim).to(device)
            fake_images = generator(noise)
            preds = discriminator(fake_images)
            gen_loss = criterion(preds, ones_labels)
            gen_loss.backward()
            gen_optim.step()

        print(f"Epoch [{epoch+1}/{num_epochs}], Generator Loss: {gen_loss.item():.4f}, Discriminator Loss: {disc_loss.item():.4f}")

# 主函数
if __name__ == "__main__":
    # 定义参数和数据加载
    z_dim = 100
    batch_size = 64
    num_epochs = 50
    lr = 0.0002

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    dataset = MNIST(root="data", transform=transform, download=True)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # 创建生成器和判别器实例
    generator = Generator(z_dim=z_dim)
    discriminator = Discriminator()

    # 训练GAN
    train_gan(generator, discriminator, dataloader, num_epochs=num_epochs, z_dim=z_dim, lr=lr)

    # 生成并显示一些图像样本
    num_samples = 16
    noise = torch.randn(num_samples, z_dim)
    generated_images = generator(noise).detach().cpu()
    plt.figure(figsize=(8, 8))
    for i in range(num_samples):
        plt.subplot(4, 4, i + 1)
        plt.imshow(generated_images[i].view(28, 28), cmap='gray')
        plt.axis('off')
    plt.show()

Conditional GAN (CGAN)---时间序列预测

生成对抗网络(GAN)通常用于生成静态数据,例如图像、文本等。然而,要将GAN应用于时间序列预测,则需要对GAN进行适当的修改。在这里,我将向你介绍一个基于GAN的时间序列预测方法------Conditional GAN (CGAN)。

Conditional GAN (CGAN) 是GAN的扩展,它在生成器和判别器的输入中加入条件信息,使得生成器可以生成与给定条件相关的时间序列数据。在时间序列预测任务中,我们将使用历史时间序列数据作为条件信息来预测未来的时间序列值。

python 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

# 准备合成时间序列数据
def generate_time_series(num_samples, num_timesteps):
    freq1, freq2, offset1, offset2 = np.random.rand(4, num_samples, 1)
    time = np.linspace(0, 1, num_timesteps)
    series = 0.5 * np.sin((time - offset1) * (freq1 * 10 + 10))
    series += 0.2 * np.sin((time - offset2) * (freq2 * 20 + 20))
    series += 0.1 * (np.random.rand(num_samples, num_timesteps) - 0.5)
    return series[..., np.newaxis].astype(np.float32)

# 数据预处理
def prepare_data(data, seq_length):
    num_samples, num_timesteps, num_features = data.shape
    X, y = [], []
    for i in range(num_timesteps - seq_length):
        X.append(data[:, i:i+seq_length, :])
        y.append(data[:, i+seq_length, :])
    X = np.array(X)
    y = np.array(y)
    return X, y

# 生成器和判别器的定义
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim=64):
        super(Generator, self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, output_dim)
        )

    def forward(self, noise, condition):
        combined_input = torch.cat((noise, condition), dim=1)
        return self.gen(combined_input)

class Discriminator(nn.Module):
    def __init__(self, input_dim, hidden_dim=64):
        super(Discriminator, self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(input_dim, hidden_dim * 2),
            nn.ReLU(),
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

    def forward(self, sequence, condition):
        combined_input = torch.cat((sequence, condition), dim=1)
        return self.disc(combined_input)

# 训练CGAN
def train_cgan(generator, discriminator, data, num_epochs=1000, batch_size=64, seq_length=10, noise_dim=20, lr=0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generator.to(device)
    discriminator.to(device)

    gen_optim = optim.Adam(generator.parameters(), lr=lr)
    disc_optim = optim.Adam(discriminator.parameters(), lr=lr)
    criterion = nn.BCELoss()

    for epoch in range(num_epochs):
        idx = np.random.randint(0, data.shape[0] - seq_length, batch_size)
        real_data = torch.tensor(data[idx]).to(device)
        condition = torch.tensor(data[idx + seq_length]).to(device)

        # 训练判别器
        disc_optim.zero_grad()
        noise = torch.randn(batch_size, noise_dim).to(device)
        fake_data = generator(noise, condition)
        disc_real_preds = discriminator(real_data, condition)
        disc_fake_preds = discriminator(fake_data, condition)

        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)

        disc_real_loss = criterion(disc_real_preds, real_labels)
        disc_fake_loss = criterion(disc_fake_preds, fake_labels)
        disc_loss = (disc_real_loss + disc_fake_loss) / 2
        disc_loss.backward()
        disc_optim.step()

        # 训练生成器
        gen_optim.zero_grad()
        noise = torch.randn(batch_size, noise_dim).to(device)
        fake_data = generator(noise, condition)
        gen_disc_preds = discriminator(fake_data, condition)
        gen_labels = torch.ones(batch_size, 1).to(device)

        gen_loss = criterion(gen_disc_preds, gen_labels)
        gen_loss.backward()
        gen_optim.step()

        if epoch % 100 == 0:
            print(f"Epoch [{epoch}/{num_epochs}], Generator Loss: {gen_loss.item():.4f}, Discriminator Loss: {disc_loss.item():.4f}")

# 主函数
if __name__ == "__main__":
    # 定义参数和合成时间序列数据集
    num_samples = 1000
    num_timesteps = 50
    seq_length = 10
    noise_dim = 20
    lr = 0.001
    num_epochs = 2000

    data = generate_time_series(num_samples, num_timesteps)
    X, y = prepare_data(data, seq_length)

    # 创建生成器和判别器实例
    input_dim = noise_dim + seq_length
    output_dim = y.shape[-1]
    generator = Generator(input_dim, output_dim)
    discriminator = Discriminator(input_dim)

    # 训练CGAN
    train_cgan(generator, discriminator, X, num_epochs=num_epochs, batch_size=64, seq_length=seq_length, noise_dim=noise_dim, lr=lr)

    # 预测未来的时间序列
    num_predictions = 5
    noise = torch.randn(num_predictions, noise_dim).to(device)
    condition = torch.tensor(data[-num_predictions:, -seq_length:, :]).to(device)
    generated_data = generator(noise, condition).detach().cpu().numpy()

    # 显示生成的时间序列
    plt.figure(figsize=(10, 6))
    for i in range(num_predictions):
        plt.plot(range(num_timesteps - seq_length, num_timesteps), condition[i, :, 0], 'b-')
        plt.plot(range(num_timesteps, num_timesteps + seq_length), generated_data[i, :, 0], 'r--')
    plt.xlabel('Time')
    plt.ylabel('Value')
    plt.legend(['Past', 'Generated Future'])
    plt.show()
相关推荐
AI技术控8 分钟前
计算机视觉算法实战——驾驶员安全带检测
人工智能·算法·计算机视觉
LucianaiB9 分钟前
基于自然语言处理的垃圾短信识别系统
人工智能·自然语言处理·垃圾短信识别系统
feifeikon1 小时前
大模型GUI系列论文阅读 DAY4续:《Large Language Model Agent for Fake News Detection》
论文阅读·人工智能·语言模型
feifeikon1 小时前
图神经网络系列论文阅读DAY1:《Predicting Tweet Engagement with Graph Neural Networks》
论文阅读·人工智能·神经网络
ZStack开发者社区3 小时前
AI应用、轻量云、虚拟化|云轴科技ZStack参编金融行标与报告
人工智能·科技·金融
存内计算开发者4 小时前
机器人奇点:从宇树科技看2025具身智能发展
深度学习·神经网络·机器学习·计算机视觉·机器人·视觉检测·具身智能
真想骂*5 小时前
人工智能如何重塑音频、视觉及多模态领域的应用格局
人工智能·音视频
赛丽曼7 小时前
机器学习-K近邻算法
人工智能·机器学习·近邻算法
啊波次得饿佛哥9 小时前
7. 计算机视觉
人工智能·计算机视觉·视觉检测
XianxinMao9 小时前
RLHF技术应用探析:从安全任务到高阶能力提升
人工智能·python·算法