pytorch生成对抗网络

人工智能例子汇总:AI常见的算法和例子-CSDN博客

生成对抗网络(GAN,Generative Adversarial Network)是一种深度学习模型,由两个神经网络组成:生成器(Generator)和判别器(Discriminator)。这两个网络通过对抗过程共同训练,从而使生成器能够生成越来越真实的假数据。

GAN的基本工作原理:

  1. 生成器(G):它的任务是生成与真实数据相似的假数据。生成器通常从一个随机噪声(例如,均匀分布或高斯分布的噪声)开始,经过多层神经网络的处理,输出伪造的数据样本。

  2. 判别器(D):它的任务是区分输入数据是来自真实数据分布,还是生成器伪造的假数据。判别器通常是一个二分类器,其输出是一个表示"真实"或"假"的概率值。

训练过程:

  • 对抗过程:生成器和判别器相互博弈。生成器希望生成尽可能像真的数据,以骗过判别器;而判别器希望准确区分真假数据。最终,生成器会通过优化损失函数,使得生成的数据与真实数据尽可能相似,判别器的性能则被提升到一个极限,使得它不能再轻易地区分真假数据。

数学公式:

  • 判别器的目标是最大化其输出的正确分类概率,即区分真假数据。
  • 生成器的目标是最小化其输出的"假数据"被判定为假的概率。

常见的GAN变种:

  1. DCGAN(Deep Convolutional GAN):使用卷积神经网络(CNN)来增强生成器和判别器的表现。
  2. WGAN(Wasserstein GAN):引入了Wasserstein距离,改进了训练稳定性。
  3. CycleGAN:能够在没有成对样本的情况下进行图像到图像的转换,例如将马变成斑马。

以下是一个简化的PyTorch GAN实现的框架,生成一个语音的梅尔频谱(假设已经处理了音频并提取了梅尔频谱特征)

import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import matplotlib.pyplot as plt


# 生成器(Generator)
class Generator(nn.Module):
    def __init__(self, z_dim=100):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(z_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 80),  # 80表示梅尔频谱的时间步(例如:80个梅尔频率)
            nn.Tanh()  # 生成梅尔频谱,范围在[-1, 1]之间
        )

    def forward(self, z):
        return self.fc(z)


# 判别器(Discriminator)
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(80, 512),  # 输入为梅尔频谱的时间步
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  # 输出判定是"真"还是"假"
        )

    def forward(self, x):
        return self.fc(x)


# 初始化生成器和判别器
z_dim = 100
generator = Generator(z_dim)
discriminator = Discriminator()

# 优化器
lr = 0.0002
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# 损失函数
criterion = nn.BCELoss()


# 加载数据(假设已经提取了梅尔频谱特征,取一个示例)
def load_example_mel_spectrogram():
    # 假设这是一个真实梅尔频谱的示例,实际数据应从音频文件中提取
    mel = torch.rand((80))  # 生成一个假的梅尔频谱数据
    return mel.unsqueeze(0)  # 扩展维度以适应网络


# 训练GAN
num_epochs = 1000
for epoch in range(num_epochs):
    # 真实数据
    real_data = load_example_mel_spectrogram()
    real_labels = torch.ones(real_data.size(0), 1)  # 标签为1表示真实数据

    # 假数据
    z = torch.randn(real_data.size(0), z_dim)  # 随机噪声
    fake_data = generator(z)
    fake_labels = torch.zeros(real_data.size(0), 1)  # 标签为0表示假数据

    # 训练判别器
    discriminator.zero_grad()
    real_loss = criterion(discriminator(real_data), real_labels)
    fake_loss = criterion(discriminator(fake_data.detach()), fake_labels)
    d_loss = (real_loss + fake_loss) / 2
    d_loss.backward()
    d_optimizer.step()

    # 训练生成器
    generator.zero_grad()
    g_loss = criterion(discriminator(fake_data), real_labels)  # 生成器希望判别器判定为真实
    g_loss.backward()
    g_optimizer.step()

    if epoch % 100 == 0:
        print(f"Epoch [{epoch}/{num_epochs}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")

    # 可视化生成的梅尔频谱(只显示最后一次生成的结果)
    if epoch == num_epochs - 1:
        plt.figure(figsize=(10, 4))
        plt.imshow(fake_data.detach().numpy(), aspect='auto', origin='lower')
        plt.title(f"Generated Mel Spectrogram - Epoch {epoch}")
        plt.colorbar()
        plt.show()

# 测试阶段:使用训练好的生成器进行语音生成
z_test = torch.randn(1, z_dim)  # 创建一个新的随机噪声向量
generated_mel_spectrogram = generator(z_test)

# 可视化生成的梅尔频谱
plt.figure(figsize=(10, 4))
plt.imshow(generated_mel_spectrogram.detach().numpy(), aspect='auto', origin='lower')
plt.title("Generated Mel Spectrogram from Test Data")
plt.colorbar()
plt.show()

解释:

  1. 测试阶段

    • 在训练完成后,我们使用一个新的随机噪声向量z_test来生成一个新的梅尔频谱。
    • generated_mel_spectrogram = generator(z_test)是生成梅尔频谱的过程。
  2. 可视化

    • 使用plt.imshow()来可视化生成的梅尔频谱图,origin='lower'是确保频谱图正确显示。
    • plt.colorbar()添加颜色条,以便更清晰地理解梅尔频谱的数值范围。

结果:

  • 在训练过程中,你会看到每个epoch的损失值,并在最后一次epoch时显示生成的梅尔频谱。
  • 在测试阶段,生成器会基于随机噪声生成一个新的梅尔频谱并进行可视化,帮助你观察最终模型生成的语音特征。
相关推荐
云空33 分钟前
《DeepSeek 网页/API 性能异常(DeepSeek Web/API Degraded Performance):网络安全日志》
运维·人工智能·web安全·网络安全·开源·网络攻击模型·安全威胁分析
AIGC大时代35 分钟前
对比DeepSeek、ChatGPT和Kimi的学术写作关键词提取能力
论文阅读·人工智能·chatgpt·数据分析·prompt
山晨啊82 小时前
2025年美赛B题-结合Logistic阻滞增长模型和SIR传染病模型研究旅游可持续性-成品论文
人工智能·机器学习
一水鉴天2 小时前
为AI聊天工具添加一个知识系统 之77 详细设计之18 正则表达式 之5
人工智能·正则表达式
davenian2 小时前
DeepSeek-R1 论文. Reinforcement Learning 通过强化学习激励大型语言模型的推理能力
人工智能·深度学习·语言模型·deepseek
X.AI6663 小时前
【大模型LLM面试合集】大语言模型架构_llama系列模型
人工智能·语言模型·llama
CM莫问3 小时前
什么是门控循环单元?
人工智能·pytorch·python·rnn·深度学习·算法·gru
饮马长城窟3 小时前
Paddle和pytorch不可以同时引用
人工智能·pytorch·paddle
机器之心3 小时前
全面梳理200+篇前沿论文,视觉生成模型理解物理世界规律的通关密码,都在这篇综述里了!
人工智能
池佳齐4 小时前
《AI大模型开发笔记》DeepSeek技术创新点
人工智能·笔记