低照度图像增强网络——EnlightenGAN

系列文章目录

GAN生成对抗网络介绍https://blog.csdn.net/m0_58941767/article/details/142704354?spm=1001.2014.3001.5501

循环生成对抗网络------CycleGANhttps://blog.csdn.net/m0_58941767/article/details/142704671?spm=1001.2014.3001.5501


目录

系列文章目录

前言

一、EnlightenGAN的主要特点

二、EnlightenGAN网络结构

三、核心代码实现

1、train.py

2、predect.py

四、推理结果

五、调试好的源码


前言

EnlightenGAN是一种用于低照度图像增强的无监督生成对抗网络。它能够在没有成对训练数据的情况下,通过利用输入图像本身的信息来进行自我正则化,从而实现图像的增强。这种方法特别适用于那些难以获取大量成对低照度和正常光照图像的场景。


**一、**EnlightenGAN的主要特点

  1. 无监督训练:EnlightenGAN不需要成对的低照度和正常光照图像来训练,这使得它能够更容易地适应真实世界的图像增强任务。

  2. 生成器结构:它采用了一个带有自注意力机制的U-Net生成器,这种结构有助于增强图像的局部细节,同时保持整体的光照平衡。

  3. 双判别器结构:EnlightenGAN使用了全局和局部判别器来平衡图像的全局和局部增强。全局判别器关注整体光照差异,而局部判别器则关注图像的细节特征。

  4. 自正则化感知损失:为了在没有成对数据的情况下保持图像内容的特征,EnlightenGAN引入了自特征保持损失,这有助于在增强过程中保持图像的纹理和结构。

  5. 自正则注意机制:通过利用低照度输入图像的光照信息作为自正则化注意力图,EnlightenGAN能够在不依赖外部监督的情况下,指导学习过程。

  6. 灵活性和适应性:由于其无监督设置,EnlightenGAN可以很容易地适应于增强来自不同领域的现实世界低照度图像。

  7. 实验结果:通过广泛的实验,EnlightenGAN在多种评价指标下的表现均优于现有的方法,包括视觉质量、无参考图像质量评估和人类主观研究。

二、EnlightenGAN网络结构

网络结构 = 生成器(带自注意力机制的U-Net)+ 判别器(全局-局部鉴别器)


三、核心代码实现

1、train.py

python 复制代码
import time

import torch

from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
def get_config(config):
    import yaml
    with open(config, 'r') as stream:
        return yaml.safe_load(stream)

def main():
    opt = TrainOptions().parse()
    opt.mode = 'train'
    config = get_config(opt.config)
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    model = create_model(opt)
    visualizer = Visualizer(opt)

    total_steps = 0

    for epoch in range(1, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        for i, data in enumerate(dataset):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter = total_steps - dataset_size * (epoch - 1)
            model.set_input(data)
            model.optimize_parameters(epoch)

            if total_steps % opt.display_freq == 0:
                visualizer.display_current_results(model.get_current_visuals(), epoch)

            if total_steps % opt.print_freq == 0:
                errors = model.get_current_errors(epoch)
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                if opt.display_id > 0:
                    visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)

            if total_steps % opt.save_latest_freq == 0:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.save('latest')

        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save('latest')
            model.save(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

        if opt.new_lr:
            if epoch == opt.niter:
                model.update_learning_rate()
            elif epoch == (opt.niter + 20):
                model.update_learning_rate()
            elif epoch == (opt.niter + 70):
                model.update_learning_rate()
            elif epoch == (opt.niter + 90):
                model.update_learning_rate()
                model.update_learning_rate()
                model.update_learning_rate()
                model.update_learning_rate()
        else:
            if epoch > opt.niter:
                model.update_learning_rate()

if __name__ == '__main__':
    main()

2、predect.py

python 复制代码
import time
import os

import torch

from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
from pdb import set_trace as st
from util import html

def main():
    opt = TestOptions().parse()
    opt.nThreads = 1   # test code only supports nThreads = 1
    opt.batchSize = 1  # test code only supports batchSize = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True  # no flip

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    if(len(dataset)==0):
        raise ValueError("Dataset is empty. Please check your data path and data loader configuration.")
    model = create_model(opt)
    visualizer = Visualizer(opt)
    # create website
    web_dir = os.path.join("./ablation/", opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
    # test
    print(len(dataset))
    with torch.no_grad():
        for i, data in enumerate(dataset):
            model.set_input(data)
            visuals = model.predict()
            img_path = model.get_image_paths()
            print('process image... %s' % img_path)
            visualizer.save_images(webpage, visuals, img_path)
        webpage.save()

if __name__ == '__main__':
    main()

四、推理结果

可见效果还是不错的。

五、调试好的源码

由于直接下载GitHub上的源码需要进行调试才可以运行,调试过程比较麻烦,我这边提供给大家一个我调试好的源码,方便新手进行训练、测试。

大家扫码关注公众号,回复关键字EnlightenGAN源码即可获取。

相关推荐
机器之心23 分钟前
VAE时代终结?谢赛宁团队「RAE」登场,表征自编码器或成DiT训练新基石
人工智能·openai
机器之心24 分钟前
Sutton判定「LLM是死胡同」后,新访谈揭示AI困境
人工智能·openai
大模型真好玩28 分钟前
低代码Agent开发框架使用指南(四)—Coze大模型和插件参数配置最佳实践
人工智能·agent·coze
jerryinwuhan28 分钟前
基于大语言模型(LLM)的城市时间、空间与情感交织分析:面向智能城市的情感动态预测与空间优化
人工智能·语言模型·自然语言处理
落雪财神意42 分钟前
股指10月想法
大数据·人工智能·金融·区块链·期股
中杯可乐多加冰42 分钟前
无代码开发实践|基于业务流能力快速开发市场监管系统,实现投诉处理快速响应
人工智能·低代码
渣渣盟44 分钟前
解密NLP:从入门到精通
人工智能·python·nlp
新智元1 小时前
万亿级思考模型,蚂蚁首次开源!20 万亿 token 搅局开源 AI
人工智能·openai
算家计算1 小时前
阿里开源最强视觉模型家族轻量版:仅4B/8B参数,性能逼近72B旗舰版
人工智能·开源·资讯
MarkHD1 小时前
Dify从入门到精通 第16天 工作流进阶 - 分支与判断:构建智能路由客服机器人
人工智能·机器人