day53

import torch

import torch.nn as nn

import torch.optim as optim

from torch.utils.data import DataLoader, TensorDataset

import numpy as np

from sklearn.preprocessing import MinMaxScaler

from sklearn.datasets import load_iris

import warnings

忽略不必要的警告信息

warnings.filterwarnings("ignore")

--------------------------

1. 配置训练参数与设备

--------------------------

潜在空间维度(生成器的输入维度)

latent_dim = 10

训练总轮数(GAN通常需要较多迭代才能收敛)

train_epochs = 10000

批次大小(根据数据集规模调整)

batch_size = 32

学习率(控制参数更新幅度)

learning_rate = 0.0002

Adam优化器的动量参数(影响收敛稳定性)

beta1 = 0.5

自动选择运算设备(优先GPU,没有则用CPU)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"当前使用设备: {device}")

--------------------------

2. 数据加载与预处理

--------------------------

加载鸢尾花数据集

iris_dataset = load_iris()

提取特征数据和标签

features = iris_dataset.data

labels = iris_dataset.target

只选取Setosa类别(标签为0)的数据进行训练

setosa_features = features[labels == 0]

将数据缩放到[-1, 1]区间(配合生成器的Tanh输出激活)

scaler = MinMaxScaler(feature_range=(-1, 1))

scaled_features = scaler.fit_transform(setosa_features)

转换为PyTorch张量并创建数据加载器

注意:必须转为float类型才能与模型参数兼容

data_tensor = torch.from_numpy(scaled_features).float()

dataset = TensorDataset(data_tensor)

data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

打印数据基本信息

print(f"训练样本数量: {len(scaled_features)}")

print(f"特征维度: {scaled_features.shape[1]}") # 鸢尾花数据集固定为4维特征

--------------------------

3. 定义生成器和判别器

--------------------------

class Generator(nn.Module):

"""生成器:将随机噪声转换为模拟的鸢尾花特征数据"""

def init(self):

super(Generator, self).init()

简单的全连接网络结构

self.net = nn.Sequential(

nn.Linear(latent_dim, 16), # 从潜在空间映射到16维

nn.ReLU(), # 激活函数增加非线性

nn.Linear(16, 32), # 进一步映射到32维

nn.ReLU(),

nn.Linear(32, 4), # 输出4维特征(与真实数据一致)

nn.Tanh() # 确保输出在[-1, 1]范围内

)

def forward(self, x):

前向传播:输入噪声,输出生成的数据

return self.net(x)

class Discriminator(nn.Module):

"""判别器:区分输入数据是真实样本还是生成器伪造的"""

def init(self):

super(Discriminator, self).init()

简单的全连接网络结构

self.net = nn.Sequential(

nn.Linear(4, 32), # 输入4维特征

nn.LeakyReLU(0.2), # LeakyReLU避免梯度消失问题

nn.Linear(32, 16), # 压缩到16维

nn.LeakyReLU(0.2),

nn.Linear(16, 1), # 输出单个概率值

nn.Sigmoid() # 将输出压缩到[0,1](表示真实数据的概率)

)

def forward(self, x):

前向传播:输入数据,输出判断概率

return self.net(x)

初始化模型并移动到运算设备

generator = Generator().to(device)

discriminator = Discriminator().to(device)

打印模型结构

print("\n生成器结构:")

print(generator)

print("\n判别器结构:")

print(discriminator)

--------------------------

4. 配置训练组件

--------------------------

定义损失函数(二元交叉熵,适合二分类问题)

criterion = nn.BCELoss()

定义优化器(分别优化生成器和判别器)

gen_optimizer = optim.Adam(generator.parameters(), lr=learning_rate, betas=(beta1, 0.999))

dis_optimizer = optim.Adam(discriminator.parameters(), lr=learning_rate, betas=(beta1, 0.999))

--------------------------

5. 开始训练

--------------------------

print("\n--- 训练开始 ---")

for epoch in range(train_epochs):

遍历数据加载器中的每一批次

for batch_idx, (real_data,) in enumerate(data_loader):

将真实数据移动到运算设备

real_data = real_data.to(device)

current_batch_size = real_data.size(0) # 获取当前批次的实际样本数(最后一批可能不满)

创建标签:真实数据标为1,生成数据标为0

real_labels = torch.ones(current_batch_size, 1).to(device)

fake_labels = torch.zeros(current_batch_size, 1).to(device)

--------------------

训练判别器

--------------------

dis_optimizer.zero_grad() # 清空判别器的梯度缓存

1. 用真实数据训练

real_output = discriminator(real_data)

计算真实数据的损失(希望判别器能认出真实数据)

loss_real = criterion(real_output, real_labels)

2. 用生成的数据训练

生成随机噪声(作为生成器的输入)

noise = torch.randn(current_batch_size, latent_dim).to(device)

生成假数据,并阻断梯度流向生成器(避免影响生成器参数)

fake_data = generator(noise).detach()

fake_output = discriminator(fake_data)

计算假数据的损失(希望判别器能认出假数据)

loss_fake = criterion(fake_output, fake_labels)

总损失反向传播并更新判别器参数

dis_loss = loss_real + loss_fake

dis_loss.backward()

dis_optimizer.step()

--------------------

训练生成器

--------------------

gen_optimizer.zero_grad() # 清空生成器的梯度缓存

重新生成假数据(这次需要计算生成器的梯度)

noise = torch.randn(current_batch_size, latent_dim).to(device)

fake_data = generator(noise)

fake_output = discriminator(fake_data)

生成器的损失:希望判别器把假数据当成真的(所以标签用real_labels)

gen_loss = criterion(fake_output, real_labels)

gen_loss.backward()

gen_optimizer.step()

每1000轮打印一次训练状态

if (epoch + 1) % 1000 == 0:

print(

f"轮次 [{epoch+1}/{train_epochs}], "

f"判别器损失: {dis_loss.item():.4f}, "

f"生成器损失: {gen_loss.item():.4f}"

)

print("\n--- 训练完成 ---")

相关推荐
万千思绪4 分钟前
【PyCharm 2025.1.2配置debug】
ide·python·pycharm
微风粼粼2 小时前
程序员在线接单
java·jvm·后端·python·eclipse·tomcat·dubbo
云天徽上2 小时前
【PaddleOCR】OCR表格识别数据集介绍,包含PubTabNet、好未来表格识别、WTW中文场景表格等数据,持续更新中......
python·ocr·文字识别·表格识别·paddleocr·pp-ocrv5
你怎么知道我是队长2 小时前
python-input内置函数
开发语言·python
叹一曲当时只道是寻常2 小时前
Python实现优雅的目录结构打印工具
python
hbwhmama3 小时前
python高级变量XIII
python
费弗里4 小时前
Python全栈应用开发利器Dash 3.x新版本介绍(3)
python·dash
dme.4 小时前
Javascript之DOM操作
开发语言·javascript·爬虫·python·ecmascript
加油吧zkf4 小时前
AI大模型如何重塑软件开发流程?——结合目标检测的深度实践与代码示例
开发语言·图像处理·人工智能·python·yolo
t_hj4 小时前
python规划
python