低照度图像增强网络——EnlightenGAN

系列文章目录

GAN生成对抗网络介绍https://blog.csdn.net/m0_58941767/article/details/142704354?spm=1001.2014.3001.5501

循环生成对抗网络------CycleGANhttps://blog.csdn.net/m0_58941767/article/details/142704671?spm=1001.2014.3001.5501


目录

系列文章目录

前言

一、EnlightenGAN的主要特点

二、EnlightenGAN网络结构

三、核心代码实现

1、train.py

2、predect.py

四、推理结果

五、调试好的源码


前言

EnlightenGAN是一种用于低照度图像增强的无监督生成对抗网络。它能够在没有成对训练数据的情况下,通过利用输入图像本身的信息来进行自我正则化,从而实现图像的增强。这种方法特别适用于那些难以获取大量成对低照度和正常光照图像的场景。


**一、**EnlightenGAN的主要特点

  1. 无监督训练:EnlightenGAN不需要成对的低照度和正常光照图像来训练,这使得它能够更容易地适应真实世界的图像增强任务。

  2. 生成器结构:它采用了一个带有自注意力机制的U-Net生成器,这种结构有助于增强图像的局部细节,同时保持整体的光照平衡。

  3. 双判别器结构:EnlightenGAN使用了全局和局部判别器来平衡图像的全局和局部增强。全局判别器关注整体光照差异,而局部判别器则关注图像的细节特征。

  4. 自正则化感知损失:为了在没有成对数据的情况下保持图像内容的特征,EnlightenGAN引入了自特征保持损失,这有助于在增强过程中保持图像的纹理和结构。

  5. 自正则注意机制:通过利用低照度输入图像的光照信息作为自正则化注意力图,EnlightenGAN能够在不依赖外部监督的情况下,指导学习过程。

  6. 灵活性和适应性:由于其无监督设置,EnlightenGAN可以很容易地适应于增强来自不同领域的现实世界低照度图像。

  7. 实验结果:通过广泛的实验,EnlightenGAN在多种评价指标下的表现均优于现有的方法,包括视觉质量、无参考图像质量评估和人类主观研究。

二、EnlightenGAN网络结构

网络结构 = 生成器(带自注意力机制的U-Net)+ 判别器(全局-局部鉴别器)


三、核心代码实现

1、train.py

python 复制代码
import time

import torch

from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
def get_config(config):
    import yaml
    with open(config, 'r') as stream:
        return yaml.safe_load(stream)

def main():
    opt = TrainOptions().parse()
    opt.mode = 'train'
    config = get_config(opt.config)
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    model = create_model(opt)
    visualizer = Visualizer(opt)

    total_steps = 0

    for epoch in range(1, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        for i, data in enumerate(dataset):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter = total_steps - dataset_size * (epoch - 1)
            model.set_input(data)
            model.optimize_parameters(epoch)

            if total_steps % opt.display_freq == 0:
                visualizer.display_current_results(model.get_current_visuals(), epoch)

            if total_steps % opt.print_freq == 0:
                errors = model.get_current_errors(epoch)
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                if opt.display_id > 0:
                    visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)

            if total_steps % opt.save_latest_freq == 0:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.save('latest')

        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save('latest')
            model.save(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

        if opt.new_lr:
            if epoch == opt.niter:
                model.update_learning_rate()
            elif epoch == (opt.niter + 20):
                model.update_learning_rate()
            elif epoch == (opt.niter + 70):
                model.update_learning_rate()
            elif epoch == (opt.niter + 90):
                model.update_learning_rate()
                model.update_learning_rate()
                model.update_learning_rate()
                model.update_learning_rate()
        else:
            if epoch > opt.niter:
                model.update_learning_rate()

if __name__ == '__main__':
    main()

2、predect.py

python 复制代码
import time
import os

import torch

from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
from pdb import set_trace as st
from util import html

def main():
    opt = TestOptions().parse()
    opt.nThreads = 1   # test code only supports nThreads = 1
    opt.batchSize = 1  # test code only supports batchSize = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True  # no flip

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    if(len(dataset)==0):
        raise ValueError("Dataset is empty. Please check your data path and data loader configuration.")
    model = create_model(opt)
    visualizer = Visualizer(opt)
    # create website
    web_dir = os.path.join("./ablation/", opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
    # test
    print(len(dataset))
    with torch.no_grad():
        for i, data in enumerate(dataset):
            model.set_input(data)
            visuals = model.predict()
            img_path = model.get_image_paths()
            print('process image... %s' % img_path)
            visualizer.save_images(webpage, visuals, img_path)
        webpage.save()

if __name__ == '__main__':
    main()

四、推理结果

可见效果还是不错的。

五、调试好的源码

由于直接下载GitHub上的源码需要进行调试才可以运行,调试过程比较麻烦,我这边提供给大家一个我调试好的源码,方便新手进行训练、测试。

大家扫码关注公众号,回复关键字EnlightenGAN源码即可获取。

相关推荐
Light601 分钟前
智启未来:深度解析Python Transformers库及其应用场景
开发语言·python·深度学习·自然语言处理·预训练模型·transformers库 |·|应用场景
爱的叹息4 分钟前
DeepSeek 大模型 + LlamaIndex + MySQL 数据库 + 知识文档 实现简单 RAG 系统
数据库·人工智能·mysql·langchain
数据智能老司机11 分钟前
构建具备自主性的人工智能系统——在生成式人工智能系统中构建信任
深度学习·llm·aigc
PeterOne14 分钟前
Trae MCP + Obsidian 集成如何缓解开发者的时间损耗
人工智能·trae
sduwcgg44 分钟前
kaggle配置
人工智能·python·机器学习
DolphinScheduler社区1 小时前
白鲸开源与亚马逊云科技携手推动AI-Ready数据架构创新
人工智能·科技·开源·aws·白鲸开源·whalestudio
欣然~1 小时前
借助 OpenCV 和 PyTorch 库,利用卷积神经网络提取图像边缘特征
人工智能·计算机视觉
谦行1 小时前
工欲善其事,必先利其器—— PyTorch 深度学习基础操作
pytorch·深度学习·ai编程
xwz小王子2 小时前
Nature Communications 面向形状可编程磁性软材料的数据驱动设计方法—基于随机设计探索与神经网络的协同优化框架
深度学习
白熊1882 小时前
【计算机视觉】CV实战项目 - 基于YOLOv5的人脸检测与关键点定位系统深度解析
人工智能·yolo·计算机视觉