详解pix2pix图像转译模型——原理+效果图

本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!

beginning

上次给大家详细介绍了CycleGAN,并动手实现了风格迁移,还没看过的小伙伴赶紧康康叭➡基于MMGeneration实现CycleGAN图像风格迁移。一提起CycleGAN,那必会想起的就是pix2pix。作为图像生成和图像转译领域的两个必学经典,同时也是让人拍案叫绝的两个算法(具体有多绝,往后看就明白辽)。pix2pix在语义标签图转真实图片、简笔画转真图、黑白图像上色、卫星航拍图转地图等图像转译任务上表现很是优秀。话不多说,如果你也对此感兴趣,想一睹它的风采,让我们一起愉快的学习叭🎈🎈🎈


1.pix2pix简介

pix2pix是一种基于条件式生成对抗网络(CGAN)的图像转译模型,而条件式生成抵抗网络是生成对抗网络的一种扩展,它通过在生成器和判别器中引入条件信息来实现有条件的图像生成。生成器采用U-Net网络结构,融合底层细粒度特征和高层抽象;判别器采用patchGAN网络结构,在图块尺度提取纹理等高频信息。(什么?看不懂?没关系滴,往后看都有对应部分的详细解释说明嗷)🌴🌴🌴

pix2pix不仅论文写得好,主页做的漂亮,最令人难忘的是它还有一个交互式的demo,其中pix2pix的标志性成果------简笔画的猫转成真猫,建筑外立面语义设计图转真实图,简笔画的鞋转真鞋等,你都可以在这里轻松上手玩转,这么好玩好用,赶紧放个链接:pix2pix交互式趣味demo🧸🧸🧸

那么简笔画猫转成真猫到底是一个什么原理腻,可以这样理解:你可以获取很多真猫的图片,用opencv的边缘提取,把每一张图片的边缘都给提取出来,构建一个像素到像素的映射数据集,也就是数据集包含两类图片,一类是边缘轮廓简笔画,另一类是真猫的图片,它们俩是一一对应的关系,所以pix2pix解决的是一个像素配对的图像转译问题,那么我们上次介绍的cyclegan呢解决的是一个非配对的图像转译问题。同样,这里也能用cyclegan来解决这些问题。image translation领域非常的好玩,既可以用配对的数据集去训练,也可以用不配对的数据集。🎉🎉🎉

pix2pix是2017年的论文,现在看来比较老了,如果你现在还想做跟图像转译相关的项目的话,可以用更好更新的算法,比如UGATIT、StarGAN等。当然用pix2pix也是完全可以滴,但是要注意pix2pix使用起来可能会容易模式崩溃,训练不太稳定喔🤔🤔🤔


2.pix2pix原理详解

pix2pix使用的是条件式生成对抗网络,条件式指的是对生成器不光要输入一个噪声,还要输入一张原始图像,因为你就要对这张图像做转译嘛;对于判别器也是一样的,不仅要把真假图像输入给它,让它判别,还要把输入的这张原始图片也输入给它,让判别器去判断这一对图像是真是假。不管是生成器还是判别器,都是由这样的模块组成的:***convolution-BatchNorm-ReLu(CBR模块)***🌈🌈🌈

2.1损失函数

条件式GAN的目标函数(也叫损失函数)表达如下: <math xmlns="http://www.w3.org/1998/Math/MathML"> L c G A N ( G , D ) = E x , y [ l o g D ( x , y ) ] + E x , y [ l o g ( 1 − D ( x , G ( x , z ) ) ) ] L_{cGAN}(G,D)=E_{x,y}[logD(x,y)]+E_{x,y}[log(1-D(x,G(x,z)))] </math>LcGAN(G,D)=Ex,y[logD(x,y)]+Ex,y[log(1−D(x,G(x,z)))]

先看判别器,当给判别器真图的时候,判别器认为真图是真图的概率要足够的大,也就是式子的第一项 <math xmlns="http://www.w3.org/1998/Math/MathML"> D ( x , y ) D(x,y) </math>D(x,y)要足够的大;如果是假图的话,判别器认为假图是真图的概率应该要足够的小,即 <math xmlns="http://www.w3.org/1998/Math/MathML"> D ( x , G ( x , z ) ) D(x,G(x,z)) </math>D(x,G(x,z))要尽可能的小。所以判别器的优化目标是最大化第一项和第二项,也就是最大化整个函数。注意!这里的式子表达的是括号里的期望加上括号里的期望,而不是这两项相乘喔!🌻🌻🌻

而生成器呢,其实就是使得 <math xmlns="http://www.w3.org/1998/Math/MathML"> D ( x , G ( x , z ) ) D(x,G(x,z)) </math>D(x,G(x,z))尽可能的大,即让判别器以为假图是真图的概率越大越好;生成器无法控制第一项,因为第一项 <math xmlns="http://www.w3.org/1998/Math/MathML"> D ( x , y ) D(x,y) </math>D(x,y)是判别器认为真图是真图的概率,这对于生成器来说只能算得上是一个常数。所以生成器的目标就是要努力使得这个函数最小化

判别器在努力最大化这个函数,生成器在努力最小化这个函数(严谨的说应该是最小化第二项)。最终的结果就是一个双人的极大极小零和博弈。(讲到这里,这个损失函数大家应该都理解了叭,很多朋友第一次学条件GAN的时候对这个函数就整不明白)🌟🌟🌟

2.2生成器

生成器是一个U-Net网络结构,也是图像分割领域常用的一个网络,特点是结构像U型一样。U-Net的结构先是一个编码器Encoder,然后是一个解码器decoder。编码器的话先逐层下采样,中间是一个维度很小瓶颈层,再逐层的上采样,上采样和下采样是对称的,输入一张图再生成另外一张大小一样的图。这样的话所有的信息流都会被逼得流过瓶颈层,而瓶颈层的维度又很小,所以就不可避免地会带来信息的丢失。为了避免这一问题呢,U-Net在对称的层都引入了一个跳转连接skip connection,通俗理解skip connection就是抄近道,能够把底层的像素级别的特征融合到高层,这样就可以充分利用好底层特征(底层特征是极其重要滴)。如何对信息进行融合腻,直接沿通道方向摞起来就可以啦,摞起来再输入下一层进行卷积处理🍭🍭🍭

对称的镜像的这个U-Net层呢是可以底层和高层直接连接的,所以特征可以从底层汇聚到高层,这样既有底层的边缘、转角、轮廓、颜色等信息,又有高层语义特征的这种纹理等高级语义特征。所以它的生成效果就好。生成器它的使命就是努力生成让判别器误以为真的假图像。

2.3判别器

判别器呢在努力地分辨照片的真假。生成器生成的假照片和这个输入会组成一对图像喂给判别器,判别器如果足够优秀的话,应该会给这一对图像判成假;若再把真实的照片和这个输入一起喂给判别器的话,好的判别器应该会给这一对判别为真。所以发现了没,判别器其实也是一个二分类模型,只不过它输入的是图像对,而不是单张图像,它判断的是这一对是真是假🤔🤔🤔

判别器的名字叫马尔科夫判别器,马尔科夫的意思是说patch之间是互相独立的,就符合马尔科夫假设嘛。为了让判别器去捕捉高级特征,就使用到了局部小图块的尺度进行判别,这就叫做PatchGAN。也就是把整个图分成n×n的网格,然后对每一个网格来做二分类,PatchGAN对每一小块的纹理、颜色辨别真伪,来判定里面的图是真图还是假图。这样的好处就是可以全卷积的运行每一张图像,把n×n的每个结果加起来做平均,就可以得到对整张图像的判别结果。

2.4优化方法

交替的训练生成器和判别器,在训练刚开始的时候,生成的图像非常假,也就是 <math xmlns="http://www.w3.org/1998/Math/MathML"> D ( x , G ( x , z ) ) D(x,G(x,z)) </math>D(x,G(x,z))非常小(因为判别器一眼就能看出来生成的图像是假图像),这样就导致了 <math xmlns="http://www.w3.org/1998/Math/MathML"> l o g ( 1 − D ( x , G ( x , z ) ) ) log(1-D(x,G(x,z))) </math>log(1−D(x,G(x,z)))非常接近于0,这时候函数是没有梯度的,不便于它学习。所以我们不是最小化 <math xmlns="http://www.w3.org/1998/Math/MathML"> l o g ( 1 − D ( x , G ( x , z ) ) ) log(1-D(x,G(x,z))) </math>log(1−D(x,G(x,z)))这个函数了,而是改成直接最大化这个函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> l o g D ( x , G ( x , z ) ) logD(x,G(x,z)) </math>logD(x,G(x,z)),这样的话梯度就很大了(想法是不是很不戳)🌟🌟🌟

训练过程中使用了Adam优化器(第一动量取0.5,第二动量取0.999),随机梯度下降。


学完上面的知识后,让我们来理解一下这张判断鞋子真假的流程图。小伙伴们能不能自己对着图捋一下工作流程腻🧐🧐🧐我来说了喔:生成器G输入的是噪声和一个输入图像,那么它生成的是一个假图像,把这个输入和假图像一起喂给判别器,如果判别器足够好,那么就应该把这一对图像判定为假;另外呢如果是真图像和对应的输入图像,输入给判别器,那么优秀的判别器应该给这一对图像判定为真。在这里条件式的作用就体现出来了,它会使得生成的图像既像真鞋,也得和输入的线稿相吻合,像素到像素得配对,所以叫pix2pix(揭秘了哈)


3.实战代码

在上次的CycleGAN中,给大家介绍了两种调用预训练模型的方式,一种是命令行方式,一种是Python API方式调用。这次就复习一下,用命令行的方式调用一下pix2pix。

<math xmlns="http://www.w3.org/1998/Math/MathML"> p i x 2 p i x 设计涂鸦转建筑立面图 \color{blue}{pix2pix 设计涂鸦转建筑立面图} </math>pix2pix设计涂鸦转建筑立面图:

python 复制代码
!python demo/translation_demo.py -h
!python demo/translation_demo.py \
        configs/pix2pix/pix2pix_vanilla_unet_bn_facades_b1x1_80k.py \
        https://download.openmmlab.com/mmgen/pix2pix/refactor/pix2pix_vanilla_unet_bn_1x1_80k_facades_20210902_170442-c0958d50.pth \
        tests/data/paired/test/3.jpg \
        --save-path outputs/E/E1_facades.jpg \
        --device cuda:0
show_img_from_path('tests/data/paired/test/3.jpg')

运行之后就会看到以下的图片:

python 复制代码
show_img_from_path('outputs/E/E1_facades.jpg')

得到:

<math xmlns="http://www.w3.org/1998/Math/MathML"> p i x 2 p i x 线描转鞋画 \color{blue}{pix2pix 线描转鞋画} </math>pix2pix线描转鞋画:

python 复制代码
!python demo/translation_demo.py \
        configs/pix2pix/pix2pix_vanilla_unet_bn_wo_jitter_flip_edges2shoes_b1x4_190k.py \
        https://download.openmmlab.com/mmgen/pix2pix/refactor/pix2pix_vanilla_unet_bn_wo_jitter_flip_1x4_186840_edges2shoes_convert-bgr_20210902_170902-0c828552.pth \
        tests/data/paired/test/33_AB.jpg \
        --save-path outputs/E/E2_shoes.jpg \
        --device cuda:0

如果报错_pickle.UnpicklingError: invalid load key, '<'.,说明 checkpoints 权重文件没有下载完整,可手动下载https://download.openmmlab.com/mmgen/pix2pix/refactor/pix2pix_vanilla_unet_bn_wo_jitter_flip_1x4_186840_edges2shoes_convert-bgr_20210902_170902-0c828552.pth至 checkpoints 目录

再运行下面的代码块:

python 复制代码
!python demo/translation_demo.py \
        configs/pix2pix/pix2pix_vanilla_unet_bn_wo_jitter_flip_edges2shoes_b1x4_190k.py \
        checkpoints/pix2pix_vanilla_unet_bn_wo_jitter_flip_1x4_186840_edges2shoes_convert-bgr_20210902_170902-0c828552.pth \
        tests/data/paired/test/33_AB.jpg \
        --save-path outputs/E/E2_shoes.jpg \
        --device cuda:0
python 复制代码
show_img_from_path('tests/data/paired/test/33_AB.jpg')

运行得到:

python 复制代码
show_img_from_path('outputs/E/E2_shoes.jpg')

上面的代码很简单,主要流程都是先指定configs配置文件,再指定一个pix2pix的checkpoint模型权重文件就可以啦🎉🎉🎉

ending

看到这里相信盆友们都对pix2pix有了一个全面深入的了解啦🌴🌴🌴很开心能把学到的知识以文章的形式分享给大家。如果你也觉得我的分享对你有所帮助,please一键三连嗷!!!下期见

相关推荐
IT古董6 分钟前
【漫话机器学习系列】017.大O算法(Big-O Notation)
人工智能·机器学习
凯哥是个大帅比6 分钟前
人工智能ACA(五)--深度学习基础
人工智能·深度学习
m0_7482329226 分钟前
DALL-M:基于大语言模型的上下文感知临床数据增强方法 ,补充
人工智能·语言模型·自然语言处理
szxinmai主板定制专家32 分钟前
【国产NI替代】基于FPGA的32通道(24bits)高精度终端采集核心板卡
大数据·人工智能·fpga开发
海棠AI实验室35 分钟前
AI的进阶之路:从机器学习到深度学习的演变(三)
人工智能·深度学习·机器学习
机器懒得学习1 小时前
基于YOLOv5的智能水域监测系统:从目标检测到自动报告生成
人工智能·yolo·目标检测
QQ同步助手1 小时前
如何正确使用人工智能:开启智慧学习与创新之旅
人工智能·学习·百度
AIGC大时代1 小时前
如何使用ChatGPT辅助文献综述,以及如何进行优化?一篇说清楚
人工智能·深度学习·chatgpt·prompt·aigc
流浪的小新1 小时前
【AI】人工智能、LLM学习资源汇总
人工智能·学习