如何去除图片马赛克?

本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!

一、前言

对图像不了解的人时常妄想去除马赛克是可以实现的,严格意义来说这确实是无法实现的。而深度学习是出现,让去除马赛克成为可能。

为了理解去除马赛克有多难,我们需要知道马赛克是什么。观感上,马赛克就是方块感。当我们观察图像像素时, 马赛克表现为下图的情况:

原图右下角有十字,而添加马赛克后右下角一片都变成了同一像素,如果我们没保留原图,那么我们无法还原,也不知道是否还原了原图。因为原图已经被破坏了,这也是为什么马赛克是不可修复的。

那神经网络又是如何让修复成为可能呢?其实无论什么方式的修复,都是一种估计,而不是真正的修复。神经网络去除马赛克的操作其实是生成马赛克那部分内容,然后替代马赛克,从而达到修复的效果。

这种修复并不是还原,而是想象。假如我们对一张人脸打了马赛克,神经网络可以去除马赛克,但是去除后的人脸不再是原来那个人了。

二、实现原理

2.1 自编码器

图像修复的方法有很多,比如自编码器。自编码器是一种自监督模型,结构简单,不需要人为打标,收敛迅速。其结构如图:

编码器部分就是用于下采样的卷积网络,编码器会把图片编码成一个向量,而解码器则利用转置卷积把编码向量上采样成和原图大小一致的图片,最后我们把原图和生成结果的MSE作为损失函数进行优化。当模型训练好后,就可以用编码器对图片进行编码。

2.2 自编码器去除马赛克

那自编码器和去除马赛克有什么联系呢?其实非常简单,就是原本我们是输入原图,期望解码器能输出原图。这是出于我们希望模型学习如何编码图片的原图。而现在我们想要模型去除马赛克,此时我们要做的就是把马赛克图片作为输入,而原图作为输出,这样来训练就可以达到去除马赛克的效果了:

关于关于这种实现可以参考:juejin.cn/post/721068...

2.3 自编码器的问题

自编码器有个很明显的问题,就是图片经过编码器后会损失信息,而解码器的结果自然也会存在一些问题。这样既达不到去除马赛克的功能,连还原的原图都有一些模糊。

这里可以利用FPN的思想来改进,当自编码器加入FPN后,就得到了UNet网络结构。

2.4 UNet网络

UNet结构和自编码器类似,是一个先下再上的结构。和自编码器不同的时,UNet会利用编码器的每个输出,将各个输出与解码器的输入进行concatenate,这样就能更好地保留原图信息。其结构如下图:

UNet原本是用于图像分割的网络,这里我们用它来去除马赛克。

在UNet中,有几个部分我们分别来看看。

2.4.1 ConvBlock

在UNet中,有大量连续卷积的操作,这里我们作为一个Block(蓝色箭头),它可以实现为一个层,用PyTorch实现如下:

scss 复制代码
class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )

    def forward(self, inputs):
        return self.model(inputs)

这里其实就是两次卷积操作,这里的目的是提取当前感受野的特征。

2.4.2 ConvDown

经过连续卷积后,会使用卷积网络对图片进行下采样,这里把stride设置为2即可让图片缩小为原来的1/2。我们同样可以实现为层:

ruby 复制代码
class ConvDown(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(channels, channels, 3, 2, 1),
            nn.BatchNorm2d(channels),
            nn.ReLU()
        )

    def forward(self, inputs):
        return self.model(inputs)

这里只有一个卷积,而且stride被设置为了2。

2.4.3 ConvUp

接下来是解码器部分,这里多了一个上采用的操作,我们可以用转置卷积完成,代码如下:

scss 复制代码
class ConvUp(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(channels, channels // 2, 2, 2),
            nn.BatchNorm2d(channels // 2),
            nn.ReLU()
        )

    def forward(self, inputs):
        return self.model(inputs)

上面是层可以把图片尺寸扩大为2倍,同时把特征图数量缩小到1/2。这里缩小特征图的操作是为了concatenate操作,后面详细说。

三、完整实现

首先,导入需要用的模块:

javascript 复制代码
import os
import random
import torch
from torch import nn
from torch import optim
from torch.utils import data
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision.transforms import ToTensor
from PIL import Image, ImageDraw, ImageFilter
from torchvision.utils import make_grid

下面开始具体实现。

3.1 创建Dataset

首先创建本次任务需要的数据集,分布大致相同的图片即可,代码如下:

python 复制代码
class ReConstructionDataset(data.Dataset):
    def __init__(self, data_dir=r"G:/datasets/lbxx", image_size=64):
        self.image_size = image_size
        # 图像预处理
        self.trans = transforms.Compose([
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.ToTensor(),
            # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])
        # 保持所有图片的路径
        self.image_paths = []
        # 读取根目录,把所有图片路径放入image_paths
        for root, dirs, files in os.walk(data_dir):
            for file in files:
                self.image_paths.append(os.path.join(root, file))

    def __getitem__(self, item):
        # 读取图片,并预处理
        image = Image.open(self.image_paths[item])
        return self.trans(self.create_blur(image)), self.trans(image)

    def __len__(self):
        return len(self.image_paths)


    @staticmethod
    def create_blur(image, return_mask=False, box_size=200):
        mask = Image.new('L', image.size, 255)
        draw = ImageDraw.Draw(mask)
        upper_left_corner = (random.randint(0, image.size[0] - box_size), random.randint(0, image.size[1] - box_size))
        lower_right_corner = (upper_left_corner[0] + box_size, upper_left_corner[1] + box_size)
        draw.rectangle([lower_right_corner, upper_left_corner], fill=0)
        masked_image = Image.composite(image, image.filter(ImageFilter.GaussianBlur(15)), mask)
        if return_mask:
            return masked_image, mask
        else:
            return masked_image

Dataset的实现与以往基本一致,实现init、getitem、len方法,这里我们还实现了一个create_blur方法,该方法用于生成矩形马赛克(实际上是高斯模糊)。下面是create_blur方法生成的图片:

3.2 网络构建

这里我们需要使用前面的几个子单元,先实现编码器,代码如下:

ini 复制代码
class UNetEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.blk0 = ConvBlock(3, 64)
        self.down0 = ConvDown(64)
        self.blk1 = ConvBlock(64, 128)
        self.down1 = ConvDown(128)
        self.blk2 = ConvBlock(128, 256)
        self.down2 = ConvDown(256)
        self.blk3 = ConvBlock(256, 512)
        self.down3 = ConvDown(512)
        self.blk4 = ConvBlock(512, 1024)

    def forward(self, inputs):
        f0 = self.blk0(inputs)
        d0 = self.down0(f0)
        f1 = self.blk1(d0)
        d1 = self.down1(f1)
        f2 = self.blk2(d1)
        d2 = self.down2(f2)
        f3 = self.blk3(d2)
        d3 = self.down3(f3)
        f4 = self.blk4(d3)
        return f0, f1, f2, f3, f4

这里就是ConvBlok和ConvDown的n次组合,最终会得到一个1024×4×4的特征图。在forward中,我们返回了5个ConvBlok返回的结果,因为在解码器中我们需要全部使用。

接下来是解码器部分,这里与编码器相反,代码如下:

ini 复制代码
class UNetDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.up3 = ConvUp(1024)
        self.blk3 = ConvBlock(1024, 512)
        self.up2 = ConvUp(512)
        self.blk2 = ConvBlock(512, 256)
        self.up1 = ConvUp(256)
        self.blk1 = ConvBlock(256, 128)
        self.up0 = ConvUp(128)
        self.blk0 = ConvBlock(128, 64)
        self.last_conv = nn.Conv2d(64, 3, 3, 1, 1)

    def forward(self, inputs):
        f0, f1, f2, f3, f4 = inputs
        u3 = self.up3(f4)
        df2 = self.blk3(torch.concat((f3, u3), dim=1))
        u2 = self.up2(df2)
        df1 = self.blk2(torch.concat((f2, u2), dim=1))
        u1 = self.up1(df1)
        df0 = self.blk1(torch.concat((f1, u1), dim=1))
        u0 = self.up0(df0)
        f = self.blk0(torch.concat((f0, u0), dim=1))
        return torch.tanh(self.last_conv(f))

解码器的inputs为编码器的5组特征图,在forward时需要与上采样结果concatenate。

最后,整个网络组合起来,代码如下:

ruby 复制代码
class ReConstructionNetwork(nn.Module):

    def __init__(self):
        super().__init__()
        self.encoder = UNetEncoder()
        self.decoder = UNetDecoder()

    def forward(self, inputs):
        fs = self.encoder(inputs)
        return self.decoder(fs)

3.3 网络训练

现在各个部分都完成了,可以开始训练网络:

scss 复制代码
device = "cuda" if torch.cuda.is_available() else "cpu"


def train(model, dataloader, optimizer, criterion, epochs):
    model = model.to(device)
    for epoch in range(epochs):
        for iter, (masked_images, images) in enumerate(dataloader):
            masked_images, images = masked_images.to(device), images.to(device)
            outputs = model(masked_images)
            loss = criterion(outputs, images)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if (iter + 1) % 100 == 1:
                print("epoch: %s, iter: %s, loss: %s" % (epoch + 1, iter + 1, loss.item()))
                with torch.no_grad():
                    outputs = make_grid(outputs)
                    img = outputs.cpu().numpy().transpose(1, 2, 0)
                    plt.imshow(img)
                    plt.show()
        torch.save(model.state_dict(), '../outputs/reconstruction.pth')


if __name__ == '__main__':
    dataloader = data.DataLoader(ReConstructionDataset(r"G:\datasets\lbxx"), 64)
    unet = ReConstructionNetwork()
    optimizer = optim.Adam(auto_encoder.parameters(), lr=0.0002)
    criterion = nn.MSELoss()
    train(unet, dataloader, optimizer, criterion, 20)

训练完成后,就可以用来去除马赛克了,代码如下:

ini 复制代码
dataloader = data.DataLoader(ReConstructionDataset(r"G:\datasets\lbxx"), 64, shuffle=True)
unet = ReConstructionNetwork().to(device)
unet.load_state_dict(torch.load('../outputs/reconstruction.pth'))
for masked_images, images in dataloader:
    masked_images, images = masked_images.to(device), images.to(device)
    with torch.no_grad():
        outputs = unet(masked_images)
        outputs = torch.concatenate([images, masked_images, outputs], dim=-1)
        outputs = make_grid(outputs)
        img = outputs.cpu().numpy().transpose(1, 2, 0)
        img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
        Image.fromarray(img).show()
       

下面是生成结果。左侧为原图,中间为添加马赛克后的图片,右侧则是去除马赛克后的结果:

整体来说效果比较不错。本文的方法不只可以用来去除马赛克,还可以完成图像重构。比如老化的图片、被墨汁污染的图片等,都可以用本文的方法完成重构。另外,本文的数据有限,实现效果并不通用,有需求的读者可以移步CodeFormer项目:github.com/sczhou/Code...

相关推荐
Swift社区13 分钟前
LeetCode - #139 单词拆分
算法·leetcode·职场和发展
Kent_J_Truman44 分钟前
greater<>() 、less<>()及运算符 < 重载在排序和堆中的使用
算法
IT 青年1 小时前
数据结构 (1)基本概念和术语
数据结构·算法
Dong雨2 小时前
力扣hot100-->栈/单调栈
算法·leetcode·职场和发展
SoraLuna2 小时前
「Mac玩转仓颉内测版24」基础篇4 - 浮点类型详解
开发语言·算法·macos·cangjie
liujjjiyun2 小时前
小R的随机播放顺序
数据结构·c++·算法
¥ 多多¥2 小时前
c++中mystring运算符重载
开发语言·c++·算法
trueEve3 小时前
SQL,力扣题目1369,获取最近第二次的活动
算法·leetcode·职场和发展
天若有情6733 小时前
c++框架设计展示---提高开发效率!
java·c++·算法
ahadee3 小时前
蓝桥杯每日真题 - 第19天
c语言·vscode·算法·蓝桥杯