关于
近年来,基于卷积网络(CNN)的监督学习已经 在计算机视觉应用中得到了广泛的采用。相比之下,无监督 使用 CNN 进行学习受到的关注较少。在这项工作中,我们希望能有所帮助 缩小了 CNN 在监督学习和无监督学习方面的成功之间的差距。我们介绍一类称为深度卷积生成的 CNN 对抗性网络(DCGAN),具有一定的架构限制,以及 证明他们是无监督学习的有力候选人。训练 在各种图像数据集上,我们展示了令人信服的证据,表明我们的深度卷积对抗对学习了从对象部分到 生成器和鉴别器中的场景。此外,我们使用学到的 新任务的特征 - 证明它们作为一般图像表示的适用性。(https://arxiv.org/pdf/1511.06434.pdf)
工具
数据集
方法实现
加载必要的库函数和自定义函数
python
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image
python
def get_sample_image(G, n_noise):
"""
save sample 100 images
"""
z = torch.randn(100, n_noise).to(DEVICE)
y_hat = G(z).view(100, 28, 28) # (100, 28, 28)
result = y_hat.cpu().data.numpy()
img = np.zeros([280, 280])
for j in range(10):
img[j*28:(j+1)*28] = np.concatenate([x for x in result[j*10:(j+1)*10]], axis=-1)
return img
定义判别模型
python
class Discriminator(nn.Module):
"""
Convolutional Discriminator for MNIST
"""
def __init__(self, in_channel=1, num_classes=1):
super(Discriminator, self).__init__()
self.conv = nn.Sequential(
# 28 -> 14
nn.Conv2d(in_channel, 512, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2),
# 14 -> 7
nn.Conv2d(512, 256, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
# 7 -> 4
nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
nn.AvgPool2d(4),
)
self.fc = nn.Sequential(
# reshape input, 128 -> 1
nn.Linear(128, 1),
nn.Sigmoid(),
)
def forward(self, x, y=None):
y_ = self.conv(x)
y_ = y_.view(y_.size(0), -1)
y_ = self.fc(y_)
return y_
定义生成模型
python
class Generator(nn.Module):
"""
Convolutional Generator for MNIST
"""
def __init__(self, input_size=100, num_classes=784):
super(Generator, self).__init__()
self.fc = nn.Sequential(
nn.Linear(input_size, 4*4*512),
nn.ReLU(),
)
self.conv = nn.Sequential(
# input: 4 by 4, output: 7 by 7
nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(),
# input: 7 by 7, output: 14 by 14
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(),
# input: 14 by 14, output: 28 by 28
nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1, bias=False),
nn.Tanh(),
)
def forward(self, x, y=None):
x = x.view(x.size(0), -1)
y_ = self.fc(x)
y_ = y_.view(y_.size(0), 512, 4, 4)
y_ = self.conv(y_)
return y_
模型超参数定义配置
python
batch_size = 64
criterion = nn.BCELoss()
D_opt = torch.optim.Adam(D.parameters(), lr=0.001, betas=(0.5, 0.999))
G_opt = torch.optim.Adam(G.parameters(), lr=0.001, betas=(0.5, 0.999))
max_epoch = 30 # need more than 20 epochs for training generator
step = 0
n_critic = 1 # for training more k steps about Discriminator
n_noise = 100
D_labels = torch.ones([batch_size, 1]).to(DEVICE) # Discriminator Label to real
D_fakes = torch.zeros([batch_size, 1]).to(DEVICE) # Discriminator Label to fake
模型训练
python
for epoch in range(max_epoch):
for idx, (images, labels) in enumerate(data_loader):
# Training Discriminator
x = images.to(DEVICE)
x_outputs = D(x)
D_x_loss = criterion(x_outputs, D_labels)
z = torch.randn(batch_size, n_noise).to(DEVICE)
z_outputs = D(G(z))
D_z_loss = criterion(z_outputs, D_fakes)
D_loss = D_x_loss + D_z_loss
D.zero_grad()
D_loss.backward()
D_opt.step()
if step % n_critic == 0:
# Training Generator
z = torch.randn(batch_size, n_noise).to(DEVICE)
z_outputs = D(G(z))
G_loss = criterion(z_outputs, D_labels)
D.zero_grad()
G.zero_grad()
G_loss.backward()
G_opt.step()
if step % 500 == 0:
print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch, max_epoch, step, D_loss.item(), G_loss.item()))
if step % 1000 == 0:
G.eval()
img = get_sample_image(G, n_noise)
imsave('./{}_step{}.jpg'.format(MODEL_NAME, str(step).zfill(3)), img, cmap='gray')
G.train()
step += 1
测试生成效果
python
# generation to image
G.eval()
imshow(get_sample_image(G, n_noise), cmap='gray')
模型和状态参量保存
python
def save_checkpoint(state, file_name='checkpoint.pth.tar'):
torch.save(state, file_name)
# Saving params.
# torch.save(D.state_dict(), 'D_c.pkl')
# torch.save(G.state_dict(), 'G_c.pkl')
save_checkpoint({'epoch': epoch + 1, 'state_dict':D.state_dict(), 'optimizer' : D_opt.state_dict()}, 'D_dc.pth.tar')
save_checkpoint({'epoch': epoch + 1, 'state_dict':G.state_dict(), 'optimizer' : G_opt.state_dict()}, 'G_dc.pth.tar')
应用
DCGAN作为一个成熟的生成模型,在自然图像,医学图像,医学电生理信号数据分析中,都可以用来实现数据的合成,达到数据增强的目的,同时,如何减少增强数据对于后端任务的不利干扰,也是一个需要关注的方面。