系列文章目录
GAN生成对抗网络介绍https://blog.csdn.net/m0_58941767/article/details/142704354?spm=1001.2014.3001.5501
循环生成对抗网络------CycleGANhttps://blog.csdn.net/m0_58941767/article/details/142704671?spm=1001.2014.3001.5501
目录
前言
EnlightenGAN是一种用于低照度图像增强的无监督生成对抗网络。它能够在没有成对训练数据的情况下,通过利用输入图像本身的信息来进行自我正则化,从而实现图像的增强。这种方法特别适用于那些难以获取大量成对低照度和正常光照图像的场景。
**一、**EnlightenGAN的主要特点
-
无监督训练:EnlightenGAN不需要成对的低照度和正常光照图像来训练,这使得它能够更容易地适应真实世界的图像增强任务。
-
生成器结构:它采用了一个带有自注意力机制的U-Net生成器,这种结构有助于增强图像的局部细节,同时保持整体的光照平衡。
-
双判别器结构:EnlightenGAN使用了全局和局部判别器来平衡图像的全局和局部增强。全局判别器关注整体光照差异,而局部判别器则关注图像的细节特征。
-
自正则化感知损失:为了在没有成对数据的情况下保持图像内容的特征,EnlightenGAN引入了自特征保持损失,这有助于在增强过程中保持图像的纹理和结构。
-
自正则注意机制:通过利用低照度输入图像的光照信息作为自正则化注意力图,EnlightenGAN能够在不依赖外部监督的情况下,指导学习过程。
-
灵活性和适应性:由于其无监督设置,EnlightenGAN可以很容易地适应于增强来自不同领域的现实世界低照度图像。
-
实验结果:通过广泛的实验,EnlightenGAN在多种评价指标下的表现均优于现有的方法,包括视觉质量、无参考图像质量评估和人类主观研究。
二、EnlightenGAN网络结构
网络结构 = 生成器(带自注意力机制的U-Net)+ 判别器(全局-局部鉴别器)
三、核心代码实现
1、train.py
python
import time
import torch
from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
def get_config(config):
import yaml
with open(config, 'r') as stream:
return yaml.safe_load(stream)
def main():
opt = TrainOptions().parse()
opt.mode = 'train'
config = get_config(opt.config)
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
dataset_size = len(data_loader)
print('#training images = %d' % dataset_size)
model = create_model(opt)
visualizer = Visualizer(opt)
total_steps = 0
for epoch in range(1, opt.niter + opt.niter_decay + 1):
epoch_start_time = time.time()
for i, data in enumerate(dataset):
iter_start_time = time.time()
total_steps += opt.batchSize
epoch_iter = total_steps - dataset_size * (epoch - 1)
model.set_input(data)
model.optimize_parameters(epoch)
if total_steps % opt.display_freq == 0:
visualizer.display_current_results(model.get_current_visuals(), epoch)
if total_steps % opt.print_freq == 0:
errors = model.get_current_errors(epoch)
t = (time.time() - iter_start_time) / opt.batchSize
visualizer.print_current_errors(epoch, epoch_iter, errors, t)
if opt.display_id > 0:
visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)
if total_steps % opt.save_latest_freq == 0:
print('saving the latest model (epoch %d, total_steps %d)' %
(epoch, total_steps))
model.save('latest')
if epoch % opt.save_epoch_freq == 0:
print('saving the model at the end of epoch %d, iters %d' %
(epoch, total_steps))
model.save('latest')
model.save(epoch)
print('End of epoch %d / %d \t Time Taken: %d sec' %
(epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))
if opt.new_lr:
if epoch == opt.niter:
model.update_learning_rate()
elif epoch == (opt.niter + 20):
model.update_learning_rate()
elif epoch == (opt.niter + 70):
model.update_learning_rate()
elif epoch == (opt.niter + 90):
model.update_learning_rate()
model.update_learning_rate()
model.update_learning_rate()
model.update_learning_rate()
else:
if epoch > opt.niter:
model.update_learning_rate()
if __name__ == '__main__':
main()
2、predect.py
python
import time
import os
import torch
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
from pdb import set_trace as st
from util import html
def main():
opt = TestOptions().parse()
opt.nThreads = 1 # test code only supports nThreads = 1
opt.batchSize = 1 # test code only supports batchSize = 1
opt.serial_batches = True # no shuffle
opt.no_flip = True # no flip
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
if(len(dataset)==0):
raise ValueError("Dataset is empty. Please check your data path and data loader configuration.")
model = create_model(opt)
visualizer = Visualizer(opt)
# create website
web_dir = os.path.join("./ablation/", opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
# test
print(len(dataset))
with torch.no_grad():
for i, data in enumerate(dataset):
model.set_input(data)
visuals = model.predict()
img_path = model.get_image_paths()
print('process image... %s' % img_path)
visualizer.save_images(webpage, visuals, img_path)
webpage.save()
if __name__ == '__main__':
main()
四、推理结果
可见效果还是不错的。
五、调试好的源码
由于直接下载GitHub上的源码需要进行调试才可以运行,调试过程比较麻烦,我这边提供给大家一个我调试好的源码,方便新手进行训练、测试。
大家扫码关注公众号,回复关键字EnlightenGAN源码即可获取。