pytorch实现DCP暗通道先验去雾算法及其onnx导出

pytorch实现DCP暗通道先验去雾算法及其onnx导出

简介

最近在做图像去雾,于是在Pytorch上复现了一下dcp算法。暗通道先验去雾算法是大神何恺明2009年发表在CVPR上的一篇论文,还获得了当年的CVPR最佳论文。

实现

具体原理就不阐述了,网上的解析多的是,这里直接把用pytorch复现的代码贴出来:

python 复制代码
import torch


def dcp(img, omega=0.75):
    h, w = img.shape[2:]
    imsz = h * w
    # 要查找的是暗通道中前0.1%的值
    numpx = torch.clamp_min(imsz // 1000, 1)
    # 找到暗通道的索引,弄成[batch, 3, numpx],因为要匹配三个通道,所以需要expand
    dark = torch.min(img, dim=1, keepdim=True)[0]
    indices = torch.topk(dark.view(-1, imsz), k=numpx, dim=1)[1].view(-1, 1, numpx).expand(-1, 3, -1)
    # 用上述索引匹配原图中的3个通道,并求其平均值
    a = (torch.gather(img.view(-1, 3, imsz), 2, indices).sum(2) / numpx).view(-1, 3, 1, 1)
    
    # 代公式算tx
    tx =  1 - omega * torch.min(img / a.view(-1, 3, 1, 1), dim=1, keepdim=True)[0]
    # 代公式算jx
    return (img - a) / torch.clamp_min(tx, 0.1) + a

函数有两个参数:

  1. img:经归一化后的(N,C,H,W)布局的图像
  2. omega:DCP算法的一个参数ω,数值越大效果越强

如果想在模型训练时引入dcp算法,可以用nn.Module封装一下:

python 复制代码
class DCP(torch.nn.Module):
    def __init__(self, omega):
        self._omega = omega

    def forward(self, x):
        return dcp(x, self._omega)

ONNX导出

导出

既然能封装成Module,那么就顺便试了一下导出ONNX。

导出onnx需要安装onnx和onnxsim:

bash 复制代码
pip install onnx onnxsim

导出代码如下:

python 复制代码
import torch
import onnx
from onnxsim import simplify 


def dcp(img, omega=0.75):
    h, w = img.shape[2:]
    imsz = h * w
    # 要查找的是暗通道中前0.1%的值
    numpx = torch.clamp_min(imsz // 1000, 1)
    # 找到暗通道的索引,弄成[batch, 3, numpx],因为要匹配三个通道,所以需要expand
    dark = torch.min(img, dim=1, keepdim=True)[0]
    indices = torch.topk(dark.view(-1, imsz), k=numpx, dim=1)[1].view(-1, 1, numpx).expand(-1, 3, -1)
    # 用上述索引匹配原图中的3个通道,并求其平均值
    a = (torch.gather(img.view(-1, 3, imsz), 2, indices).sum(2) / numpx).view(-1, 3, 1, 1)
    
    # 代公式算tx
    tx =  1 - omega * torch.min(img / a.view(-1, 3, 1, 1), dim=1, keepdim=True)[0]
    # 代公式算jx
    return (img - a) / torch.clamp_min(tx, 0.1) + a

class DCPExport(torch.nn.Module):
    def forward(self, x, omega):
        return dcp(x, omega)

def export(output='dcp.onnx'):
    torch.onnx.export(
        DCPExport(), 
        (torch.randn(1, 3, 255, 255, dtype=torch.float32), torch.tensor(0.75, dtype=torch.float32)), 
        'dcp.onnx', 
        input_names=['fog_image', 'omega'], 
        output_names=['clear_image'], 
        dynamic_axes={
            'fog_image': {0: 'batch', 2: 'height', 3: 'width'},
            'clear_image': {0: 'batch', 2: 'height', 3: 'width'},
        }
    )
    onnx_model = onnx.load(output) 
    model_simp, check = simplify(onnx_model) 
    assert check, "简化模型失败" 
    onnx.save(model_simp, output) 

if __name__ == '__main__':
	export()

导出结果如下:

导出后的onnx输入输出如下:

  • 输入:
    1. fog_image[float32]:形状为NCHW,且归一化的有雾图像,其中通道数C必须为3
    2. omega[float32]:dcp的参数,类型为浮点数
  • 输出:
    1. clear_image[float32]:形状为NCHW,且归一化的无雾图像,其中通道数C为3

下载链接:https://pan.baidu.com/s/1A1jSJQBFCGTeM8vbHOrysQ?pwd=tl6p

测试

用cv2和pil都可以:

python 复制代码
import numpy as np
import cv2
from PIL import Image
from onnxruntime import InferenceSession


model = InferenceSession('dcp.onnx')

# CV2读图
image = cv2.imread('dehaze/dehaze/input/images/indoor1.jpg')
# 这里说明一下,因为dcp对所有通道进行同等变换,所以不用bgr和rgb互转了,出来的结果都是一样的
# x = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
x = np.transpose(image, (2, 0, 1))[None].astype(np.float32) / 255.
res = model.run(['clear_image'], {'fog_image': x, 'omega': np.array(0.75, dtype=np.float32)})[0][0]
res = np.transpose(res, (1, 2, 0))
res = np.clip(res*255+0.5, 0, 255).astype(np.uint8)
# res = cv2.cvtColor(res, cv2.COLOR_RGB2BGR)
cv2.imwrite('onnx-cv.png', np.concatenate((image, res), 1))

# PIL读图
image = Image.open('dehaze/dehaze/input/images/indoor1.jpg')
x = np.transpose(image, (2, 0, 1))[None].astype(np.float32) / 255.
res = model.run(None, {'fog_image': x, 'omega': np.array(0.75, dtype=np.float32)})[0][0]
res = np.transpose(res, (1, 2, 0))
res = np.clip(res*255+0.5, 0, 255).astype(np.uint8)
Image.fromarray(np.concatenate((image, res), 1)).save('onnx-pil.png')

效果:

相关推荐
zm-v-159304339862 分钟前
ChatGPT 与 DeepSeek:学术科研的智能 “双引擎”
人工智能·chatgpt
果冻人工智能4 分钟前
美国狂奔,中国稳走,AI赛道上的龟兔之争?
人工智能
牙牙要健康5 分钟前
【目标检测】【深度学习】【Pytorch版本】YOLOV2模型算法详解
pytorch·深度学习·目标检测
果冻人工智能14 分钟前
再谈AI与程序员: AI 写的代码越来越多,那我们还需要开发者吗?
人工智能
大脑探路者18 分钟前
【PyTorch】继承 nn.Module 创建简单神经网络
人工智能·pytorch·神经网络
梭七y21 分钟前
【力扣hot100题】(033)合并K个升序链表
算法·leetcode·链表
月亮被咬碎成星星26 分钟前
LeetCode[383]赎金信
算法·leetcode
无代码Dev39 分钟前
如何使用AI去水印(ChatGPT去除图片水印)
人工智能·ai·ai-native
无难事者若执1 小时前
新手村:逻辑回归-理解03:逻辑回归中的最大似然函数
算法·机器学习·逻辑回归
达柳斯·绍达华·宁1 小时前
自动驾驶04:点云预处理03
人工智能·机器学习·自动驾驶