深度学习:基于MindSpore实现CycleGAN壁画修复

关于CycleGAN的基础知识可参考:

深度学习:CycleGAN图像风格迁移转换-CSDN博客

以及MindSpore官方的教学视频:

CycleGAN图像风格迁移转换_哔哩哔哩_bilibili

本案例将基于CycleGAN实现破损草图到线稿图的转换

数据集

本案例使用的数据集里面的图片为经图线稿图数据。图像被统一缩放为256×256像素大小,其中用于训练的线稿图片25654张、草图图片25654张,用于测试的线稿图片100张、草图图片116张。

这里对数据进行了随机裁剪、水平随机翻转和归一化的预处理。

DexiNed

DexiNed(Dense Extreme Inception Network for Edge Detection)是一个为边缘检测任务设计的深度卷积神经网络模型。它由两个主要部分组成:Dexi和上采样网络(USNet)。

Dexi

这是模型的主要部分,包含六个编码块,每个块由多个子块组成,子块中包含卷积层、批量归一化和ReLU激活函数。从第二个块开始,引入了跳跃连接(skip connections),以保留不同层次的边缘特征。这些块的输出特征图被送入上采样网络以生成中间边缘图。

上采样网络(USNet)

这个部分由多个上采样块组成,每个块包括卷积层和反卷积层(也称为转置卷积层)。USNet的作用是将Dexi输出的低分辨率特征图上采样到更高的分辨率,以生成清晰的边缘图。

卷积层用于提取特征,反卷积层(或转置卷积层)用于将特征图的空间尺寸增大。

损失函数

DexiNed模型使用的损失函数是专门为边缘检测任务设计的,它在一定程度上受到了BDCN(Bi-directional Cascade Network)损失函数的启发,并进行了一些修改和优化。这个损失函数的目的是在训练过程中平衡正面(正样本)和负面(负样本)的边缘样本比例,从而提高边缘检测的准确性。

损失函数定义为:

其中,

Li 是第 i 个输出的损失,λi 是对应的权重,用于平衡正负样本的比例。具体的 Li 计算方式为:

DexiNed数据集

DexiNed模型的训练数据集主要是为边缘检测任务设计的高质量数据集。在论文中提到了两个主要的数据集:

  1. BIPED (Barcelona Images for Perceptual Edge Detection):这是一个特别为边缘检测设计的大规模数据集,包含详细的边缘标注信息。它由250张真实世界的图像组成,图像分辨率为1280×720像素,主要描绘城市环境场景。这些图像的边缘通过手动标注生成,以确保边缘检测的准确性。

  2. MDBD (Multicue Dataset for Boundary Detection):这是一个用于边界检测的数据集,也适用于边缘检测任务。它由100个高清图像组成,每个图像有多个参与者的标注,适用于训练和评估边缘检测算法。

DexiNed模型需要成对的数据来进行训练,即每张输入图像都需要有一个对应的标注图像(Ground Truth, GT)。这些标注图像详细地标出了图像中边缘的位置,模型通过比较预测边缘和这些标注来学习如何准确地检测边缘。

DexiNed在本例中主要用于将彩色图片转化为线稿图,随后将线稿图输入CycleGAN,得到输出。

基于MindSpore的壁画修复

加载数据集

python 复制代码
#下载数据集
from download import download

url = "https://6169fb4615b14dbcb6b2cb1c4eb78bb2.obs.cn-north-4.myhuaweicloud.com/Cyc_line.zip"

download(url, "./localdata", kind="zip", replace=True)
python 复制代码
from __future__ import division
import math
import numpy as np

import os
import multiprocessing

import mindspore.dataset as de
import mindspore.dataset.vision as vision

"""数据集分布式采样器"""
class DistributedSampler:
    """Distributed sampler."""
    def __init__(self, dataset_size, num_replicas=None, rank=None, shuffle=True):
        if num_replicas is None:
            print("***********Setting world_size to 1 since it is not passed in ******************")
            num_replicas = 1
        if rank is None:
            print("***********Setting rank to 0 since it is not passed in ******************")
            rank = 0
        self.dataset_size = dataset_size
        self.num_replicas = num_replicas
        self.rank = rank
        self.epoch = 0
        self.num_samples = int(math.ceil(dataset_size * 1.0 / self.num_replicas))
        self.total_size = self.num_samples * self.num_replicas
        self.shuffle = shuffle

    def __iter__(self):
        # deterministically shuffle based on epoch
        if self.shuffle:
            indices = np.random.RandomState(seed=self.epoch).permutation(self.dataset_size)
            # np.array type. number from 0 to len(dataset_size)-1, used as index of dataset
            indices = indices.tolist()
            self.epoch += 1
            # change to list type
        else:
            indices = list(range(self.dataset_size))

        # add extra samples to make it evenly divisible
        indices += indices[:(self.total_size - len(indices))]
        assert len(indices) == self.total_size

        # subsample
        indices = indices[self.rank:self.total_size:self.num_replicas]
        assert len(indices) == self.num_samples

        return iter(indices)

    def __len__(self):
        return self.num_samples
python 复制代码
# 加载CycleGAN数据集
IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.tif', '.tiff']

# 判断当前文件是否为图片
def is_image_file(filename):
    return any(filename.lower().endswith(extension) for extension in IMG_EXTENSIONS)

# 定义一个函数用于从指定目录中创建数据集列表
def make_dataset(dir_path, max_dataset_size=float("inf")):
    # 初始化一个空列表用来存储图片路径
    images = []
    # 确保提供的dir_path是一个有效的目录
    assert os.path.isdir(dir_path), '%s is not a valid directory' % dir_path
    # 遍历目录下的所有文件,将图片的文件路径存入images列表
    for root, _, fnames in sorted(os.walk(dir_path)):
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)
                images.append(path)
    # 返回指定长度的图片路径列表
    return images[:min(max_dataset_size, len(images))]

# CycleGAN中没有成对出现,但是分属两个领域的图片数据
class UnalignedDataset:
    '''
    此数据集类能够加载未对齐或未配对的数据集。
    需要两个目录来存放来自领域A和B的训练图片。
    可以使用'--dataroot /path/to/data'这样的标志来训练模型。
    同样,在测试时也需要准备两个目录。
    返回:两个领域的图片路径列表。
    '''
    def __init__(self, dataroot, max_dataset_size=float("inf"), use_random=True):
        # 根据指定根路径生成A\B领域数据的文件夹路径
        self.dir_A = os.path.join(dataroot, 'trainA')
        self.dir_B = os.path.join(dataroot, 'trainB')
        
        # 领域A图片数据的路径
        self.A_paths = sorted(make_dataset(self.dir_A, max_dataset_size))
        # 领域B图片数据的路径
        self.B_paths = sorted(make_dataset(self.dir_B, max_dataset_size))
        # 领域A的数据长度
        self.A_size = len(self.A_paths)
        # 领域B的数据长度
        self.B_size = len(self.B_paths)
        # 根据参数决定是否随机化
        self.use_random = use_random
    # 从数据集中根据给定的索引 index 获取一个样本对(分别来自领域A和领域B的一张图片)
    def __getitem__(self, index):
        # 数据A的索引
        index_A = index % self.A_size
        # 数据B的索引
        index_B = index % self.B_size
        # 每遍历完所有图片后会重新随机排序领域A中的图片路径列表,并且从领域B中随机选取图片。
        if index % max(self.A_size, self.B_size) == 0 and self.use_random:
            random.shuffle(self.A_paths)
            index_B = random.randint(0, self.B_size - 1)
        # 获取指定下标的图片路径
        A_path = self.A_paths[index_A]
        B_path = self.B_paths[index_B]
        # 获取图片对象
        A_img = np.array(Image.open(A_path).convert('RGB'))
        B_img = np.array(Image.open(B_path).convert('RGB'))
        # 返回领域A和B的图片
        return A_img, B_img
    
    def __len__(self):
        return max(self.A_size, self.B_size)
python 复制代码
def create_dataset(dataroot, batch_size=1, use_random=True, device_num=1, rank=0, max_dataset_size=float("inf"), image_size=256):
    """
    创建数据集
    该数据集类可以加载用于训练或测试的图像。
    参数:
        dataroot (str): 图像根目录。
        batch_size (int): 批处理大小,默认为1。
        use_random (bool): 是否使用随机化,默认为True。
        device_num (int): 设备数量,默认为1。
        rank (int): 当前设备的排名,默认为0。
        max_dataset_size (float): 数据集的最大大小,默认为无穷大。
        image_size (int): 图像的尺寸,默认为256x256。
    返回:
        RGB图像列表。
    """
    shuffle = use_random # 是否打乱数据集
    # 获取系统可用的CPU核心数
    cores = multiprocessing.cpu_count()
    # 计算并行工作的线程数,根据设备数量分配
    num_parallel_workers = min(1, int(cores / device_num))
    # 定义归一化时使用的均值和标准差
    # 三个通道的均值和房擦汗都是127.5
    mean = [0.5 * 255] * 3
    std = [0.5 * 255] * 3
    # 创建数据集(未对齐)
    dataset = UnalignedDataset(dataroot, max_dataset_size=max_dataset_size, use_random=use_random)
    # 使用DistributedSampler来实现分布式采样
    distributed_sampler = DistributedSampler(len(dataset), device_num, rank, shuffle=shuffle)
    # 创建GeneratorDataset,指定列名,并使用之前创建的sampler和并行工作线程数
    ds = de.GeneratorDataset(dataset, column_names=["image_A", "image_B"],
                             sampler=distributed_sampler, num_parallel_workers=num_parallel_workers)
    # 指定数据增强操作
    if use_random:
        trans = [
            # 图片随机裁剪变比例
            vision.RandomResizedCrop(image_size, scale=(0.5, 1.0), ratio=(0.75, 1.333)),
            # 水平翻转,概率为0.5
            vision.RandomHorizontalFlip(prob=0.5),
            # 图片数据归一化
            vision.Normalize(mean=mean, std=std),
            vision.HWC2CHW()
        ]
    else:  # 如果不启用随机化,则只进行简单的缩放和归一化
        trans = [
            C.Resize((image_size, image_size)),  # 固定大小缩放
            C.Normalize(mean=mean, std=std),  # 归一化
            C.HWC2CHW()  # 将HWC格式转换为CHW格式
        ]
    # 将数据增强操作映射到数据中
    ds = ds.map(operations=trans, input_columns=["image_A"], num_parallel_workers=num_parallel_workers)
    ds = ds.map(operations=trans, input_columns=["image_B"], num_parallel_workers=num_parallel_workers)
    # 设置批处理大小,并且丢弃不足一批的数据
    ds = ds.batch(batch_size, drop_remainder=True)
    
    return ds
python 复制代码
#根据设备情况调整训练参数

dataroot = "./localdata"
batch_size = 12
device_num = 1
rank = 0
use_random = True
max_dataset_size = 24000
image_size = 256

cyclegan_ds = create_dataset(dataroot=dataroot,max_dataset_size=max_dataset_size,batch_size=batch_size,device_num=device_num,rank = rank,use_random=use_random,image_size=image_size)
datasize = cyclegan_ds.get_dataset_size()
print("Datasize: ", datasize)

构建和训练CycleGAN的代码不再重复贴出,可参考上述提到的博文

构建DexiNed

python 复制代码
import os
import cv2
import numpy as np
import time

import mindspore as ms
from mindspore import nn, ops
from mindspore import dataset as ds
from mindspore.amp import auto_mixed_precision
from mindspore.common import initializer as init
python 复制代码
# DexiNed边缘检测数据集
class Test_Dataset():
    def __init__(self, data_root, mean_bgr, image_size):
        self.data = []
        # 列出根路径下的所有文件名
        imgs_ = os.listdir(data_root)
        # 初始化两个空列表,分别用于存储图像路径和对应的文件名
        self.names = []
        self.filenames = []
        for img in imgs_:
            if img.endswith(".png") or img.endswith(".jpg"):
                # 构建文件的完整路径
                dir = os.path.join(data_root, img)
                self.names.append(dir)
                self.filenames.append(img)
        self.mean_bgr = mean_bgr
        self.image_size = image_size
        
    def __len__(self):
        return len(self.names)
    
    def __getitem__(self, idx):
        # 使用OpenCV读取指定索引位置的图像,读取模式为彩色
        image = cv2.imread(self.names[idx], cv2.IMREAD_COLOR)
        # 读取图片的长宽
        im_shape = (image.shape[0], image.shape[1])
        # 对图像进行变换处理
        image = self.transform(img=image)
        # 返回图像、图像名、图像形状
        return image, self.filenames[idx], im_shape
    
    def transform(self, img):
        # 裁切
        img = cv2.resize(img, (self.image_size, self.image_size))
        # 将图像转换为浮点数类型
        img = np.array(img, dtype=np.float32)
        # 归一化
        img -= self.mean_bgr
        # 将图像从(H, W, C)格式转换为(C, H, W)格式
        img = img.transpose((2, 0, 1))
        return img
python 复制代码
# DexiNed网络结构

# 初始化权重函数
def weight_init(net):
    for name, param in net.parameters_and_names():
        # 使用Xavier分布初始化权重
        if 'weight' in name:
            param.set_data(
                init.initializer(
                    init.XavierNormal(),
                    param.shape,
                    param.dtype))
        # 偏置初始化为0
        if 'bias' in name:
            param.set_data(init.initializer('zeros', param.shape, param.dtype))
python 复制代码
# 表示DexiNed中的一个基础的密集连接层,实现具有批量归一化和ReLU激活的卷积层
class _DenseLayer(nn.Cell):
    def __init__(self, input_features, out_features):
        super(_DenseLayer, self).__init__()
        # 两个ConvNormReLU块
        self.conv1 = nn.Conv2d(
            input_features, out_features, kernel_size=3,
            stride=1, padding=2, pad_mode="pad",
            has_bias=True, weight_init=init.XavierNormal())
        self.norm1 = nn.BatchNorm2d(out_features)
        self.relu1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(
            out_features, out_features, kernel_size=3,
            stride=1, pad_mode="pad", has_bias=True,
            weight_init=init.XavierNormal())
        self.norm2 = nn.BatchNorm2d(out_features)
        self.relu = ops.ReLU()
    
    def construct(self, x):
        x1, x2 = x
        x1 = self.conv1(self.relu(x1))
        x1 = self.norm1(x1)
        x1 = self.relu1(x1)
        x1 = self.conv2(x1)
        new_features = self.norm2(x1)
        # 每一层的输出都会被传递给所有后续层作为输入
        return 0.5 * (new_features + x2), x2
python 复制代码
# 基于DenseLayer定义DenseBlock
class _DenseBlock(nn.Cell):
    def __init__(self, num_layers, input_features, out_features):
        super(_DenseBlock, self).__init__()
        
        self.denselayer1 = _DenseLayer(input_features, out_features)
        
        input_features = out_features
        self.denselayer2 = _DenseLayer(input_features, out_features)
        
        if num_layers == 3:
            self.denselayer3 = _DenseLayer(input_features, out_features)
            self.layers = nn.SequentialCell(
                [self.denselayer1, self.denselayer2, self.denselayer3])
        else:
            self.layers = nn.SequentialCell(
                [self.denselayer1, self.denselayer2])

    def construct(self, x):
        x = self.layers(x)
        return x
python 复制代码
# 表示上采样块,这是USNet的一部分,用于将特征图的尺寸增大。
class UpConvBlock(nn.Cell):
    def __init__(self, in_features, up_scale):
        super(UpConvBlock, self).__init__()
         # 定义上采样的因子,默认为2
        self.up_factor = 2
        # 定义一个常量特征数,通常用于控制输出通道的数量
        self.constant_features = 16
        
        layers = self.make_deconv_layers(in_features, up_scale)
        
        assert layers is not None, layers
        self.features = nn.SequentialCell(*layers)
    # # 构建上采样层
    def make_deconv_layers(self, in_features, up_scale):
        layers = []
        all_pads = [0, 0, 1, 3, 7]
        # 根据up_scale循环创建相应的层
        # 逐步放大
        for i in range(up_scale):
            # 定义卷积核大小
            kernel_size = 2 ** up_scale
            # 获取填充大小
            pad = all_pads[up_scale]
            # 计算输出维度
            out_features = self.compute_out_features(i, up_scale)
            # 创建反卷积层
            layers.append(nn.Conv2d(in_features, out_features, 1, has_bias=True))
            layers.append(nn.ReLU())
            layers.append(nn.Conv2dTranspose(
                out_features, out_features, kernel_size,
                stride=2, padding=pad, pad_mode="pad",
                has_bias=True, weight_init=init.XavierNormal()))
            # 更新通道数
            in_features = out_features
        return layers
    # 计算当前输出通道数
    def compute_out_features(self, idx, up_scale):
        return 1 if idx == up_scale - 1 else self.constant_features
    def construct(self, x):
        return self.features(x)
python 复制代码
# 单个卷积块,包含Conv和BatchNorm
class SingleConvBlock(nn.Cell):
    def __init__(self, in_features, out_features, stride, use_bs=True):
        super().__init__()
        self.use_batch_norm = use_bs
        self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride, pad_mode='pad', has_bias=True, weight_init=init.XavierNormal())
        self.bn = nn.BatchNorm2d(out_features)
        
    def construct(self, x):
        x = self.conv(x)
        if self.use_batch_norm:
            x = self.bn(x)
        return x
python 复制代码
# 双卷积块
class DoubleConvBlock(nn.Cell):
    def __init__(self, in_features, mid_features,
                 out_features=None,
                 stride=1,
                 use_act=True):
        super(DoubleConvBlock, self).__init__()

        self.use_act = use_act
        if out_features is None:
            out_features = mid_features
        self.conv1 = nn.Conv2d(
            in_features,
            mid_features,
            3,
            padding=1,
            stride=stride,
            pad_mode="pad",
            has_bias=True,
            weight_init=init.XavierNormal())
        self.bn1 = nn.BatchNorm2d(mid_features)
        self.conv2 = nn.Conv2d(
            mid_features,
            out_features,
            3,
            padding=1,
            pad_mode="pad",
            has_bias=True,
            weight_init=init.XavierNormal())
        self.bn2 = nn.BatchNorm2d(out_features)
        self.relu = nn.ReLU()

    def construct(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        if self.use_act:
            x = self.relu(x)
        return x
python 复制代码
# 自定义最大汇聚层
class maxpooling(nn.Cell):
    def __init__(self):
        super(maxpooling, self).__init__()
        self.pad = nn.Pad(((0,0),(0,0),(1,1),(1,1)), mode="SYMMETRIC")
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid')

    def construct(self, x):
        x = self.pad(x)
        x = self.maxpool(x)
        return x
python 复制代码
# 组件DexiNed网络
class DexiNed(nn.Cell):

    def __init__(self):
        super(DexiNed, self).__init__()
        #  DoubleConvBlock(双卷积块)实例,用于构建DexiNed的编码部分。
        # 处理输入图像和第一层的输出。
        self.block_1 = DoubleConvBlock(3, 32, 64, stride=2,)
        self.block_2 = DoubleConvBlock(64, 128, use_act=False)
        # 于实现DexiNed中的密集连接层
        # 用于提取多尺度特征
        self.dblock_3 = _DenseBlock(2, 128, 256)  # [128,256,100,100]
        self.dblock_4 = _DenseBlock(3, 256, 512)
        self.dblock_5 = _DenseBlock(3, 512, 512)
        self.dblock_6 = _DenseBlock(3, 512, 256)
        # 最大池化层,用于下采样操作。
        self.maxpool = maxpooling()
        # 用于从不同层次提取特征。
        self.side_1 = SingleConvBlock(64, 128, 2)
        self.side_2 = SingleConvBlock(128, 256, 2)
        self.side_3 = SingleConvBlock(256, 512, 2)
        self.side_4 = SingleConvBlock(512, 512, 1)
        self.side_5 = SingleConvBlock(512, 256, 1)  
        # 用于准备输入到密集块的数据。
        # right skip connections, figure in Journal paper
        self.pre_dense_2 = SingleConvBlock(128, 256, 2)
        self.pre_dense_3 = SingleConvBlock(128, 256, 1)
        self.pre_dense_4 = SingleConvBlock(256, 512, 1)
        self.pre_dense_5 = SingleConvBlock(512, 512, 1)
        self.pre_dense_6 = SingleConvBlock(512, 256, 1)
        # 上采样块,用于恢复特征图的尺寸。
        self.up_block_1 = UpConvBlock(64, 1)
        self.up_block_2 = UpConvBlock(128, 1)
        self.up_block_3 = UpConvBlock(256, 2)
        self.up_block_4 = UpConvBlock(512, 3)
        self.up_block_5 = UpConvBlock(512, 4)
        self.up_block_6 = UpConvBlock(256, 4)
        # 单卷积块,用于将多个尺度的特征图融合成一个单一的输出。
        self.block_cat = SingleConvBlock(6, 1, stride=1, use_bs=False)
    # 如果张量的形状与目标形状不匹配,则使用双线性插值进行调整;否则直接返回张量。
    def slice(self, tensor, slice_shape):
        t_shape = tensor.shape
        height, width = slice_shape
        if t_shape[-1] != slice_shape[-1]:
            new_tensor = ops.interpolate(
                tensor,
                sizes=(height, width),
                mode='bilinear',
                coordinate_transformation_mode="half_pixel")
        else:
            new_tensor = tensor
        return new_tensor

    def construct(self, x):
        # 确保输入张量 x 是四维的。
        assert x.ndim == 4, x.shape
        # 通过 block_1 处理输入 x,并通过 side_1 提取特征。
        # Block 1
        block_1 = self.block_1(x)
        block_1_side = self.side_1(block_1)
        # 通过 block_2 处理 block_1 的输出,然后通过 maxpool 降采样,并与 block_1_side 相加。再通过 side_2 提取特征。
        # Block 2
        block_2 = self.block_2(block_1)
        block_2_down = self.maxpool(block_2)
        block_2_add = block_2_down + block_1_side
        block_2_side = self.side_2(block_2_add)
        # 通过 pre_dense_3 处理 block_2_down,并将其与 block_2_add 一起传递给 dblock_3。
        # 然后通过 maxpool 降采样,并与 block_2_side 相加。再通过 side_3 提取特征。
        # Block 3
        block_3_pre_dense = self.pre_dense_3(block_2_down)
        block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense])
        block_3_down = self.maxpool(block_3)  # [128,256,50,50]
        block_3_add = block_3_down + block_2_side
        block_3_side = self.side_3(block_3_add)
        # 通过 pre_dense_2 和 pre_dense_4 处理 block_2_down 和 block_3_down,并将它们相加后传递给 dblock_4。
        # 然后通过 maxpool 降采样,并与 block_3_side 相加。再通过 side_4 提取特征。
        # Block 4
        block_2_resize_half = self.pre_dense_2(block_2_down)
        block_4_pre_dense = self.pre_dense_4(
            block_3_down + block_2_resize_half)
        block_4, _ = self.dblock_4([block_3_add, block_4_pre_dense])
        block_4_down = self.maxpool(block_4)
        block_4_add = block_4_down + block_3_side
        block_4_side = self.side_4(block_4_add)
        # 通过 pre_dense_5 处理 block_4_down,并将其与 block_4_add 一起传递给 dblock_5。然后与 block_4_side 相加。
        # Block 5
        block_5_pre_dense = self.pre_dense_5(
            block_4_down)  # block_5_pre_dense_512 +block_4_down
        block_5, _ = self.dblock_5([block_4_add, block_5_pre_dense])
        block_5_add = block_5 + block_4_side
        # 通过 pre_dense_6 处理 block_5,并将其与 block_5_add 一起传递给 dblock_6。
        # Block 6
        block_6_pre_dense = self.pre_dense_6(block_5)
        block_6, _ = self.dblock_6([block_5_add, block_6_pre_dense])
        # upsampling blocks
        # 对每个块的输出进行上采样,恢复特征图的尺寸。
        out_1 = self.up_block_1(block_1)
        out_2 = self.up_block_2(block_2)
        out_3 = self.up_block_3(block_3)
        out_4 = self.up_block_4(block_4)
        out_5 = self.up_block_5(block_5)
        out_6 = self.up_block_6(block_6)
        results = [out_1, out_2, out_3, out_4, out_5, out_6]
        # 将所有上采样的输出拼接在一起,并通过 block_cat 进行最后的融合,生成最终的边缘图。
        # concatenate multiscale outputs
        op = ops.Concat(1)
        block_cat = op(results)

        block_cat = self.block_cat(block_cat)  # Bx1xHxW
        results.append(block_cat)
        # 返回包含多个尺度的边缘图和最终融合后的边缘图的结果列表。
        return results

使用DexiNed进行推理

python 复制代码
'''将输入图像规格化到指定范围'''
def image_normalization(img, img_min=0, img_max=255, epsilon=1e-12):
    img = np.float32(img)
    img = (img - np.min(img)) * (img_max - img_min) / \
        ((np.max(img) - np.min(img)) + epsilon) + img_min
    return img
python 复制代码
'''对DexiNed模型的输出数据进行后处理'''
# DexiNed 模型会输出多个尺度的边缘图(results列表中包含了多个上采样的结果),
# 这个函数将这些边缘图进行融合,并应用一些图像处理技术来生成最终的边缘检测结果。
def fuse_DNoutput(img):
    edge_maps = []
    tensor = img
    for i in tensor:
        sigmoid = ops.Sigmoid()
        output = sigmoid(i).numpy()
        edge_maps.append(output)
    tensor = np.array(edge_maps)
    idx = 0
    tmp = tensor[:, idx, ...]
    tmp = np.squeeze(tmp)
    preds = []
    for i in range(tmp.shape[0]):
        tmp_img = tmp[i]
        tmp_img = np.uint8(image_normalization(tmp_img))
        tmp_img = cv2.bitwise_not(tmp_img)
        preds.append(tmp_img)
        if i == 6:
            fuse = tmp_img
            fuse = fuse.astype(np.uint8)
    idx += 1
    return fuse
python 复制代码
"""DexiNed 检测."""

def test(imgs,dexined_ckpt):
    
    if not os.path.isfile(dexined_ckpt):
        raise FileNotFoundError(
            f"Checkpoint file not found: {dexined_ckpt}")
    print(f"DexiNed ckpt path : {dexined_ckpt}")
    # os.makedirs(dexined_output_dir, exist_ok=True)
    model = DexiNed()
    # model = auto_mixed_precision(model, 'O2')
    ms.load_checkpoint(dexined_ckpt, model)
    model.set_train(False)
    preds = []
    origin = []
    total_duration = []
    print('Start dexined testing....')
    for img in imgs.create_dict_iterator():
        filename = str(img["names"])[2:-2]
        # print(filename)
        # output_dir_f = os.path.join(dexined_output_dir, filename)
        image = img["data"]
        origin.append(filename)
        end = time.perf_counter()
        # 使用DexiNed进行预测
        pred = model(image)
        # 获取图片宽高
        img_h = img["img_shape"][0, 0]
        img_w = img["img_shape"][0, 1]
        # 调用 fuse_DNoutput 函数对模型的输出进行后处理。
        pred = fuse_DNoutput(pred)
        # 将处理后的边缘图调整为原始图像的尺寸。
        dexi_img = cv2.resize(
            pred, (int(img_w.asnumpy()), int(img_h.asnumpy())))
        # cv2.imwrite("output.jpg", dexi_img)
        tmp_duration = time.perf_counter() - end
        total_duration.append(tmp_duration)
        preds.append(pred)
    total_duration_f = np.sum(np.array(total_duration))
    print("FPS: %f.4" % (len(total_duration) / total_duration_f))
    # 返回处理后的边缘图列表 preds 和原始图像文件名列表 origin。
    return preds,origin

DexiNed结合CycleGAN对壁画进行修复

python 复制代码
import os
import numpy as np
from PIL import Image
import mindspore.dataset as ds
import matplotlib.pyplot as plt
import mindspore.dataset.vision as vision
from mindspore.dataset import transforms
from mindspore import load_checkpoint, load_param_into_net

# 加载权重文件
def load_ckpt(net, ckpt_dir):
    param_GA = load_checkpoint(ckpt_dir)
    load_param_into_net(net, param_GA)

    
#模型参数地址
g_a_ckpt = './ckpt/G_A_120.ckpt'
dexined_ckpt = "./ckpt/dexined.ckpt"

#图片输入地址
img_path='./ckpt/jt'
#输出地址
save_path='./result'

load_ckpt(net_rg_a, g_a_ckpt)

os.makedirs(save_path, exist_ok=True)
# 图片推理
fig = plt.figure(figsize=(16, 4), dpi=64)
def eval_data(dir_path, net, a):
    my_dataset = Test_Dataset(
        dir_path, mean_bgr=[167.15, 146.07, 124.62], image_size=512)
    
    dataset = ds.GeneratorDataset(
        my_dataset, column_names=[
            "data", "names", "img_shape"])
    dataset = dataset.batch(1, drop_remainder=True)
    # 使用DexiNed将原图转为线稿图
    preds ,origin= test(dataset,dexined_ckpt)
    for i, data in enumerate(preds):
        # 读取线稿图
        img =ms.Tensor((np.array([data,data,data])/255-0.5)*2).unsqueeze(0)
        # 将线稿图放入生成器,生成假图
        fake = net(img.to(ms.float32))
        fake = (fake[0] * 0.5 * 255 + 0.5 * 255).astype(np.uint8).transpose((1, 2, 0))
        # 找到线稿图对应的原图,用作后续输出
        img = (Image.open(os.path.join(img_path,origin[i])).convert('RGB'))
        # 保存CycleGAN生成的结果
        fake_pil=Image.fromarray(fake.asnumpy())
        fake_pil.save(f"{save_path}/{i}.jpg")
        # 展示推理结果
        if i<8:
            fig.add_subplot(2, 8, min(i+1+a, 16))
            plt.axis("off")
            plt.imshow(np.array(img))

            fig.add_subplot(2, 8, min(i+9+a, 16))
            plt.axis("off")
            plt.imshow(fake.asnumpy())

eval_data(img_path,net_rg_a, 0)

plt.show()

输出结果如下:

详细可以参考MindSpore官方教学视频:

基于MindSpore实现CycleGAN壁画修复_哔哩哔哩_bilibili

相关推荐
斯多葛的信徒2 分钟前
看看你的电脑可以跑 AI 模型吗?
人工智能·语言模型·电脑·llama
正在走向自律3 分钟前
AI 写作(六):核心技术与多元应用(6/10)
人工智能·aigc·ai写作
AI科技大本营3 分钟前
Anthropic四大专家“会诊”:实现深度思考不一定需要多智能体,AI完美对齐比失控更可怕!...
人工智能·深度学习
Cc不爱吃洋葱3 分钟前
如何本地部署AI智能体平台,带你手搓一个AI Agent
人工智能·大语言模型·agent·ai大模型·ai agent·智能体·ai智能体
网安打工仔4 分钟前
斯坦福李飞飞最新巨著《AI Agent综述》
人工智能·自然语言处理·大模型·llm·agent·ai大模型·大模型入门
AGI学习社4 分钟前
2024中国排名前十AI大模型进展、应用案例与发展趋势
linux·服务器·人工智能·华为·llama
AI_Tool4 分钟前
纳米AI搜索官网 - 新一代智能答案引擎
人工智能·搜索引擎
Damon小智5 分钟前
合合信息DocFlow产品解析与体验:人人可搭建的AI自动化单据处理工作流
图像处理·人工智能·深度学习·机器学习·ai·自动化·docflow
小虚竹5 分钟前
用AI辅导侄女大学物理的质点运动学问题
人工智能·chatgpt
猿类崛起@6 分钟前
百度千帆大模型实战:AI大模型开发的调用指南
人工智能·学习·百度·大模型·产品经理·大模型学习·大模型教程