import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from sklearn.datasets import load_iris
import warnings
忽略不必要的警告信息
warnings.filterwarnings("ignore")
--------------------------
1. 配置训练参数与设备
--------------------------
潜在空间维度(生成器的输入维度)
latent_dim = 10
训练总轮数(GAN通常需要较多迭代才能收敛)
train_epochs = 10000
批次大小(根据数据集规模调整)
batch_size = 32
学习率(控制参数更新幅度)
learning_rate = 0.0002
Adam优化器的动量参数(影响收敛稳定性)
beta1 = 0.5
自动选择运算设备(优先GPU,没有则用CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"当前使用设备: {device}")
--------------------------
2. 数据加载与预处理
--------------------------
加载鸢尾花数据集
iris_dataset = load_iris()
提取特征数据和标签
features = iris_dataset.data
labels = iris_dataset.target
只选取Setosa类别(标签为0)的数据进行训练
setosa_features = features[labels == 0]
将数据缩放到[-1, 1]区间(配合生成器的Tanh输出激活)
scaler = MinMaxScaler(feature_range=(-1, 1))
scaled_features = scaler.fit_transform(setosa_features)
转换为PyTorch张量并创建数据加载器
注意:必须转为float类型才能与模型参数兼容
data_tensor = torch.from_numpy(scaled_features).float()
dataset = TensorDataset(data_tensor)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
打印数据基本信息
print(f"训练样本数量: {len(scaled_features)}")
print(f"特征维度: {scaled_features.shape[1]}") # 鸢尾花数据集固定为4维特征
--------------------------
3. 定义生成器和判别器
--------------------------
class Generator(nn.Module):
"""生成器:将随机噪声转换为模拟的鸢尾花特征数据"""
def init(self):
super(Generator, self).init()
简单的全连接网络结构
self.net = nn.Sequential(
nn.Linear(latent_dim, 16), # 从潜在空间映射到16维
nn.ReLU(), # 激活函数增加非线性
nn.Linear(16, 32), # 进一步映射到32维
nn.ReLU(),
nn.Linear(32, 4), # 输出4维特征(与真实数据一致)
nn.Tanh() # 确保输出在[-1, 1]范围内
)
def forward(self, x):
前向传播:输入噪声,输出生成的数据
return self.net(x)
class Discriminator(nn.Module):
"""判别器:区分输入数据是真实样本还是生成器伪造的"""
def init(self):
super(Discriminator, self).init()
简单的全连接网络结构
self.net = nn.Sequential(
nn.Linear(4, 32), # 输入4维特征
nn.LeakyReLU(0.2), # LeakyReLU避免梯度消失问题
nn.Linear(32, 16), # 压缩到16维
nn.LeakyReLU(0.2),
nn.Linear(16, 1), # 输出单个概率值
nn.Sigmoid() # 将输出压缩到[0,1](表示真实数据的概率)
)
def forward(self, x):
前向传播:输入数据,输出判断概率
return self.net(x)
初始化模型并移动到运算设备
generator = Generator().to(device)
discriminator = Discriminator().to(device)
打印模型结构
print("\n生成器结构:")
print(generator)
print("\n判别器结构:")
print(discriminator)
--------------------------
4. 配置训练组件
--------------------------
定义损失函数(二元交叉熵,适合二分类问题)
criterion = nn.BCELoss()
定义优化器(分别优化生成器和判别器)
gen_optimizer = optim.Adam(generator.parameters(), lr=learning_rate, betas=(beta1, 0.999))
dis_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999))
--------------------------
5. 开始训练
--------------------------
print("\n--- 训练开始 ---")
for epoch in range(train_epochs):
遍历数据加载器中的每一批次
for batch_idx, (real_data,) in enumerate(data_loader):
将真实数据移动到运算设备
real_data = real_data.to(device)
current_batch_size = real_data.size(0) # 获取当前批次的实际样本数(最后一批可能不满)
创建标签:真实数据标为1,生成数据标为0
real_labels = torch.ones(current_batch_size, 1).to(device)
fake_labels = torch.zeros(current_batch_size, 1).to(device)
--------------------
训练判别器
--------------------
dis_optimizer.zero_grad() # 清空判别器的梯度缓存
1. 用真实数据训练
real_output = discriminator(real_data)
计算真实数据的损失(希望判别器能认出真实数据)
loss_real = criterion(real_output, real_labels)
2. 用生成的数据训练
生成随机噪声(作为生成器的输入)
noise = torch.randn(current_batch_size, latent_dim).to(device)
生成假数据,并阻断梯度流向生成器(避免影响生成器参数)
fake_data = generator(noise).detach()
fake_output = discriminator(fake_data)
计算假数据的损失(希望判别器能认出假数据)
loss_fake = criterion(fake_output, fake_labels)
总损失反向传播并更新判别器参数
dis_loss = loss_real + loss_fake
dis_loss.backward()
dis_optimizer.step()
--------------------
训练生成器
--------------------
gen_optimizer.zero_grad() # 清空生成器的梯度缓存
重新生成假数据(这次需要计算生成器的梯度)
noise = torch.randn(current_batch_size, latent_dim).to(device)
fake_data = generator(noise)
fake_output = discriminator(fake_data)
生成器的损失:希望判别器把假数据当成真的(所以标签用real_labels)
gen_loss = criterion(fake_output, real_labels)
gen_loss.backward()
gen_optimizer.step()
每1000轮打印一次训练状态
if (epoch + 1) % 1000 == 0:
print(
f"轮次 [{epoch+1}/{train_epochs}], "
f"判别器损失: {dis_loss.item():.4f}, "
f"生成器损失: {gen_loss.item():.4f}"
)
print("\n--- 训练完成 ---")