import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
from torchvision import transforms
from PIL import Image
import os
# Generator 定义
class Generator(nn.Module):
def __init__(self, noise_dim):
super(Generator, self).__init__()
self.fc = nn.Sequential(
nn.Linear(noise_dim, 1024 * 4 * 4),
nn.BatchNorm1d(1024 * 4 * 4),
nn.ReLU(True)
)
self.deconv_layers = nn.Sequential(
nn.ConvTranspose2d(1024, 512, 4, 2, 1), # 4x4 -> 8x8
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, 4, 2, 1), # 8x8 -> 16x16
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1), # 16x16 -> 32x32
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, 2, 1), # 32x32 -> 64x64
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 3, 4, 2, 1), # 64x64 -> 128x128
nn.Tanh()
)
def forward(self, noise):
x = self.fc(noise).view(-1, 1024, 4, 4)
x = self.deconv_layers(x)
return x
# Discriminator 定义
class Discriminator(nn.Module):
def __init__(self, input_channels=3):
super(Discriminator, self).__init__()
self.conv_layers = nn.Sequential(
nn.Conv2d(input_channels, 64, 4, 2, 1), # 128x128 -> 64x64
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1), # 64x64 -> 32x32
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, 2, 1), # 32x32 -> 16x16
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, 4, 2, 1), # 16x16 -> 8x8
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1, 8, 1, 0), # 8x8 -> 1x1
)
def forward(self, x):
x = self.conv_layers(x)
x = x.view(-1, 1)
return x
# 数据集定义
class TrafficSignDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.image_paths = []
labels_file_path = os.path.join(root_dir, 'labels.txt')
with open(labels_file_path, 'r') as f:
lines = f.readlines()
for line in lines:
img_name, _ = line.strip().split()
img_path = os.path.join(root_dir, img_name)
self.image_paths.append(img_path)
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image
# 设置超参数
noise_dim = 100 # 噪声维度
batch_size = 8 # 批大小
lr = 2e-4 # 学习率
num_epochs = 500 # 训练轮数
output_dir = r"C:\Users\sun\Desktop\2024102201\out" # 生成图像保存路径
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# 初始化模型
G = Generator(noise_dim=noise_dim).to('cuda')
D = Discriminator(input_channels=3).to('cuda')
# 设置优化器
optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.0, 0.9))
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.0, 0.9))
# 学习率调度器
scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, step_size=50, gamma=0.5)
scheduler_D = optim.lr_scheduler.StepLR(optimizer_D, step_size=50, gamma=0.5)
# 数据预处理
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 加载数据集
root_dir = r"C:\Users\sun\Desktop\2024102201\1"
dataset = TrafficSignDataset(root_dir=root_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 定义损失函数
def discriminator_hinge_loss(real_outputs, fake_outputs):
real_loss = torch.mean(F.relu(1.0 - real_outputs))
fake_loss = torch.mean(F.relu(1.0 + fake_outputs))
return real_loss + fake_loss
def generator_hinge_loss(fake_outputs):
return -torch.mean(fake_outputs)
# 固定噪声用于生成图像
fixed_noise = torch.randn(64, noise_dim).to('cuda')
# 训练循环
for epoch in range(num_epochs):
for i, real_images in enumerate(dataloader):
real_images = real_images.to('cuda')
batch_size_current = real_images.size(0)
# ---------------------
# 训练判别器
# ---------------------
optimizer_D.zero_grad()
noise = torch.randn(batch_size_current, noise_dim).to('cuda')
fake_images = G(noise)
real_outputs = D(real_images)
fake_outputs = D(fake_images.detach())
d_loss = discriminator_hinge_loss(real_outputs, fake_outputs)
d_loss.backward()
optimizer_D.step()
# ---------------------
# 训练生成器
# ---------------------
optimizer_G.zero_grad()
fake_outputs = D(fake_images)
g_loss = generator_hinge_loss(fake_outputs)
g_loss.backward()
optimizer_G.step()
# 更新学习率
scheduler_G.step()
scheduler_D.step()
# 输出损失
print(f"Epoch [{epoch + 1}/{num_epochs}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
# 生成并保存图像
with torch.no_grad():
fake_images = G(fixed_noise).detach().cpu()
save_image(fake_images, os.path.join(output_dir, f"epoch_{epoch + 1}.png"), nrow=8, normalize=True)
# 可选:每隔一定epoch保存一次模型
if (epoch + 1) % 50 == 0:
torch.save(G.state_dict(), os.path.join(output_dir, f'generator_epoch_{epoch + 1}.pth'))
torch.save(D.state_dict(), os.path.join(output_dir, f'discriminator_epoch_{epoch + 1}.pth'))
dcgan
yyfhq2024-10-30 22:59
相关推荐
堇舟2 小时前
斯皮尔曼相关(Spearman correlation)系数这个男人是小帅3 小时前
【图神经网络】 AM-GCN论文精讲(全网最细致篇)放松吃羊肉4 小时前
【约束优化】一次搞定拉格朗日,对偶问题,弱对偶定理,Slater条件和KKT条件YRr YRr5 小时前
深度学习:正则化(Regularization)详细解释yyfhq5 小时前
rescorediff思通数据5 小时前
AI助力医疗数据自动化:诊断报告识别与管理(●'◡'●)知6 小时前
基于树莓派的安保巡逻机器人--(一、快速人脸录入与精准人脸识别)迷路爸爸1807 小时前
深入理解Allan方差:用体重数据分析误差的时间尺度与稳定性T0uken8 小时前
【机器学习】过拟合与欠拟合