pytorch实现DCP暗通道先验去雾算法及其onnx导出

pytorch实现DCP暗通道先验去雾算法及其onnx导出

简介

最近在做图像去雾,于是在Pytorch上复现了一下dcp算法。暗通道先验去雾算法是大神何恺明2009年发表在CVPR上的一篇论文,还获得了当年的CVPR最佳论文。

实现

具体原理就不阐述了,网上的解析多的是,这里直接把用pytorch复现的代码贴出来:

python 复制代码
import torch


def dcp(img, omega=0.75):
    h, w = img.shape[2:]
    imsz = h * w
    # 要查找的是暗通道中前0.1%的值
    numpx = torch.clamp_min(imsz // 1000, 1)
    # 找到暗通道的索引,弄成[batch, 3, numpx],因为要匹配三个通道,所以需要expand
    dark = torch.min(img, dim=1, keepdim=True)[0]
    indices = torch.topk(dark.view(-1, imsz), k=numpx, dim=1)[1].view(-1, 1, numpx).expand(-1, 3, -1)
    # 用上述索引匹配原图中的3个通道,并求其平均值
    a = (torch.gather(img.view(-1, 3, imsz), 2, indices).sum(2) / numpx).view(-1, 3, 1, 1)
    
    # 代公式算tx
    tx =  1 - omega * torch.min(img / a.view(-1, 3, 1, 1), dim=1, keepdim=True)[0]
    # 代公式算jx
    return (img - a) / torch.clamp_min(tx, 0.1) + a

函数有两个参数:

  1. img:经归一化后的(N,C,H,W)布局的图像
  2. omega:DCP算法的一个参数ω,数值越大效果越强

如果想在模型训练时引入dcp算法,可以用nn.Module封装一下:

python 复制代码
class DCP(torch.nn.Module):
    def __init__(self, omega):
        self._omega = omega

    def forward(self, x):
        return dcp(x, self._omega)

ONNX导出

导出

既然能封装成Module,那么就顺便试了一下导出ONNX。

导出onnx需要安装onnx和onnxsim:

bash 复制代码
pip install onnx onnxsim

导出代码如下:

python 复制代码
import torch
import onnx
from onnxsim import simplify 


def dcp(img, omega=0.75):
    h, w = img.shape[2:]
    imsz = h * w
    # 要查找的是暗通道中前0.1%的值
    numpx = torch.clamp_min(imsz // 1000, 1)
    # 找到暗通道的索引,弄成[batch, 3, numpx],因为要匹配三个通道,所以需要expand
    dark = torch.min(img, dim=1, keepdim=True)[0]
    indices = torch.topk(dark.view(-1, imsz), k=numpx, dim=1)[1].view(-1, 1, numpx).expand(-1, 3, -1)
    # 用上述索引匹配原图中的3个通道,并求其平均值
    a = (torch.gather(img.view(-1, 3, imsz), 2, indices).sum(2) / numpx).view(-1, 3, 1, 1)
    
    # 代公式算tx
    tx =  1 - omega * torch.min(img / a.view(-1, 3, 1, 1), dim=1, keepdim=True)[0]
    # 代公式算jx
    return (img - a) / torch.clamp_min(tx, 0.1) + a

class DCPExport(torch.nn.Module):
    def forward(self, x, omega):
        return dcp(x, omega)

def export(output='dcp.onnx'):
    torch.onnx.export(
        DCPExport(), 
        (torch.randn(1, 3, 255, 255, dtype=torch.float32), torch.tensor(0.75, dtype=torch.float32)), 
        'dcp.onnx', 
        input_names=['fog_image', 'omega'], 
        output_names=['clear_image'], 
        dynamic_axes={
            'fog_image': {0: 'batch', 2: 'height', 3: 'width'},
            'clear_image': {0: 'batch', 2: 'height', 3: 'width'},
        }
    )
    onnx_model = onnx.load(output) 
    model_simp, check = simplify(onnx_model) 
    assert check, "简化模型失败" 
    onnx.save(model_simp, output) 

if __name__ == '__main__':
	export()

导出结果如下:

导出后的onnx输入输出如下:

  • 输入:
    1. fog_image[float32]:形状为NCHW,且归一化的有雾图像,其中通道数C必须为3
    2. omega[float32]:dcp的参数,类型为浮点数
  • 输出:
    1. clear_image[float32]:形状为NCHW,且归一化的无雾图像,其中通道数C为3

下载链接:https://pan.baidu.com/s/1A1jSJQBFCGTeM8vbHOrysQ?pwd=tl6p

测试

用cv2和pil都可以:

python 复制代码
import numpy as np
import cv2
from PIL import Image
from onnxruntime import InferenceSession


model = InferenceSession('dcp.onnx')

# CV2读图
image = cv2.imread('dehaze/dehaze/input/images/indoor1.jpg')
# 这里说明一下,因为dcp对所有通道进行同等变换,所以不用bgr和rgb互转了,出来的结果都是一样的
# x = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
x = np.transpose(image, (2, 0, 1))[None].astype(np.float32) / 255.
res = model.run(['clear_image'], {'fog_image': x, 'omega': np.array(0.75, dtype=np.float32)})[0][0]
res = np.transpose(res, (1, 2, 0))
res = np.clip(res*255+0.5, 0, 255).astype(np.uint8)
# res = cv2.cvtColor(res, cv2.COLOR_RGB2BGR)
cv2.imwrite('onnx-cv.png', np.concatenate((image, res), 1))

# PIL读图
image = Image.open('dehaze/dehaze/input/images/indoor1.jpg')
x = np.transpose(image, (2, 0, 1))[None].astype(np.float32) / 255.
res = model.run(None, {'fog_image': x, 'omega': np.array(0.75, dtype=np.float32)})[0][0]
res = np.transpose(res, (1, 2, 0))
res = np.clip(res*255+0.5, 0, 255).astype(np.uint8)
Image.fromarray(np.concatenate((image, res), 1)).save('onnx-pil.png')

效果:

相关推荐
鹅毛在路上了3 分钟前
昇思25天学习打卡营第5天|GAN图像生成
人工智能·生成对抗网络·mindspore
傻啦嘿哟4 分钟前
为什么写Python脚本时要加上if __name__ == ‘__main__‘?
开发语言·python
硅纪元4 分钟前
硅纪元视角 | AI纳米机器人突破癌症治疗,精准打击肿瘤细胞
大数据·人工智能·机器人
bigbigli_大李8 分钟前
C++基础21 二维数组及相关问题详解
数据结构·c++·算法
vosokcc@yuyinjiqiren11 分钟前
ai智能语音机器人电销系统:让销售更快速高效
大数据·服务器·网络·人工智能·机器人
FL162386312916 分钟前
[数据集][目标检测]睡岗检测数据集VOC+YOLO格式3290张4类别
人工智能·yolo·目标检测
今日信息差30 分钟前
7月04日,每日信息差
大数据·人工智能·科技·阿里云·云计算
O zil38 分钟前
资料分析题目类型分类
人工智能·分类·数据挖掘
华为云PaaS服务小智1 小时前
HDC Cloud 2024 | CodeArts加速软件智能化开发,携手HarmonyOS重塑企业应用创新体验
人工智能·华为·harmonyos
子龙烜1 小时前
数据分析三剑客-Matplotlib
python·数据挖掘·数据分析·matplotlib