低照度图像增强网络——EnlightenGAN

系列文章目录

GAN生成对抗网络介绍https://blog.csdn.net/m0_58941767/article/details/142704354?spm=1001.2014.3001.5501

循环生成对抗网络------CycleGANhttps://blog.csdn.net/m0_58941767/article/details/142704671?spm=1001.2014.3001.5501


目录

系列文章目录

前言

一、EnlightenGAN的主要特点

二、EnlightenGAN网络结构

三、核心代码实现

1、train.py

2、predect.py

四、推理结果

五、调试好的源码


前言

EnlightenGAN是一种用于低照度图像增强的无监督生成对抗网络。它能够在没有成对训练数据的情况下,通过利用输入图像本身的信息来进行自我正则化,从而实现图像的增强。这种方法特别适用于那些难以获取大量成对低照度和正常光照图像的场景。


**一、**EnlightenGAN的主要特点

  1. 无监督训练:EnlightenGAN不需要成对的低照度和正常光照图像来训练,这使得它能够更容易地适应真实世界的图像增强任务。

  2. 生成器结构:它采用了一个带有自注意力机制的U-Net生成器,这种结构有助于增强图像的局部细节,同时保持整体的光照平衡。

  3. 双判别器结构:EnlightenGAN使用了全局和局部判别器来平衡图像的全局和局部增强。全局判别器关注整体光照差异,而局部判别器则关注图像的细节特征。

  4. 自正则化感知损失:为了在没有成对数据的情况下保持图像内容的特征,EnlightenGAN引入了自特征保持损失,这有助于在增强过程中保持图像的纹理和结构。

  5. 自正则注意机制:通过利用低照度输入图像的光照信息作为自正则化注意力图,EnlightenGAN能够在不依赖外部监督的情况下,指导学习过程。

  6. 灵活性和适应性:由于其无监督设置,EnlightenGAN可以很容易地适应于增强来自不同领域的现实世界低照度图像。

  7. 实验结果:通过广泛的实验,EnlightenGAN在多种评价指标下的表现均优于现有的方法,包括视觉质量、无参考图像质量评估和人类主观研究。

二、EnlightenGAN网络结构

网络结构 = 生成器(带自注意力机制的U-Net)+ 判别器(全局-局部鉴别器)


三、核心代码实现

1、train.py

python 复制代码
import time

import torch

from options.train_options import TrainOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True
def get_config(config):
    import yaml
    with open(config, 'r') as stream:
        return yaml.safe_load(stream)

def main():
    opt = TrainOptions().parse()
    opt.mode = 'train'
    config = get_config(opt.config)
    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    dataset_size = len(data_loader)
    print('#training images = %d' % dataset_size)

    model = create_model(opt)
    visualizer = Visualizer(opt)

    total_steps = 0

    for epoch in range(1, opt.niter + opt.niter_decay + 1):
        epoch_start_time = time.time()
        for i, data in enumerate(dataset):
            iter_start_time = time.time()
            total_steps += opt.batchSize
            epoch_iter = total_steps - dataset_size * (epoch - 1)
            model.set_input(data)
            model.optimize_parameters(epoch)

            if total_steps % opt.display_freq == 0:
                visualizer.display_current_results(model.get_current_visuals(), epoch)

            if total_steps % opt.print_freq == 0:
                errors = model.get_current_errors(epoch)
                t = (time.time() - iter_start_time) / opt.batchSize
                visualizer.print_current_errors(epoch, epoch_iter, errors, t)
                if opt.display_id > 0:
                    visualizer.plot_current_errors(epoch, float(epoch_iter)/dataset_size, opt, errors)

            if total_steps % opt.save_latest_freq == 0:
                print('saving the latest model (epoch %d, total_steps %d)' %
                      (epoch, total_steps))
                model.save('latest')

        if epoch % opt.save_epoch_freq == 0:
            print('saving the model at the end of epoch %d, iters %d' %
                  (epoch, total_steps))
            model.save('latest')
            model.save(epoch)

        print('End of epoch %d / %d \t Time Taken: %d sec' %
              (epoch, opt.niter + opt.niter_decay, time.time() - epoch_start_time))

        if opt.new_lr:
            if epoch == opt.niter:
                model.update_learning_rate()
            elif epoch == (opt.niter + 20):
                model.update_learning_rate()
            elif epoch == (opt.niter + 70):
                model.update_learning_rate()
            elif epoch == (opt.niter + 90):
                model.update_learning_rate()
                model.update_learning_rate()
                model.update_learning_rate()
                model.update_learning_rate()
        else:
            if epoch > opt.niter:
                model.update_learning_rate()

if __name__ == '__main__':
    main()

2、predect.py

python 复制代码
import time
import os

import torch

from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
from util.visualizer import Visualizer
from pdb import set_trace as st
from util import html

def main():
    opt = TestOptions().parse()
    opt.nThreads = 1   # test code only supports nThreads = 1
    opt.batchSize = 1  # test code only supports batchSize = 1
    opt.serial_batches = True  # no shuffle
    opt.no_flip = True  # no flip

    data_loader = CreateDataLoader(opt)
    dataset = data_loader.load_data()
    if(len(dataset)==0):
        raise ValueError("Dataset is empty. Please check your data path and data loader configuration.")
    model = create_model(opt)
    visualizer = Visualizer(opt)
    # create website
    web_dir = os.path.join("./ablation/", opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
    webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch))
    # test
    print(len(dataset))
    with torch.no_grad():
        for i, data in enumerate(dataset):
            model.set_input(data)
            visuals = model.predict()
            img_path = model.get_image_paths()
            print('process image... %s' % img_path)
            visualizer.save_images(webpage, visuals, img_path)
        webpage.save()

if __name__ == '__main__':
    main()

四、推理结果

可见效果还是不错的。

五、调试好的源码

由于直接下载GitHub上的源码需要进行调试才可以运行,调试过程比较麻烦,我这边提供给大家一个我调试好的源码,方便新手进行训练、测试。

大家扫码关注公众号,回复关键字EnlightenGAN源码即可获取。

相关推荐
喜欢吃豆20 分钟前
llama.cpp 全方位技术指南:从底层原理到实战部署
人工智能·语言模型·大模型·llama·量化·llama.cpp
e6zzseo1 小时前
独立站的优势和劣势和运营技巧
大数据·人工智能
富唯智能2 小时前
移动+协作+视觉:开箱即用的下一代复合机器人如何重塑智能工厂
人工智能·工业机器人·复合机器人
Antonio9153 小时前
【图像处理】图像的基础几何变换
图像处理·人工智能·计算机视觉
新加坡内哥谈技术4 小时前
Perplexity AI 的 RAG 架构全解析:幕后技术详解
人工智能
大大dxy大大4 小时前
机器学习实现逻辑回归-癌症分类预测
机器学习·分类·逻辑回归
武子康4 小时前
AI研究-119 DeepSeek-OCR PyTorch FlashAttn 2.7.3 推理与部署 模型规模与资源详细分析
人工智能·深度学习·机器学习·ai·ocr·deepseek·deepseek-ocr
Sirius Wu5 小时前
深入浅出:Tongyi DeepResearch技术解读
人工智能·语言模型·langchain·aigc
忙碌5446 小时前
AI大模型时代下的全栈技术架构:从深度学习到云原生部署实战
人工智能·深度学习·架构
LZ_Keep_Running6 小时前
智能变电巡检:AI检测新突破
人工智能