低照度图像增强网络——EnlightenGAN

系列文章目录

GAN生成对抗网络介绍https://blog.csdn.net/m0_58941767/article/details/142704354?spm=1001.2014.3001.5501

循环生成对抗网络------CycleGANhttps://blog.csdn.net/m0_58941767/article/details/142704671?spm=1001.2014.3001.5501


目录

系列文章目录

前言

一、EnlightenGAN的主要特点

二、EnlightenGAN网络结构

三、核心代码实现

1、train.py

2、predect.py

四、推理结果

五、调试好的源码


前言

EnlightenGAN是一种用于低照度图像增强的无监督生成对抗网络。它能够在没有成对训练数据的情况下,通过利用输入图像本身的信息来进行自我正则化,从而实现图像的增强。这种方法特别适用于那些难以获取大量成对低照度和正常光照图像的场景。


**一、**EnlightenGAN的主要特点

  1. 无监督训练:EnlightenGAN不需要成对的低照度和正常光照图像来训练,这使得它能够更容易地适应真实世界的图像增强任务。

  2. 生成器结构:它采用了一个带有自注意力机制的U-Net生成器,这种结构有助于增强图像的局部细节,同时保持整体的光照平衡。

  3. 双判别器结构:EnlightenGAN使用了全局和局部判别器来平衡图像的全局和局部增强。全局判别器关注整体光照差异,而局部判别器则关注图像的细节特征。

  4. 自正则化感知损失:为了在没有成对数据的情况下保持图像内容的特征,EnlightenGAN引入了自特征保持损失,这有助于在增强过程中保持图像的纹理和结构。

  5. 自正则注意机制:通过利用低照度输入图像的光照信息作为自正则化注意力图,EnlightenGAN能够在不依赖外部监督的情况下,指导学习过程。

  6. 灵活性和适应性:由于其无监督设置,EnlightenGAN可以很容易地适应于增强来自不同领域的现实世界低照度图像。

  7. 实验结果:通过广泛的实验,EnlightenGAN在多种评价指标下的表现均优于现有的方法,包括视觉质量、无参考图像质量评估和人类主观研究。

二、EnlightenGAN网络结构

网络结构 = 生成器(带自注意力机制的U-Net)+ 判别器(全局-局部鉴别器)


三、核心代码实现

1、train.py

python 复制代码
import time

import torch

from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
def get_config(config):
    import yaml
    with open(config, 'r') as stream:
        return yaml.safe_load(stream)

def main():
    opt = TrainOptions().parse()
    opt.mode = 'train'
    config = get_config(opt.config)
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    model = create_model(opt)
    visualizer = Visualizer(opt)

    total_steps = 0

    for epoch in range(1, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        for i, data in enumerate(dataset):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter = total_steps - dataset_size * (epoch - 1)
            model.set_input(data)
            model.optimize_parameters(epoch)

            if total_steps % opt.display_freq == 0:
                visualizer.display_current_results(model.get_current_visuals(), epoch)

            if total_steps % opt.print_freq == 0:
                errors = model.get_current_errors(epoch)
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                if opt.display_id > 0:
                    visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)

            if total_steps % opt.save_latest_freq == 0:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.save('latest')

        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save('latest')
            model.save(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

        if opt.new_lr:
            if epoch == opt.niter:
                model.update_learning_rate()
            elif epoch == (opt.niter + 20):
                model.update_learning_rate()
            elif epoch == (opt.niter + 70):
                model.update_learning_rate()
            elif epoch == (opt.niter + 90):
                model.update_learning_rate()
                model.update_learning_rate()
                model.update_learning_rate()
                model.update_learning_rate()
        else:
            if epoch > opt.niter:
                model.update_learning_rate()

if __name__ == '__main__':
    main()

2、predect.py

python 复制代码
import time
import os

import torch

from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
from pdb import set_trace as st
from util import html

def main():
    opt = TestOptions().parse()
    opt.nThreads = 1   # test code only supports nThreads = 1
    opt.batchSize = 1  # test code only supports batchSize = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True  # no flip

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    if(len(dataset)==0):
        raise ValueError("Dataset is empty. Please check your data path and data loader configuration.")
    model = create_model(opt)
    visualizer = Visualizer(opt)
    # create website
    web_dir = os.path.join("./ablation/", opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
    # test
    print(len(dataset))
    with torch.no_grad():
        for i, data in enumerate(dataset):
            model.set_input(data)
            visuals = model.predict()
            img_path = model.get_image_paths()
            print('process image... %s' % img_path)
            visualizer.save_images(webpage, visuals, img_path)
        webpage.save()

if __name__ == '__main__':
    main()

四、推理结果

可见效果还是不错的。

五、调试好的源码

由于直接下载GitHub上的源码需要进行调试才可以运行,调试过程比较麻烦,我这边提供给大家一个我调试好的源码,方便新手进行训练、测试。

大家扫码关注公众号,回复关键字EnlightenGAN源码即可获取。

相关推荐
Watermelo6176 分钟前
通过MongoDB Atlas 实现语义搜索与 RAG——迈向AI的搜索机制
人工智能·深度学习·神经网络·mongodb·机器学习·自然语言处理·数据挖掘
AI算法-图哥18 分钟前
pytorch量化训练
人工智能·pytorch·深度学习·文生图·模型压缩·量化
大山同学20 分钟前
DPGO:异步和并行分布式位姿图优化 2020 RA-L best paper
人工智能·分布式·语言模型·去中心化·slam·感知定位
机器学习之心21 分钟前
时序预测 | 改进图卷积+informer时间序列预测,pytorch架构
人工智能·pytorch·python·时间序列预测·informer·改进图卷积
天飓1 小时前
基于OpenCV的自制Python访客识别程序
人工智能·python·opencv
檀越剑指大厂1 小时前
开源AI大模型工作流神器Flowise本地部署与远程访问
人工智能·开源
声网1 小时前
「人眼视觉不再是视频消费的唯一形式」丨智能编解码和 AI 视频生成专场回顾@RTE2024
人工智能·音视频
newxtc1 小时前
【AiPPT-注册/登录安全分析报告-无验证方式导致安全隐患】
人工智能·安全·ai写作·极验·行为验证
技术仔QAQ1 小时前
【tokenization分词】WordPiece, Byte-Pair Encoding(BPE), Byte-level BPE(BBPE)的原理和代码
人工智能·python·gpt·语言模型·自然语言处理·开源·nlp
神一样的老师1 小时前
去中心化联邦学习与TinyML联合调查:群学习简介
机器学习