dcgan

复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
from torchvision import transforms
from PIL import Image
import os

# Generator 定义
class Generator(nn.Module):
    def __init__(self, noise_dim):
        super(Generator, self).__init__()

        self.fc = nn.Sequential(
            nn.Linear(noise_dim, 1024 * 4 * 4),
            nn.BatchNorm1d(1024 * 4 * 4),
            nn.ReLU(True)
        )

        self.deconv_layers = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, 4, 2, 1),  # 4x4 -> 8x8
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            nn.ConvTranspose2d(512, 256, 4, 2, 1),   # 8x8 -> 16x16
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1),   # 16x16 -> 32x32
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1),    # 32x32 -> 64x64
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.ConvTranspose2d(64, 3, 4, 2, 1),      # 64x64 -> 128x128
            nn.Tanh()
        )

    def forward(self, noise):
        x = self.fc(noise).view(-1, 1024, 4, 4)
        x = self.deconv_layers(x)
        return x

# Discriminator 定义
class Discriminator(nn.Module):
    def __init__(self, input_channels=3):
        super(Discriminator, self).__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(input_channels, 64, 4, 2, 1),    # 128x128 -> 64x64
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1),               # 64x64 -> 32x32
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1),              # 32x32 -> 16x16
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, 2, 1),              # 16x16 -> 8x8
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 1, 8, 1, 0),                # 8x8 -> 1x1
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(-1, 1)
        return x

# 数据集定义
class TrafficSignDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        labels_file_path = os.path.join(root_dir, 'labels.txt')
        with open(labels_file_path, 'r') as f:
            lines = f.readlines()
            for line in lines:
                img_name, _ = line.strip().split()
                img_path = os.path.join(root_dir, img_name)
                self.image_paths.append(img_path)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

# 设置超参数
noise_dim = 100      # 噪声维度
batch_size = 8       # 批大小
lr = 2e-4            # 学习率
num_epochs = 500     # 训练轮数
output_dir = r"C:\Users\sun\Desktop\2024102201\out"  # 生成图像保存路径

if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# 初始化模型
G = Generator(noise_dim=noise_dim).to('cuda')
D = Discriminator(input_channels=3).to('cuda')

# 设置优化器
optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.0, 0.9))
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.0, 0.9))

# 学习率调度器
scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, step_size=50, gamma=0.5)
scheduler_D = optim.lr_scheduler.StepLR(optimizer_D, step_size=50, gamma=0.5)

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载数据集
root_dir = r"C:\Users\sun\Desktop\2024102201\1"
dataset = TrafficSignDataset(root_dir=root_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 定义损失函数
def discriminator_hinge_loss(real_outputs, fake_outputs):
    real_loss = torch.mean(F.relu(1.0 - real_outputs))
    fake_loss = torch.mean(F.relu(1.0 + fake_outputs))
    return real_loss + fake_loss

def generator_hinge_loss(fake_outputs):
    return -torch.mean(fake_outputs)

# 固定噪声用于生成图像
fixed_noise = torch.randn(64, noise_dim).to('cuda')

# 训练循环
for epoch in range(num_epochs):
    for i, real_images in enumerate(dataloader):
        real_images = real_images.to('cuda')
        batch_size_current = real_images.size(0)

        # ---------------------
        #  训练判别器
        # ---------------------
        optimizer_D.zero_grad()
        noise = torch.randn(batch_size_current, noise_dim).to('cuda')
        fake_images = G(noise)

        real_outputs = D(real_images)
        fake_outputs = D(fake_images.detach())
        d_loss = discriminator_hinge_loss(real_outputs, fake_outputs)
        d_loss.backward()
        optimizer_D.step()

        # ---------------------
        #  训练生成器
        # ---------------------
        optimizer_G.zero_grad()
        fake_outputs = D(fake_images)
        g_loss = generator_hinge_loss(fake_outputs)
        g_loss.backward()
        optimizer_G.step()

    # 更新学习率
    scheduler_G.step()
    scheduler_D.step()

    # 输出损失
    print(f"Epoch [{epoch + 1}/{num_epochs}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")

    # 生成并保存图像
    with torch.no_grad():
        fake_images = G(fixed_noise).detach().cpu()
        save_image(fake_images, os.path.join(output_dir, f"epoch_{epoch + 1}.png"), nrow=8, normalize=True)

    # 可选:每隔一定epoch保存一次模型
    if (epoch + 1) % 50 == 0:
        torch.save(G.state_dict(), os.path.join(output_dir, f'generator_epoch_{epoch + 1}.pth'))
        torch.save(D.state_dict(), os.path.join(output_dir, f'discriminator_epoch_{epoch + 1}.pth'))
相关推荐
Jamence1 小时前
【深度学习数学知识】-贝叶斯公式
人工智能·深度学习·概率论
feifeikon1 小时前
机器学习DAY4续:梯度提升与 XGBoost (完)
人工智能·深度学习·机器学习
IT猿手2 小时前
最新高性能多目标优化算法:多目标麋鹿优化算法(MOEHO)求解GLSMOP1-GLSMOP9及工程应用---盘式制动器设计,提供完整MATLAB代码
开发语言·算法·机器学习·matlab·强化学习
取个名字真难呐2 小时前
LossMaskMatrix损失函数掩码矩阵
python·深度学习·矩阵
Kenneth風车3 小时前
【机器学习(九)】分类和回归任务-多层感知机(Multilayer Perceptron,MLP)算法-Sentosa_DSML社区版 (1)111
算法·机器学习·分类
盼小辉丶3 小时前
TensorFlow深度学习实战(2)——使用TensorFlow构建神经网络
深度学习·神经网络·tensorflow
18号房客3 小时前
计算机视觉-人工智能(AI)入门教程一
人工智能·深度学习·opencv·机器学习·计算机视觉·数据挖掘·语音识别
QQ_7781329743 小时前
基于深度学习的图像超分辨率重建
人工智能·机器学习·超分辨率重建
IT古董4 小时前
【漫话机器学习系列】020.正则化强度的倒数C(Inverse of regularization strength)
人工智能·机器学习
进击的小小学生4 小时前
机器学习连载
人工智能·机器学习