pytorch生成对抗网络

人工智能例子汇总:AI常见的算法和例子-CSDN博客

生成对抗网络(GAN,Generative Adversarial Network)是一种深度学习模型,由两个神经网络组成:生成器(Generator)和判别器(Discriminator)。这两个网络通过对抗过程共同训练,从而使生成器能够生成越来越真实的假数据。

GAN的基本工作原理:

  1. 生成器(G):它的任务是生成与真实数据相似的假数据。生成器通常从一个随机噪声(例如,均匀分布或高斯分布的噪声)开始,经过多层神经网络的处理,输出伪造的数据样本。

  2. 判别器(D):它的任务是区分输入数据是来自真实数据分布,还是生成器伪造的假数据。判别器通常是一个二分类器,其输出是一个表示"真实"或"假"的概率值。

训练过程:

  • 对抗过程:生成器和判别器相互博弈。生成器希望生成尽可能像真的数据,以骗过判别器;而判别器希望准确区分真假数据。最终,生成器会通过优化损失函数,使得生成的数据与真实数据尽可能相似,判别器的性能则被提升到一个极限,使得它不能再轻易地区分真假数据。

数学公式:

  • 判别器的目标是最大化其输出的正确分类概率,即区分真假数据。
  • 生成器的目标是最小化其输出的"假数据"被判定为假的概率。

常见的GAN变种:

  1. DCGAN(Deep Convolutional GAN):使用卷积神经网络(CNN)来增强生成器和判别器的表现。
  2. WGAN(Wasserstein GAN):引入了Wasserstein距离,改进了训练稳定性。
  3. CycleGAN:能够在没有成对样本的情况下进行图像到图像的转换,例如将马变成斑马。

以下是一个简化的PyTorch GAN实现的框架,生成一个语音的梅尔频谱(假设已经处理了音频并提取了梅尔频谱特征)

复制代码
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import matplotlib.pyplot as plt


# 生成器(Generator)
class Generator(nn.Module):
    def __init__(self, z_dim=100):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(z_dim, 128),
            nn.ReLU(),
            nn.Linear(128, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 80),  # 80表示梅尔频谱的时间步(例如:80个梅尔频率)
            nn.Tanh()  # 生成梅尔频谱,范围在[-1, 1]之间
        )

    def forward(self, z):
        return self.fc(z)


# 判别器(Discriminator)
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(80, 512),  # 输入为梅尔频谱的时间步
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()  # 输出判定是"真"还是"假"
        )

    def forward(self, x):
        return self.fc(x)


# 初始化生成器和判别器
z_dim = 100
generator = Generator(z_dim)
discriminator = Discriminator()

# 优化器
lr = 0.0002
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# 损失函数
criterion = nn.BCELoss()


# 加载数据(假设已经提取了梅尔频谱特征,取一个示例)
def load_example_mel_spectrogram():
    # 假设这是一个真实梅尔频谱的示例,实际数据应从音频文件中提取
    mel = torch.rand((80))  # 生成一个假的梅尔频谱数据
    return mel.unsqueeze(0)  # 扩展维度以适应网络


# 训练GAN
num_epochs = 1000
for epoch in range(num_epochs):
    # 真实数据
    real_data = load_example_mel_spectrogram()
    real_labels = torch.ones(real_data.size(0), 1)  # 标签为1表示真实数据

    # 假数据
    z = torch.randn(real_data.size(0), z_dim)  # 随机噪声
    fake_data = generator(z)
    fake_labels = torch.zeros(real_data.size(0), 1)  # 标签为0表示假数据

    # 训练判别器
    discriminator.zero_grad()
    real_loss = criterion(discriminator(real_data), real_labels)
    fake_loss = criterion(discriminator(fake_data.detach()), fake_labels)
    d_loss = (real_loss + fake_loss) / 2
    d_loss.backward()
    d_optimizer.step()

    # 训练生成器
    generator.zero_grad()
    g_loss = criterion(discriminator(fake_data), real_labels)  # 生成器希望判别器判定为真实
    g_loss.backward()
    g_optimizer.step()

    if epoch % 100 == 0:
        print(f"Epoch [{epoch}/{num_epochs}], D Loss: {d_loss.item()}, G Loss: {g_loss.item()}")

    # 可视化生成的梅尔频谱(只显示最后一次生成的结果)
    if epoch == num_epochs - 1:
        plt.figure(figsize=(10, 4))
        plt.imshow(fake_data.detach().numpy(), aspect='auto', origin='lower')
        plt.title(f"Generated Mel Spectrogram - Epoch {epoch}")
        plt.colorbar()
        plt.show()

# 测试阶段:使用训练好的生成器进行语音生成
z_test = torch.randn(1, z_dim)  # 创建一个新的随机噪声向量
generated_mel_spectrogram = generator(z_test)

# 可视化生成的梅尔频谱
plt.figure(figsize=(10, 4))
plt.imshow(generated_mel_spectrogram.detach().numpy(), aspect='auto', origin='lower')
plt.title("Generated Mel Spectrogram from Test Data")
plt.colorbar()
plt.show()

解释:

  1. 测试阶段

    • 在训练完成后,我们使用一个新的随机噪声向量z_test来生成一个新的梅尔频谱。
    • generated_mel_spectrogram = generator(z_test)是生成梅尔频谱的过程。
  2. 可视化

    • 使用plt.imshow()来可视化生成的梅尔频谱图,origin='lower'是确保频谱图正确显示。
    • plt.colorbar()添加颜色条,以便更清晰地理解梅尔频谱的数值范围。

结果:

  • 在训练过程中,你会看到每个epoch的损失值,并在最后一次epoch时显示生成的梅尔频谱。
  • 在测试阶段,生成器会基于随机噪声生成一个新的梅尔频谱并进行可视化,帮助你观察最终模型生成的语音特征。
相关推荐
G皮T3 小时前
【人工智能】ChatGPT、DeepSeek-R1、DeepSeek-V3 辨析
人工智能·chatgpt·llm·大语言模型·deepseek·deepseek-v3·deepseek-r1
九年义务漏网鲨鱼3 小时前
【大模型学习 | MINIGPT-4原理】
人工智能·深度学习·学习·语言模型·多模态
元宇宙时间3 小时前
Playfun即将开启大型Web3线上活动,打造沉浸式GameFi体验生态
人工智能·去中心化·区块链
开发者工具分享3 小时前
文本音频违规识别工具排行榜(12选)
人工智能·音视频
产品经理独孤虾4 小时前
人工智能大模型如何助力电商产品经理打造高效的商品工业属性画像
人工智能·机器学习·ai·大模型·产品经理·商品画像·商品工业属性
老任与码4 小时前
Spring AI Alibaba(1)——基本使用
java·人工智能·后端·springaialibaba
蹦蹦跳跳真可爱5894 小时前
Python----OpenCV(图像増强——高通滤波(索贝尔算子、沙尔算子、拉普拉斯算子),图像浮雕与特效处理)
人工智能·python·opencv·计算机视觉
雷羿 LexChien4 小时前
从 Prompt 管理到人格稳定:探索 Cursor AI 编辑器如何赋能 Prompt 工程与人格风格设计(上)
人工智能·python·llm·编辑器·prompt
两棵雪松5 小时前
如何通过向量化技术比较两段文本是否相似?
人工智能
heart000_15 小时前
128K 长文本处理实战:腾讯混元 + 云函数 SCF 构建 PDF 摘要生成器
人工智能·自然语言处理·pdf