dcgan

复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
from torchvision import transforms
from PIL import Image
import os

# Generator 定义
class Generator(nn.Module):
    def __init__(self, noise_dim):
        super(Generator, self).__init__()

        self.fc = nn.Sequential(
            nn.Linear(noise_dim, 1024 * 4 * 4),
            nn.BatchNorm1d(1024 * 4 * 4),
            nn.ReLU(True)
        )

        self.deconv_layers = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, 4, 2, 1),  # 4x4 -> 8x8
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            nn.ConvTranspose2d(512, 256, 4, 2, 1),   # 8x8 -> 16x16
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.ConvTranspose2d(256, 128, 4, 2, 1),   # 16x16 -> 32x32
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1),    # 32x32 -> 64x64
            nn.BatchNorm2d(64),
            nn.ReLU(True),

            nn.ConvTranspose2d(64, 3, 4, 2, 1),      # 64x64 -> 128x128
            nn.Tanh()
        )

    def forward(self, noise):
        x = self.fc(noise).view(-1, 1024, 4, 4)
        x = self.deconv_layers(x)
        return x

# Discriminator 定义
class Discriminator(nn.Module):
    def __init__(self, input_channels=3):
        super(Discriminator, self).__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(input_channels, 64, 4, 2, 1),    # 128x128 -> 64x64
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1),               # 64x64 -> 32x32
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1),              # 32x32 -> 16x16
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, 2, 1),              # 16x16 -> 8x8
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(512, 1, 8, 1, 0),                # 8x8 -> 1x1
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(-1, 1)
        return x

# 数据集定义
class TrafficSignDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        labels_file_path = os.path.join(root_dir, 'labels.txt')
        with open(labels_file_path, 'r') as f:
            lines = f.readlines()
            for line in lines:
                img_name, _ = line.strip().split()
                img_path = os.path.join(root_dir, img_name)
                self.image_paths.append(img_path)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

# 设置超参数
noise_dim = 100      # 噪声维度
batch_size = 8       # 批大小
lr = 2e-4            # 学习率
num_epochs = 500     # 训练轮数
output_dir = r"C:\Users\sun\Desktop\2024102201\out"  # 生成图像保存路径

if not os.path.exists(output_dir):
    os.makedirs(output_dir)

# 初始化模型
G = Generator(noise_dim=noise_dim).to('cuda')
D = Discriminator(input_channels=3).to('cuda')

# 设置优化器
optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(0.0, 0.9))
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(0.0, 0.9))

# 学习率调度器
scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, step_size=50, gamma=0.5)
scheduler_D = optim.lr_scheduler.StepLR(optimizer_D, step_size=50, gamma=0.5)

# 数据预处理
transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载数据集
root_dir = r"C:\Users\sun\Desktop\2024102201\1"
dataset = TrafficSignDataset(root_dir=root_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 定义损失函数
def discriminator_hinge_loss(real_outputs, fake_outputs):
    real_loss = torch.mean(F.relu(1.0 - real_outputs))
    fake_loss = torch.mean(F.relu(1.0 + fake_outputs))
    return real_loss + fake_loss

def generator_hinge_loss(fake_outputs):
    return -torch.mean(fake_outputs)

# 固定噪声用于生成图像
fixed_noise = torch.randn(64, noise_dim).to('cuda')

# 训练循环
for epoch in range(num_epochs):
    for i, real_images in enumerate(dataloader):
        real_images = real_images.to('cuda')
        batch_size_current = real_images.size(0)

        # ---------------------
        #  训练判别器
        # ---------------------
        optimizer_D.zero_grad()
        noise = torch.randn(batch_size_current, noise_dim).to('cuda')
        fake_images = G(noise)

        real_outputs = D(real_images)
        fake_outputs = D(fake_images.detach())
        d_loss = discriminator_hinge_loss(real_outputs, fake_outputs)
        d_loss.backward()
        optimizer_D.step()

        # ---------------------
        #  训练生成器
        # ---------------------
        optimizer_G.zero_grad()
        fake_outputs = D(fake_images)
        g_loss = generator_hinge_loss(fake_outputs)
        g_loss.backward()
        optimizer_G.step()

    # 更新学习率
    scheduler_G.step()
    scheduler_D.step()

    # 输出损失
    print(f"Epoch [{epoch + 1}/{num_epochs}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")

    # 生成并保存图像
    with torch.no_grad():
        fake_images = G(fixed_noise).detach().cpu()
        save_image(fake_images, os.path.join(output_dir, f"epoch_{epoch + 1}.png"), nrow=8, normalize=True)

    # 可选:每隔一定epoch保存一次模型
    if (epoch + 1) % 50 == 0:
        torch.save(G.state_dict(), os.path.join(output_dir, f'generator_epoch_{epoch + 1}.pth'))
        torch.save(D.state_dict(), os.path.join(output_dir, f'discriminator_epoch_{epoch + 1}.pth'))
相关推荐
明明真系叻1 小时前
2025.4.20机器学习笔记:文献阅读
人工智能·笔记·机器学习
Vodka~3 小时前
深度学习——数据处理脚本(基于detectron2框架)
人工智能·windows·深度学习
lixy5794 小时前
深度学习之自动微分
人工智能·python·深度学习
神经星星4 小时前
【TVM教程】microTVM TFLite 指南
人工智能·机器学习·编程语言
SunsPlanter5 小时前
机器学习期末
人工智能·机器学习
吹风看太阳5 小时前
机器学习02——RNN
人工智能·rnn·机器学习
cosmos3156 小时前
深度学习进行网络流识别
深度学习·算法
Ac157ol7 小时前
《基于神经网络实现手写数字分类》
人工智能·深度学习·神经网络·机器学习·cnn
计算机视觉农民工7 小时前
机器学习有多少种算法?当下入门需要全部学习吗?
学习·算法·机器学习
Hongs_Cai7 小时前
机器学习简介
人工智能·机器学习