生成式AI系列 —— DCGAN生成手写数字

1、模型构建

1.1 构建生成器

python 复制代码
# 导入软件包
import torch
import torch.nn as nn

class Generator(nn.Module):

    def __init__(self, z_dim=20, image_size=256):
        super(Generator, self).__init__()

        self.layer1 = nn.Sequential(
            nn.ConvTranspose2d(z_dim, image_size * 32,
                               kernel_size=4, stride=1),
            nn.BatchNorm2d(image_size * 32),
            nn.ReLU(inplace=True))

        self.layer2 = nn.Sequential(
            nn.ConvTranspose2d(image_size * 32, image_size * 16,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(image_size * 16),
            nn.ReLU(inplace=True))

        self.layer3 = nn.Sequential(
            nn.ConvTranspose2d(image_size * 16, image_size * 8,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(image_size * 8),
            nn.ReLU(inplace=True))

        self.layer4 = nn.Sequential(
            nn.ConvTranspose2d(image_size * 8, image_size *4,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(image_size * 4),
            nn.ReLU(inplace=True))
        self.layer5 = nn.Sequential(
            nn.ConvTranspose2d(image_size * 4, image_size * 2,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(image_size * 2),
            nn.ReLU(inplace=True))
        self.layer6 = nn.Sequential(
            nn.ConvTranspose2d(image_size * 2, image_size,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(image_size),
            nn.ReLU(inplace=True))
        self.last = nn.Sequential(
            nn.ConvTranspose2d(image_size, 3, kernel_size=4,
                               stride=2, padding=1),
            nn.Tanh())
        # 注意:因为是黑白图像,所以只有一个输出通道

    def forward(self, z):
        out = self.layer1(z)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = self.layer6(out)
        out = self.last(out)

        return out
    
if __name__ == "__main__":
    import matplotlib.pyplot as plt

    G = Generator(z_dim=20, image_size=256)

    # 输入的随机数
    input_z = torch.randn(1, 20)

    # 将张量尺寸变形为(1,20,1,1)
    input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)

    #输出假图像
    fake_images = G(input_z)
    print(fake_images.shape)
    img_transformed = fake_images[0].detach().numpy().transpose(1, 2, 0)
    plt.imshow(img_transformed)
    plt.show()

1.1 构建判别器

python 复制代码
class Discriminator(nn.Module):

    def __init__(self, z_dim=20, image_size=256):
        super(Discriminator, self).__init__()

        self.layer1 = nn.Sequential(
            nn.Conv2d(3, image_size, kernel_size=4,
                      stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True))
       #注意:因为是黑白图像,所以输入通道只有一个

        self.layer2 = nn.Sequential(
            nn.Conv2d(image_size, image_size*2, kernel_size=4,
                      stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True))

        self.layer3 = nn.Sequential(
            nn.Conv2d(image_size*2, image_size*4, kernel_size=4,
                      stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True))

        self.layer4 = nn.Sequential(
            nn.Conv2d(image_size*4, image_size*8, kernel_size=4,
                      stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True))
        
        self.layer5 = nn.Sequential(
            nn.Conv2d(image_size*8, image_size*16, kernel_size=4,
                      stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True))
        
        self.layer6 = nn.Sequential(
            nn.Conv2d(image_size*16, image_size*32, kernel_size=4,
                      stride=2, padding=1),
            nn.LeakyReLU(0.1, inplace=True))
        
        self.last = nn.Conv2d(image_size*32, 1, kernel_size=4, stride=1)

    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.layer5(out)
        out = self.layer6(out)
        out = self.last(out)

        return out

    
if __name__ == "__main__":
    #确认程序执行
    D = Discriminator(z_dim=20, image_size=64)

    #生成伪造图像
    input_z = torch.randn(1, 20)
    input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)
    fake_images = G(input_z)

    #将伪造的图像输入判别器D中
    d_out = D(fake_images)

    #将输出值d_out乘以Sigmoid函数,将其转换成0~1的值
    print(torch.sigmoid(d_out))

2、数据集构建

python 复制代码
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), ".")))
import time
from PIL import Image
import torch
import torch.utils.data as data
import torch.nn as nn

from torchvision import transforms
from model.DCGAN import Generator, Discriminator
from matplotlib import pyplot as plt


def make_datapath_list(root):
    """创建用于学习和验证的图像数据及标注数据的文件路径列表。 """

    train_img_list = list() #保存图像文件的路径

    for img_idx in range(200):
        img_path = f"{root}/img_7_{str(img_idx)}.jpg"
        train_img_list.append(img_path)

        img_path = f"{root}/img_8_{str(img_idx)}.jpg"
        train_img_list.append(img_path)

    return train_img_list


class ImageTransform:
    """图像的预处理类"""

    def __init__(self, mean, std):
        self.data_transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize(mean, std)]
        )

    def __call__(self, img):
        return self.data_transform(img)


class GAN_Img_Dataset(data.Dataset):
    """图像的 Dataset 类,继承自 PyTorchd 的 Dataset 类"""

    def __init__(self, file_list, transform):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        '''返回图像的张数'''
        return len(self.file_list)

    def __getitem__(self, index):
        '''获取经过预处理后的图像的张量格式的数据'''

        img_path = self.file_list[index]
        img = Image.open(img_path)  # [ 高度 ][ 宽度 ] 黑白

        # 图像的预处理
        img_transformed = self.transform(img)

        return img_transformed


# 创建DataLoader并确认执行结果

# 创建文件列表
root = "./img_78"
train_img_list = make_datapath_list(root)

# 创建Dataset
mean = (0.5)
std = (0.5)
train_dataset = GAN_Img_Dataset(
    file_list=train_img_list, transform=ImageTransform(mean, std)
)

# 创建DataLoader
batch_size = 2
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=batch_size, shuffle=True
)

# 确认执行结果
batch_iterator = iter(train_dataloader)  # 转换为迭代器
imges = next(batch_iterator)  # 取出位于第一位的元素
print(imges.size())  # torch.Size([64, 1, 64, 64])

数据请在访问链接获取:

3、train接口实现

python 复制代码
def train_model(G, D, dataloader, num_epochs):
    # 确认是否能够使用GPU加速
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("使用设备:", device)

    # 设置最优化算法
    g_lr, d_lr = 0.0001, 0.0004
    beta1, beta2 = 0.0, 0.9
    g_optimizer = torch.optim.Adam(G.parameters(), g_lr, [beta1, beta2])
    d_optimizer = torch.optim.Adam(D.parameters(), d_lr, [beta1, beta2])

    # 定义误差函数
    criterion = nn.BCEWithLogitsLoss(reduction='mean')

    # 使用硬编码的参数
    z_dim = 20
    mini_batch_size = 8

    # 将网络载入GPU中
    G.to(device)
    D.to(device)

    G.train()  # 将模式设置为训练模式
    D.train()  # 将模式设置为训练模式

    # 如果网络相对固定,则开启加速
    torch.backends.cudnn.benchmark = True

    # 图像张数
    num_train_imgs = len(dataloader.dataset)
    batch_size = dataloader.batch_size

    # 设置迭代计数器
    iteration = 1
    logs = []

    # epoch循环
    for epoch in range(num_epochs):
        # 保存开始时间
        t_epoch_start = time.time()
        epoch_g_loss = 0.0  # epoch的损失总和
        epoch_d_loss = 0.0  # epoch的损失总和

        print('-------------')
        print('Epoch {}/{}'.format(epoch, num_epochs))
        print('-------------')
        print('(train)')

        # 以minibatch为单位从数据加载器中读取数据的循环
        for imges in dataloader:
            # --------------------
            # 1.判别器D的学习
            # --------------------
            # 如果小批次的尺寸设置为1,会导致批次归一化处理产生错误,因此需要避免
            if imges.size()[0] == 1:
                continue

            # 如果能使用GPU,则将数据送入GPU中
            imges = imges.to(device)

            # 创建正确答案标签和伪造数据标签
            # 在epoch最后的迭代中,小批次的数量会减少
            mini_batch_size = imges.size()[0]
            label_real = torch.full((mini_batch_size,), 1).to(device)
            label_fake = torch.full((mini_batch_size,), 0).to(device)

            # 对真正的图像进行判定
            d_out_real = D(imges)

            # 生成伪造图像并进行判定
            input_z = torch.randn(mini_batch_size, z_dim).to(device)
            input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)
            fake_images = G(input_z)
            d_out_fake = D(fake_images)

            # 计算误差
            d_loss_real = criterion(d_out_real.view(-1), label_real.to(torch.float))
            d_loss_fake = criterion(d_out_fake.view(-1), label_fake.to(torch.float))
            d_loss = d_loss_real + d_loss_fake

            # 反向传播处理
            g_optimizer.zero_grad()
            d_optimizer.zero_grad()

            d_loss.backward()
            d_optimizer.step()

            # --------------------
            # 2.生成器G的学习
            # --------------------
            # 生成伪造图像并进行判定
            input_z = torch.randn(mini_batch_size, z_dim).to(device)
            input_z = input_z.view(input_z.size(0), input_z.size(1), 1, 1)
            fake_images = G(input_z)
            d_out_fake = D(fake_images)

            # 计算误差
            g_loss = criterion(d_out_fake.view(-1), label_real.to(torch.float))

            # 反向传播处理
            g_optimizer.zero_grad()
            d_optimizer.zero_grad()
            g_loss.backward()
            g_optimizer.step()

            # --------------------
            # 3.记录结果
            # --------------------
            epoch_d_loss += d_loss.item()
            epoch_g_loss += g_loss.item()
            iteration += 1

        # epoch的每个phase的loss和准确率
        t_epoch_finish = time.time()
        print('-------------')
        print(
            'epoch {} || Epoch_D_Loss:{:.4f} ||Epoch_G_Loss:{:.4f}'.format(
                epoch, epoch_d_loss / batch_size, epoch_g_loss / batch_size
            )
        )
        print('timer:  {:.4f} sec.'.format(t_epoch_finish - t_epoch_start))
        t_epoch_start = time.time()

    return G, D

4、训练

python 复制代码
G = Generator(z_dim=20, image_size=64)
D = Discriminator(z_dim=20, image_size=64)
# 定义误差函数
criterion = nn.BCEWithLogitsLoss(reduction='mean')
num_epochs = 200
G_update, D_update = train_model(
    G, D, dataloader=train_dataloader, num_epochs=num_epochs
)

5、测试

python 复制代码
# 将生成的图像和训练数据可视化
# 反复执行本单元中的代码,直到生成感觉良好的图像为止

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 生成用于输入的随机数
batch_size = 8
z_dim = 20
fixed_z = torch.randn(batch_size, z_dim)
fixed_z = fixed_z.view(fixed_z.size(0), fixed_z.size(1), 1, 1)

# 生成图像
fake_images = G_update(fixed_z.to(device))

# 训练数据
imges = next(iter(train_dataloader))  # 取出位于第一位的元素

# 输出结果
fig = plt.figure(figsize=(15, 6))
for i in range(0, 5):
    # 将训练数据放入上层
    plt.subplot(2, 5, i + 1)
    plt.imshow(imges[i][0].cpu().detach().numpy(), 'gray')

    # 将生成数据放入下层
    plt.subplot(2, 5, 5 + i + 1)
    plt.imshow(fake_images[i][0].cpu().detach().numpy(), 'gray')
相关推荐
CryptoPP4 分钟前
使用WebSocket实时获取印度股票数据源(无调用次数限制)实战
后端·python·websocket·网络协议·区块链
树叶@5 分钟前
Python数据分析7
开发语言·python
老胖闲聊1 小时前
Python Rio 【图像处理】库简介
开发语言·图像处理·python
码界奇点1 小时前
Python Flask文件处理与异常处理实战指南
开发语言·python·自然语言处理·flask·python3.11
浠寒AI2 小时前
智能体模式篇(上)- 深入 ReAct:LangGraph构建能自主思考与行动的 AI
人工智能·python
weixin_505154462 小时前
数字孪生在建设智慧城市中可以起到哪些作用或帮助?
大数据·人工智能·智慧城市·数字孪生·数据可视化
Best_Me072 小时前
深度学习模块缝合
人工智能·深度学习
YuTaoShao2 小时前
【论文阅读】YOLOv8在单目下视多车目标检测中的应用
人工智能·yolo·目标检测
行云流水剑3 小时前
【学习记录】如何使用 Python 提取 PDF 文件中的内容
python·学习·pdf
算家计算3 小时前
字节开源代码模型——Seed-Coder 本地部署教程,模型自驱动数据筛选,让每行代码都精准落位!
人工智能·开源