10251114

复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.utils import save_image
from torchvision import transforms
from PIL import Image
import os


class SelfAttention(nn.Module):
    def __init__(self, in_dim):
        super(SelfAttention, self).__init__()
        self.query = nn.utils.spectral_norm(nn.Conv2d(in_dim, in_dim // 8, kernel_size=1))
        self.key = nn.utils.spectral_norm(nn.Conv2d(in_dim, in_dim // 8, kernel_size=1))
        self.value = nn.utils.spectral_norm(nn.Conv2d(in_dim, in_dim, kernel_size=1))
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        batch_size, C, width, height = x.size()
        proj_query = self.query(x).view(batch_size, -1, width * height).permute(0, 2, 1)
        proj_key = self.key(x).view(batch_size, -1, width * height)
        energy = torch.bmm(proj_query, proj_key)
        attention = F.softmax(energy, dim=-1)
        proj_value = self.value(x).view(batch_size, -1, width * height)
        out = torch.bmm(proj_value, attention.permute(0, 2, 1))
        out = out.view(batch_size, C, width, height)
        out = self.gamma * out + x
        return out


class Generator(nn.Module):
    def __init__(self, noise_dim, label_dim):
        super(Generator, self).__init__()
        self.label_dim = label_dim

        self.fc = nn.Sequential(
            nn.Linear(noise_dim + label_dim, 1024 * 4 * 4),
            nn.BatchNorm1d(1024 * 4 * 4),
            nn.ReLU(True)
        )

        self.deconv_layers = nn.Sequential(
            nn.utils.spectral_norm(nn.ConvTranspose2d(1024, 512, 4, 2, 1)),  # 4x4 -> 8x8
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            SelfAttention(512),

            nn.utils.spectral_norm(nn.ConvTranspose2d(512, 256, 4, 2, 1)),  # 8x8 -> 16x16
            nn.BatchNorm2d(256),
            nn.ReLU(True),

            nn.utils.spectral_norm(nn.ConvTranspose2d(256, 128, 4, 2, 1)),  # 16x16 -> 32x32
            nn.BatchNorm2d(128),
            nn.ReLU(True),

            nn.utils.spectral_norm(nn.ConvTranspose2d(128, 64, 4, 2, 1)),  # 32x32 -> 64x64
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            SelfAttention(64),

            nn.utils.spectral_norm(nn.ConvTranspose2d(64, 3, 4, 2, 1)),  # 64x64 -> 128x128
            nn.Tanh()
        )

    def forward(self, noise, labels):
        x = torch.cat((noise, labels), dim=1)
        x = self.fc(x).view(-1, 1024, 4, 4)
        x = self.deconv_layers(x)
        return x


class Discriminator(nn.Module):
    def __init__(self, input_channels, label_dim):
        super(Discriminator, self).__init__()
        self.label_dim = label_dim

        self.conv1 = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(input_channels + label_dim, 64, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.conv2 = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(64, 128, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.conv3 = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(128, 256, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.self_attn = SelfAttention(256)

        self.conv4 = nn.Sequential(
            nn.utils.spectral_norm(nn.Conv2d(256, 512, 4, 2, 1)),
            nn.LeakyReLU(0.2, inplace=True)
        )

        self.fc = nn.utils.spectral_norm(nn.Linear(512 * 8 * 8, 1))

    def forward(self, x, labels):
        batch_size = x.size(0)
        img_size = x.size(2)
        labels = labels.view(batch_size, self.label_dim, 1, 1)
        labels = labels.expand(batch_size, self.label_dim, img_size, img_size)
        x = torch.cat([x, labels], dim=1)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.self_attn(x)
        x = self.conv4(x)
        x = x.view(batch_size, -1)
        x = self.fc(x)
        return x


class TrafficSignDataset(Dataset):
    def __init__(self, root_dir, labels_file, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.image_paths = []
        self.labels = []
        with open(labels_file, 'r') as f:
            lines = f.readlines()
            for line in lines:
                img_name, label = line.strip().split()
                img_path = os.path.join(root_dir, img_name)
                self.image_paths.append(img_path)
                self.labels.append(int(label))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label


# 设置超参数
noise_dim = 100  # 噪声维度
label_dim = 58  # 标签维度
batch_size =8  # 批大小
lr = 2e-4
num_epochs = 500
n_critic = 5
lambda_gp = 10
output_dir = r"C:\Users\sun\Desktop\2024102201\out"  # 生成图像保存路径

if not os.path.exists(output_dir):
    os.makedirs(output_dir)

G = Generator(noise_dim=noise_dim, label_dim=label_dim).to('cuda')
D = Discriminator(input_channels=3, label_dim=label_dim).to('cuda')

beta1 = 0.0
beta2 = 0.9
optimizer_G = optim.Adam(G.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_D = optim.Adam(D.parameters(), lr=lr, betas=(beta1, beta2))
scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, step_size=50, gamma=0.5)
scheduler_D = optim.lr_scheduler.StepLR(optimizer_D, step_size=50, gamma=0.5)

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

root_dir = r"C:\Users\sun\Desktop\2024102201\1"
labels_file = r"C:\Users\sun\Desktop\2024102201\1\labels.txt"  # 标签文件路径
dataset = TrafficSignDataset(root_dir=root_dir, labels_file=labels_file, transform=transform)
# dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)


def discriminator_hinge_loss(real_outputs, fake_outputs):
    real_loss = torch.mean(F.relu(1.0 - real_outputs))
    fake_loss = torch.mean(F.relu(1.0 + fake_outputs))
    return real_loss + fake_loss


def generator_hinge_loss(fake_outputs):
    return -torch.mean(fake_outputs)


def compute_gradient_penalty(D, real_samples, fake_samples, labels):
    alpha = torch.rand(real_samples.size(0), 1, 1, 1).to(real_samples.device)
    interpolates = (alpha * real_samples + (1 - alpha) * fake_samples).requires_grad_(True)
    d_interpolates = D(interpolates, labels)
    fake = torch.ones(d_interpolates.size()).to(real_samples.device)
    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty


fixed_noise = torch.randn(64, noise_dim).to('cuda')
fixed_labels_idx = torch.arange(0, label_dim).repeat(64 // label_dim + 1)[:64].to('cuda')
fixed_labels_one_hot = torch.zeros(64, label_dim).to('cuda')
fixed_labels_one_hot.scatter_(1, fixed_labels_idx.view(-1, 1), 1)


for epoch in range(num_epochs):
    for i, (real_images, real_labels_idx) in enumerate(dataloader):
        real_images = real_images.to('cuda')
        real_labels_idx = real_labels_idx.to('cuda')
        batch_size_current = real_images.size(0)
        real_labels_one_hot = torch.zeros(batch_size_current, label_dim).to('cuda')
        real_labels_one_hot.scatter_(1, real_labels_idx.view(-1, 1), 1)

        optimizer_D.zero_grad()
        noise = torch.randn(batch_size_current, noise_dim).to('cuda')
        fake_labels_idx = torch.randint(0, label_dim, (batch_size_current,)).to('cuda')
        fake_labels_one_hot = torch.zeros(batch_size_current, label_dim).to('cuda')
        fake_labels_one_hot.scatter_(1, fake_labels_idx.view(-1, 1), 1)
        fake_images = G(noise, fake_labels_one_hot)

        real_outputs = D(real_images, real_labels_one_hot)
        fake_outputs = D(fake_images.detach(), fake_labels_one_hot)
        d_loss = discriminator_hinge_loss(real_outputs, fake_outputs)
        gradient_penalty = compute_gradient_penalty(D, real_images, fake_images.detach(), real_labels_one_hot)
        d_loss += lambda_gp * gradient_penalty
        d_loss.backward()
        optimizer_D.step()

        if i % n_critic == 0:
            optimizer_G.zero_grad()
            fake_outputs = D(fake_images, fake_labels_one_hot)
            g_loss = generator_hinge_loss(fake_outputs)
            g_loss.backward()
            optimizer_G.step()

    scheduler_G.step()
    scheduler_D.step()

    print(f"Epoch [{epoch + 1}/{num_epochs}], D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")

    with torch.no_grad():
        fake_images = G(fixed_noise, fixed_labels_one_hot)
        save_image(fake_images, os.path.join(output_dir, f"epoch_{epoch + 1}.png"), nrow=8, normalize=True)


torch.save(G.state_dict(), 'generator.pth')
torch.save(D.state_dict(), 'discriminator.pth')
相关推荐
lxmyzzs8 分钟前
基于深度学习CenterPoint的3D目标检测部署实战
人工智能·深度学习·目标检测·自动驾驶·ros·激光雷达·3d目标检测
念念01071 小时前
数学建模竞赛中评价类相关模型
python·数学建模·因子分析·topsis
云天徽上2 小时前
【数据可视化-94】2025 亚洲杯总决赛数据可视化分析:澳大利亚队 vs 中国队
python·信息可视化·数据挖掘·数据分析·数据可视化·pyecharts
☺����2 小时前
实现自己的AI视频监控系统-第一章-视频拉流与解码2
开发语言·人工智能·python·音视频
王者鳜錸2 小时前
PYTHON让繁琐的工作自动化-函数
开发语言·python·自动化
算法_小学生3 小时前
循环神经网络(RNN, Recurrent Neural Network)
人工智能·rnn·深度学习
xiao助阵3 小时前
python实现梅尔频率倒谱系数(MFCC) 除了傅里叶变换和离散余弦变换
开发语言·python
努力还债的学术吗喽4 小时前
【速通】深度学习模型调试系统化方法论:从问题定位到性能优化
人工智能·深度学习·学习·调试·模型·方法论
麻辣清汤5 小时前
结合BI多维度异常分析(日期-> 商家/渠道->日期(商家/渠道))
数据库·python·sql·finebi
钢铁男儿5 小时前
Python 正则表达式(正则表达式和Python 语言)
python·mysql·正则表达式