【DA-CLIP】test.py解读,调用DA-CLIP和IRSDE模型复原计算复原图与GT图SSIM、PSNR、LPIPS

文件路径daclip-uir-main/universal-image-restoration/config/daclip-sde/test.py

代码有部分修改

导包

python 复制代码
import argparse
import logging
import os.path
import sys
import time
from collections import OrderedDict
import torchvision.utils as tvutils

import numpy as np
import torch
from IPython import embed
import lpips

import options as option
from models import create_model

sys.path.insert(0, "../../")
import open_clip
import utils as util
from data import create_dataloader, create_dataset
from data.util import bgr2ycbcr

注意open_clip使用的是项目里的代码,而非环境里装的那个。data、util、option同样是项目里有的包

声明

python 复制代码
#### options
parser = argparse.ArgumentParser()
parser.add_argument("-opt", type=str, default='options/test.yml', help="Path to options YMAL file.")
opt = option.parse(parser.parse_args().opt, is_train=False)

opt = option.dict_to_nonedict(opt)

配置文件

设置配置文件相对地址options/test.yml

在该配置文件中配置GT和LQ图像文件地址

python 复制代码
datasets:
  test1:
   name: Test
   mode: LQGT
   dataroot_GT: C:\Users\86136\Desktop\LQ_test\shadow\GT
   dataroot_LQ: C:\Users\86136\Desktop\LQ_test\shadow\LQ

设置results_root结果地址,每次计算结束这个地址保存要求记录的计算结果

该目录下Test文件夹将保存一张GT一张LQ一张复原图像 。

不设置也会默认在项目内 daclip-uir-main\results\daclip-sde\universal-ir

python 复制代码
#### path
path:
  pretrain_model_G: E:\daclip\pretrained\universal-ir.pth
  daclip: E:\daclip\pretrained\daclip_ViT-B-32.pt
  results_root: C:\Users\86136\Desktop\daclip-uir-main\results\daclip-sde\universal-ir
  log: 
python 复制代码
#### mkdir and logger
util.mkdirs(
    (
        path
        for key, path in opt["path"].items()
        if not key == "experiments_root"
        and "pretrain_model" not in key
        and "resume" not in key
    )
)

# os.system("rm ./result")
# os.symlink(os.path.join(opt["path"]["results_root"], ".."), "./result")

报错执行代码没有删除再创建权限?我把相关os操作注释了,全部保存到result对我影响不大

加载创建数据对

python 复制代码
#### Create test dataset and dataloader
test_loaders = []
for phase, dataset_opt in sorted(opt["datasets"].items()):
    test_set = create_dataset(dataset_opt)
    test_loader = create_dataloader(test_set, dataset_opt)
    logger.info(
        "Number of test images in [{:s}]: {:d}".format(
            dataset_opt["name"], len(test_set)
        )
    )
    test_loaders.append(test_loader)

自定义包含复原IR-SDE模型的外层类model,参考app.py

python 复制代码
# load pretrained model by default
model = create_model(opt)
device = model.device

加载DA-CLIP、IR-SDE

python 复制代码
# clip_model, _preprocess = clip.load("ViT-B/32", device=device)
if opt['path']['daclip'] is not None:
    clip_model, preprocess = open_clip.create_model_from_pretrained('daclip_ViT-B-32', pretrained=opt['path']['daclip'])
else:
    clip_model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')
tokenizer = open_clip.get_tokenizer('ViT-B-32')
clip_model = clip_model.to(device)

else是直接使用CLIP的ViT-B-32模型进行测试的代码。与我测DA-CLIP无关。

想使用的话 目测要预先下载对应模型权重并手动修改pretrained为文件地址,否则报错hf无法连接

Scala 复制代码
sde = util.IRSDE(max_sigma=opt["sde"]["max_sigma"], T=opt["sde"]["T"], schedule=opt["sde"]["schedule"], eps=opt["sde"]["eps"], device=device)
sde.set_model(model.model)
lpips_fn = lpips.LPIPS(net='alex').to(device)

scale = opt['degradation']['scale']

加载IR-SDE、LPIPS

如果不指定crop_border后续crop_border=scale

处理并计算

python 复制代码
for test_loader in test_loaders:
    test_set_name = test_loader.dataset.opt["name"]  # path opt['']
    logger.info("\nTesting [{:s}]...".format(test_set_name))
    test_start_time = time.time()
    dataset_dir = os.path.join(opt["path"]["results_root"], test_set_name)
    util.mkdir(dataset_dir)

    test_results = OrderedDict()
    test_results["psnr"] = []
    test_results["ssim"] = []
    test_results["psnr_y"] = []
    test_results["ssim_y"] = []
    test_results["lpips"] = []
    test_times = []

    for i, test_data in enumerate(test_loader):
        single_img_psnr = []
        single_img_ssim = []
        single_img_psnr_y = []
        single_img_ssim_y = []
        need_GT = False if test_loader.dataset.opt["dataroot_GT"] is None else True
        img_path = test_data["GT_path"][0] if need_GT else test_data["LQ_path"][0]
        img_name = os.path.splitext(os.path.basename(img_path))[0]

        #### input dataset_LQ
        LQ, GT = test_data["LQ"], test_data["GT"]
        img4clip = test_data["LQ_clip"].to(device)
        with torch.no_grad(), torch.cuda.amp.autocast():
            image_context, degra_context = clip_model.encode_image(img4clip, control=True)
            image_context = image_context.float()
            degra_context = degra_context.float()

        noisy_state = sde.noise_state(LQ)

        model.feed_data(noisy_state, LQ, GT, text_context=degra_context, image_context=image_context)
        tic = time.time()
        model.test(sde, save_states=False)
        toc = time.time()
        test_times.append(toc - tic)

        visuals = model.get_current_visuals()
        SR_img = visuals["Output"]
        output = util.tensor2img(SR_img.squeeze())  # uint8
        LQ_ = util.tensor2img(visuals["Input"].squeeze())  # uint8
        GT_ = util.tensor2img(visuals["GT"].squeeze())  # uint8
        
        suffix = opt["suffix"]
        if suffix:
            save_img_path = os.path.join(dataset_dir, img_name + suffix + ".png")
        else:
            save_img_path = os.path.join(dataset_dir, img_name + ".png")
        util.save_img(output, save_img_path)

        # remove it if you only want to save output images
        LQ_img_path = os.path.join(dataset_dir, img_name + "_LQ.png")
        GT_img_path = os.path.join(dataset_dir, img_name + "_HQ.png")
        util.save_img(LQ_, LQ_img_path)
        util.save_img(GT_, GT_img_path)

        if need_GT:
            gt_img = GT_ / 255.0
            sr_img = output / 255.0

            crop_border = opt["crop_border"] if opt["crop_border"] else scale
            if crop_border == 0:
                cropped_sr_img = sr_img
                cropped_gt_img = gt_img
            else:
                cropped_sr_img = sr_img[
                    crop_border:-crop_border, crop_border:-crop_border
                ]
                cropped_gt_img = gt_img[
                    crop_border:-crop_border, crop_border:-crop_border
                ]

            psnr = util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)
            ssim = util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255)
            lp_score = lpips_fn(
                GT.to(device) * 2 - 1, SR_img.to(device) * 2 - 1).squeeze().item()

            test_results["psnr"].append(psnr)
            test_results["ssim"].append(ssim)
            test_results["lpips"].append(lp_score)

            if len(gt_img.shape) == 3:
                if gt_img.shape[2] == 3:  # RGB image
                    sr_img_y = bgr2ycbcr(sr_img, only_y=True)
                    gt_img_y = bgr2ycbcr(gt_img, only_y=True)
                    if crop_border == 0:
                        cropped_sr_img_y = sr_img_y
                        cropped_gt_img_y = gt_img_y
                    else:
                        cropped_sr_img_y = sr_img_y[
                            crop_border:-crop_border, crop_border:-crop_border
                        ]
                        cropped_gt_img_y = gt_img_y[
                            crop_border:-crop_border, crop_border:-crop_border
                        ]
                    psnr_y = util.calculate_psnr(
                        cropped_sr_img_y * 255, cropped_gt_img_y * 255
                    )
                    ssim_y = util.calculate_ssim(
                        cropped_sr_img_y * 255, cropped_gt_img_y * 255
                    )

                    test_results["psnr_y"].append(psnr_y)
                    test_results["ssim_y"].append(ssim_y)

                    logger.info(
                        "img{:3d}:{:15s} - PSNR: {:.6f} dB; SSIM: {:.6f}; LPIPS: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.".format(
                            i, img_name, psnr, ssim, lp_score, psnr_y, ssim_y
                        )
                    )
            else:
                logger.info(
                    "img:{:15s} - PSNR: {:.6f} dB; SSIM: {:.6f}.".format(
                        img_name, psnr, ssim
                    )
                )

                test_results["psnr_y"].append(psnr)
                test_results["ssim_y"].append(ssim)
        else:
            logger.info(img_name)


    ave_lpips = sum(test_results["lpips"]) / len(test_results["lpips"])
    ave_psnr = sum(test_results["psnr"]) / len(test_results["psnr"])
    ave_ssim = sum(test_results["ssim"]) / len(test_results["ssim"])
    logger.info(
        "----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n".format(
            test_set_name, ave_psnr, ave_ssim
        )
    )
    if test_results["psnr_y"] and test_results["ssim_y"]:
        ave_psnr_y = sum(test_results["psnr_y"]) / len(test_results["psnr_y"])
        ave_ssim_y = sum(test_results["ssim_y"]) / len(test_results["ssim_y"])
        logger.info(
            "----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n".format(
                ave_psnr_y, ave_ssim_y
            )
        )

    logger.info(
            "----average LPIPS\t: {:.6f}\n".format(ave_lpips)
        )

    print(f"average test time: {np.mean(test_times):.4f}")

开头往log记录了相应配置文件内容,不需要可以注释。

遍历测试数据集(test_loaders)计算各种评价指标,如峰值信噪比(PSNR)、结构相似性(SSIM)和感知损失(LPIPS)。

在处理过程中,代码首先会创建一个目录来保存测试结果。

然后,对于每个测试图像,代码会加载对应的图像(如果可用),并使用一个名为clip_model的模型对图像进行编码。

接下来,代码会使用一个名为sde的随机微分方程模型和名为model的深度学习模型来处理带有噪声的图像,并生成复原图像(SR_img)。额可能作者拿了以前做超分的代码没改变量名

在这个过程中,text_contextimage_context被用作模型的输入,

图像都会被保存到之前创建的目录中。

此外,代码还会计算并记录每个图像的PSNR、SSIM和LPIPS分数,并在最后打印出这些分数的平均值。 代码中还包含了一些用于图像处理的实用函数,如util.tensor2img用于将张量转换为图像,util.save_img用于保存图像,以及util.calculate_psnrutil.calculate_ssim用于计算PSNR和SSIM分数。psnr_y和ssim_y 不用可以把相关代码注释。

最后,代码还计算了平均测试时间,并将其打印出来。

结果

log处理的单张图像报错的信息 0是该处理的图像排序序号,即正在处理第0张图

24-04-03 17:28:24.697 - INFO: img 0:_MG_2374_no_shadow - PSNR: 27.779773 dB; SSIM: 0.863140; LPIPS: 0.078669; PSNR_Y: 29.135256 dB; SSIM_Y: 0.869278.

可以给复原结果图加个后缀方便区分。

相关推荐
xiandong201 小时前
240929-CGAN条件生成对抗网络
图像处理·人工智能·深度学习·神经网络·生成对抗网络·计算机视觉
innutritious2 小时前
车辆重识别(2020NIPS去噪扩散概率模型)论文阅读2024/9/27
人工智能·深度学习·计算机视觉
PythonFun2 小时前
Python批量下载PPT模块并实现自动解压
开发语言·python·powerpoint
醒了就刷牙3 小时前
56 门控循环单元(GRU)_by《李沐:动手学深度学习v2》pytorch版
pytorch·深度学习·gru
炼丹师小米3 小时前
Ubuntu24.04.1系统下VideoMamba环境配置
python·环境配置·videomamba
橙子小哥的代码世界3 小时前
【深度学习】05-RNN循环神经网络-02- RNN循环神经网络的发展历史与演化趋势/LSTM/GRU/Transformer
人工智能·pytorch·rnn·深度学习·神经网络·lstm·transformer
GFCGUO3 小时前
ubuntu18.04运行OpenPCDet出现的问题
linux·python·学习·ubuntu·conda·pip
985小水博一枚呀4 小时前
【深度学习基础模型】神经图灵机(Neural Turing Machines, NTM)详细理解并附实现代码。
人工智能·python·rnn·深度学习·lstm·ntm
SEU-WYL5 小时前
基于深度学习的任务序列中的快速适应
人工智能·深度学习
OCR_wintone4215 小时前
中安未来 OCR—— 开启高效驾驶证识别新时代
人工智能·汽车·ocr