pytorch实现DCP暗通道先验去雾算法及其onnx导出

pytorch实现DCP暗通道先验去雾算法及其onnx导出

简介

最近在做图像去雾,于是在Pytorch上复现了一下dcp算法。暗通道先验去雾算法是大神何恺明2009年发表在CVPR上的一篇论文,还获得了当年的CVPR最佳论文。

实现

具体原理就不阐述了,网上的解析多的是,这里直接把用pytorch复现的代码贴出来:

python 复制代码
import torch


def dcp(img, omega=0.75):
    h, w = img.shape[2:]
    imsz = h * w
    # 要查找的是暗通道中前0.1%的值
    numpx = torch.clamp_min(imsz // 1000, 1)
    # 找到暗通道的索引,弄成[batch, 3, numpx],因为要匹配三个通道,所以需要expand
    dark = torch.min(img, dim=1, keepdim=True)[0]
    indices = torch.topk(dark.view(-1, imsz), k=numpx, dim=1)[1].view(-1, 1, numpx).expand(-1, 3, -1)
    # 用上述索引匹配原图中的3个通道,并求其平均值
    a = (torch.gather(img.view(-1, 3, imsz), 2, indices).sum(2) / numpx).view(-1, 3, 1, 1)
    
    # 代公式算tx
    tx =  1 - omega * torch.min(img / a.view(-1, 3, 1, 1), dim=1, keepdim=True)[0]
    # 代公式算jx
    return (img - a) / torch.clamp_min(tx, 0.1) + a

函数有两个参数:

  1. img:经归一化后的(N,C,H,W)布局的图像
  2. omega:DCP算法的一个参数ω,数值越大效果越强

如果想在模型训练时引入dcp算法,可以用nn.Module封装一下:

python 复制代码
class DCP(torch.nn.Module):
    def __init__(self, omega):
        self._omega = omega

    def forward(self, x):
        return dcp(x, self._omega)

ONNX导出

导出

既然能封装成Module,那么就顺便试了一下导出ONNX。

导出onnx需要安装onnx和onnxsim:

bash 复制代码
pip install onnx onnxsim

导出代码如下:

python 复制代码
import torch
import onnx
from onnxsim import simplify 


def dcp(img, omega=0.75):
    h, w = img.shape[2:]
    imsz = h * w
    # 要查找的是暗通道中前0.1%的值
    numpx = torch.clamp_min(imsz // 1000, 1)
    # 找到暗通道的索引,弄成[batch, 3, numpx],因为要匹配三个通道,所以需要expand
    dark = torch.min(img, dim=1, keepdim=True)[0]
    indices = torch.topk(dark.view(-1, imsz), k=numpx, dim=1)[1].view(-1, 1, numpx).expand(-1, 3, -1)
    # 用上述索引匹配原图中的3个通道,并求其平均值
    a = (torch.gather(img.view(-1, 3, imsz), 2, indices).sum(2) / numpx).view(-1, 3, 1, 1)
    
    # 代公式算tx
    tx =  1 - omega * torch.min(img / a.view(-1, 3, 1, 1), dim=1, keepdim=True)[0]
    # 代公式算jx
    return (img - a) / torch.clamp_min(tx, 0.1) + a

class DCPExport(torch.nn.Module):
    def forward(self, x, omega):
        return dcp(x, omega)

def export(output='dcp.onnx'):
    torch.onnx.export(
        DCPExport(), 
        (torch.randn(1, 3, 255, 255, dtype=torch.float32), torch.tensor(0.75, dtype=torch.float32)), 
        'dcp.onnx', 
        input_names=['fog_image', 'omega'], 
        output_names=['clear_image'], 
        dynamic_axes={
            'fog_image': {0: 'batch', 2: 'height', 3: 'width'},
            'clear_image': {0: 'batch', 2: 'height', 3: 'width'},
        }
    )
    onnx_model = onnx.load(output) 
    model_simp, check = simplify(onnx_model) 
    assert check, "简化模型失败" 
    onnx.save(model_simp, output) 

if __name__ == '__main__':
	export()

导出结果如下:

导出后的onnx输入输出如下:

  • 输入:
    1. fog_image[float32]:形状为NCHW,且归一化的有雾图像,其中通道数C必须为3
    2. omega[float32]:dcp的参数,类型为浮点数
  • 输出:
    1. clear_image[float32]:形状为NCHW,且归一化的无雾图像,其中通道数C为3

下载链接:https://pan.baidu.com/s/1A1jSJQBFCGTeM8vbHOrysQ?pwd=tl6p

测试

用cv2和pil都可以:

python 复制代码
import numpy as np
import cv2
from PIL import Image
from onnxruntime import InferenceSession


model = InferenceSession('dcp.onnx')

# CV2读图
image = cv2.imread('dehaze/dehaze/input/images/indoor1.jpg')
# 这里说明一下,因为dcp对所有通道进行同等变换,所以不用bgr和rgb互转了,出来的结果都是一样的
# x = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
x = np.transpose(image, (2, 0, 1))[None].astype(np.float32) / 255.
res = model.run(['clear_image'], {'fog_image': x, 'omega': np.array(0.75, dtype=np.float32)})[0][0]
res = np.transpose(res, (1, 2, 0))
res = np.clip(res*255+0.5, 0, 255).astype(np.uint8)
# res = cv2.cvtColor(res, cv2.COLOR_RGB2BGR)
cv2.imwrite('onnx-cv.png', np.concatenate((image, res), 1))

# PIL读图
image = Image.open('dehaze/dehaze/input/images/indoor1.jpg')
x = np.transpose(image, (2, 0, 1))[None].astype(np.float32) / 255.
res = model.run(None, {'fog_image': x, 'omega': np.array(0.75, dtype=np.float32)})[0][0]
res = np.transpose(res, (1, 2, 0))
res = np.clip(res*255+0.5, 0, 255).astype(np.uint8)
Image.fromarray(np.concatenate((image, res), 1)).save('onnx-pil.png')

效果:

相关推荐
sp_fyf_202433 分钟前
【大语言模型】ACL2024论文-19 SportsMetrics: 融合文本和数值数据以理解大型语言模型中的信息融合
人工智能·深度学习·神经网络·机器学习·语言模型·自然语言处理
CoderIsArt36 分钟前
基于 BP 神经网络整定的 PID 控制
人工智能·深度学习·神经网络
程序猿小柒40 分钟前
leetcode hot100【LeetCode 4.寻找两个正序数组的中位数】java实现
java·算法·leetcode
编程修仙1 小时前
Collections工具类
linux·windows·python
开源社1 小时前
一场开源视角的AI会议即将在南京举办
人工智能·开源
FreeIPCC1 小时前
谈一下开源生态对 AI人工智能大模型的促进作用
大数据·人工智能·机器人·开源
芝麻团坚果1 小时前
对subprocess启动的子进程使用VSCode python debugger
linux·ide·python·subprocess·vscode debugger
机器之心1 小时前
全球十亿级轨迹点驱动,首个轨迹基础大模型来了
人工智能·后端
z千鑫1 小时前
【人工智能】PyTorch、TensorFlow 和 Keras 全面解析与对比:深度学习框架的终极指南
人工智能·pytorch·深度学习·aigc·tensorflow·keras·codemoss
EterNity_TiMe_1 小时前
【论文复现】神经网络的公式推导与代码实现
人工智能·python·深度学习·神经网络·数据分析·特征分析