详解pix2pix图像转译模型——原理+效果图

本文为稀土掘金技术社区首发签约文章,30天内禁止转载,30天后未获授权禁止转载,侵权必究!

beginning

上次给大家详细介绍了CycleGAN,并动手实现了风格迁移,还没看过的小伙伴赶紧康康叭➡基于MMGeneration实现CycleGAN图像风格迁移。一提起CycleGAN,那必会想起的就是pix2pix。作为图像生成和图像转译领域的两个必学经典,同时也是让人拍案叫绝的两个算法(具体有多绝,往后看就明白辽)。pix2pix在语义标签图转真实图片、简笔画转真图、黑白图像上色、卫星航拍图转地图等图像转译任务上表现很是优秀。话不多说,如果你也对此感兴趣,想一睹它的风采,让我们一起愉快的学习叭🎈🎈🎈


1.pix2pix简介

pix2pix是一种基于条件式生成对抗网络(CGAN)的图像转译模型,而条件式生成抵抗网络是生成对抗网络的一种扩展,它通过在生成器和判别器中引入条件信息来实现有条件的图像生成。生成器采用U-Net网络结构,融合底层细粒度特征和高层抽象;判别器采用patchGAN网络结构,在图块尺度提取纹理等高频信息。(什么?看不懂?没关系滴,往后看都有对应部分的详细解释说明嗷)🌴🌴🌴

pix2pix不仅论文写得好,主页做的漂亮,最令人难忘的是它还有一个交互式的demo,其中pix2pix的标志性成果------简笔画的猫转成真猫,建筑外立面语义设计图转真实图,简笔画的鞋转真鞋等,你都可以在这里轻松上手玩转,这么好玩好用,赶紧放个链接:pix2pix交互式趣味demo🧸🧸🧸

那么简笔画猫转成真猫到底是一个什么原理腻,可以这样理解:你可以获取很多真猫的图片,用opencv的边缘提取,把每一张图片的边缘都给提取出来,构建一个像素到像素的映射数据集,也就是数据集包含两类图片,一类是边缘轮廓简笔画,另一类是真猫的图片,它们俩是一一对应的关系,所以pix2pix解决的是一个像素配对的图像转译问题,那么我们上次介绍的cyclegan呢解决的是一个非配对的图像转译问题。同样,这里也能用cyclegan来解决这些问题。image translation领域非常的好玩,既可以用配对的数据集去训练,也可以用不配对的数据集。🎉🎉🎉

pix2pix是2017年的论文,现在看来比较老了,如果你现在还想做跟图像转译相关的项目的话,可以用更好更新的算法,比如UGATIT、StarGAN等。当然用pix2pix也是完全可以滴,但是要注意pix2pix使用起来可能会容易模式崩溃,训练不太稳定喔🤔🤔🤔


2.pix2pix原理详解

pix2pix使用的是条件式生成对抗网络,条件式指的是对生成器不光要输入一个噪声,还要输入一张原始图像,因为你就要对这张图像做转译嘛;对于判别器也是一样的,不仅要把真假图像输入给它,让它判别,还要把输入的这张原始图片也输入给它,让判别器去判断这一对图像是真是假。不管是生成器还是判别器,都是由这样的模块组成的:***convolution-BatchNorm-ReLu(CBR模块)***🌈🌈🌈

2.1损失函数

条件式GAN的目标函数(也叫损失函数)表达如下: <math xmlns="http://www.w3.org/1998/Math/MathML"> L c G A N ( G , D ) = E x , y [ l o g D ( x , y ) ] + E x , y [ l o g ( 1 − D ( x , G ( x , z ) ) ) ] L_{cGAN}(G,D)=E_{x,y}[logD(x,y)]+E_{x,y}[log(1-D(x,G(x,z)))] </math>LcGAN(G,D)=Ex,y[logD(x,y)]+Ex,y[log(1−D(x,G(x,z)))]

先看判别器,当给判别器真图的时候,判别器认为真图是真图的概率要足够的大,也就是式子的第一项 <math xmlns="http://www.w3.org/1998/Math/MathML"> D ( x , y ) D(x,y) </math>D(x,y)要足够的大;如果是假图的话,判别器认为假图是真图的概率应该要足够的小,即 <math xmlns="http://www.w3.org/1998/Math/MathML"> D ( x , G ( x , z ) ) D(x,G(x,z)) </math>D(x,G(x,z))要尽可能的小。所以判别器的优化目标是最大化第一项和第二项,也就是最大化整个函数。注意!这里的式子表达的是括号里的期望加上括号里的期望,而不是这两项相乘喔!🌻🌻🌻

而生成器呢,其实就是使得 <math xmlns="http://www.w3.org/1998/Math/MathML"> D ( x , G ( x , z ) ) D(x,G(x,z)) </math>D(x,G(x,z))尽可能的大,即让判别器以为假图是真图的概率越大越好;生成器无法控制第一项,因为第一项 <math xmlns="http://www.w3.org/1998/Math/MathML"> D ( x , y ) D(x,y) </math>D(x,y)是判别器认为真图是真图的概率,这对于生成器来说只能算得上是一个常数。所以生成器的目标就是要努力使得这个函数最小化

判别器在努力最大化这个函数,生成器在努力最小化这个函数(严谨的说应该是最小化第二项)。最终的结果就是一个双人的极大极小零和博弈。(讲到这里,这个损失函数大家应该都理解了叭,很多朋友第一次学条件GAN的时候对这个函数就整不明白)🌟🌟🌟

2.2生成器

生成器是一个U-Net网络结构,也是图像分割领域常用的一个网络,特点是结构像U型一样。U-Net的结构先是一个编码器Encoder,然后是一个解码器decoder。编码器的话先逐层下采样,中间是一个维度很小瓶颈层,再逐层的上采样,上采样和下采样是对称的,输入一张图再生成另外一张大小一样的图。这样的话所有的信息流都会被逼得流过瓶颈层,而瓶颈层的维度又很小,所以就不可避免地会带来信息的丢失。为了避免这一问题呢,U-Net在对称的层都引入了一个跳转连接skip connection,通俗理解skip connection就是抄近道,能够把底层的像素级别的特征融合到高层,这样就可以充分利用好底层特征(底层特征是极其重要滴)。如何对信息进行融合腻,直接沿通道方向摞起来就可以啦,摞起来再输入下一层进行卷积处理🍭🍭🍭

对称的镜像的这个U-Net层呢是可以底层和高层直接连接的,所以特征可以从底层汇聚到高层,这样既有底层的边缘、转角、轮廓、颜色等信息,又有高层语义特征的这种纹理等高级语义特征。所以它的生成效果就好。生成器它的使命就是努力生成让判别器误以为真的假图像。

2.3判别器

判别器呢在努力地分辨照片的真假。生成器生成的假照片和这个输入会组成一对图像喂给判别器,判别器如果足够优秀的话,应该会给这一对图像判成假;若再把真实的照片和这个输入一起喂给判别器的话,好的判别器应该会给这一对判别为真。所以发现了没,判别器其实也是一个二分类模型,只不过它输入的是图像对,而不是单张图像,它判断的是这一对是真是假🤔🤔🤔

判别器的名字叫马尔科夫判别器,马尔科夫的意思是说patch之间是互相独立的,就符合马尔科夫假设嘛。为了让判别器去捕捉高级特征,就使用到了局部小图块的尺度进行判别,这就叫做PatchGAN。也就是把整个图分成n×n的网格,然后对每一个网格来做二分类,PatchGAN对每一小块的纹理、颜色辨别真伪,来判定里面的图是真图还是假图。这样的好处就是可以全卷积的运行每一张图像,把n×n的每个结果加起来做平均,就可以得到对整张图像的判别结果。

2.4优化方法

交替的训练生成器和判别器,在训练刚开始的时候,生成的图像非常假,也就是 <math xmlns="http://www.w3.org/1998/Math/MathML"> D ( x , G ( x , z ) ) D(x,G(x,z)) </math>D(x,G(x,z))非常小(因为判别器一眼就能看出来生成的图像是假图像),这样就导致了 <math xmlns="http://www.w3.org/1998/Math/MathML"> l o g ( 1 − D ( x , G ( x , z ) ) ) log(1-D(x,G(x,z))) </math>log(1−D(x,G(x,z)))非常接近于0,这时候函数是没有梯度的,不便于它学习。所以我们不是最小化 <math xmlns="http://www.w3.org/1998/Math/MathML"> l o g ( 1 − D ( x , G ( x , z ) ) ) log(1-D(x,G(x,z))) </math>log(1−D(x,G(x,z)))这个函数了,而是改成直接最大化这个函数 <math xmlns="http://www.w3.org/1998/Math/MathML"> l o g D ( x , G ( x , z ) ) logD(x,G(x,z)) </math>logD(x,G(x,z)),这样的话梯度就很大了(想法是不是很不戳)🌟🌟🌟

训练过程中使用了Adam优化器(第一动量取0.5,第二动量取0.999),随机梯度下降。


学完上面的知识后,让我们来理解一下这张判断鞋子真假的流程图。小伙伴们能不能自己对着图捋一下工作流程腻🧐🧐🧐我来说了喔:生成器G输入的是噪声和一个输入图像,那么它生成的是一个假图像,把这个输入和假图像一起喂给判别器,如果判别器足够好,那么就应该把这一对图像判定为假;另外呢如果是真图像和对应的输入图像,输入给判别器,那么优秀的判别器应该给这一对图像判定为真。在这里条件式的作用就体现出来了,它会使得生成的图像既像真鞋,也得和输入的线稿相吻合,像素到像素得配对,所以叫pix2pix(揭秘了哈)


3.实战代码

在上次的CycleGAN中,给大家介绍了两种调用预训练模型的方式,一种是命令行方式,一种是Python API方式调用。这次就复习一下,用命令行的方式调用一下pix2pix。

<math xmlns="http://www.w3.org/1998/Math/MathML"> p i x 2 p i x 设计涂鸦转建筑立面图 \color{blue}{pix2pix 设计涂鸦转建筑立面图} </math>pix2pix设计涂鸦转建筑立面图:

python 复制代码
!python demo/translation_demo.py -h
!python demo/translation_demo.py \
        configs/pix2pix/pix2pix_vanilla_unet_bn_facades_b1x1_80k.py \
        https://download.openmmlab.com/mmgen/pix2pix/refactor/pix2pix_vanilla_unet_bn_1x1_80k_facades_20210902_170442-c0958d50.pth \
        tests/data/paired/test/3.jpg \
        --save-path outputs/E/E1_facades.jpg \
        --device cuda:0
show_img_from_path('tests/data/paired/test/3.jpg')

运行之后就会看到以下的图片:

python 复制代码
show_img_from_path('outputs/E/E1_facades.jpg')

得到:

<math xmlns="http://www.w3.org/1998/Math/MathML"> p i x 2 p i x 线描转鞋画 \color{blue}{pix2pix 线描转鞋画} </math>pix2pix线描转鞋画:

python 复制代码
!python demo/translation_demo.py \
        configs/pix2pix/pix2pix_vanilla_unet_bn_wo_jitter_flip_edges2shoes_b1x4_190k.py \
        https://download.openmmlab.com/mmgen/pix2pix/refactor/pix2pix_vanilla_unet_bn_wo_jitter_flip_1x4_186840_edges2shoes_convert-bgr_20210902_170902-0c828552.pth \
        tests/data/paired/test/33_AB.jpg \
        --save-path outputs/E/E2_shoes.jpg \
        --device cuda:0

如果报错_pickle.UnpicklingError: invalid load key, '<'.,说明 checkpoints 权重文件没有下载完整,可手动下载https://download.openmmlab.com/mmgen/pix2pix/refactor/pix2pix_vanilla_unet_bn_wo_jitter_flip_1x4_186840_edges2shoes_convert-bgr_20210902_170902-0c828552.pth至 checkpoints 目录

再运行下面的代码块:

python 复制代码
!python demo/translation_demo.py \
        configs/pix2pix/pix2pix_vanilla_unet_bn_wo_jitter_flip_edges2shoes_b1x4_190k.py \
        checkpoints/pix2pix_vanilla_unet_bn_wo_jitter_flip_1x4_186840_edges2shoes_convert-bgr_20210902_170902-0c828552.pth \
        tests/data/paired/test/33_AB.jpg \
        --save-path outputs/E/E2_shoes.jpg \
        --device cuda:0
python 复制代码
show_img_from_path('tests/data/paired/test/33_AB.jpg')

运行得到:

python 复制代码
show_img_from_path('outputs/E/E2_shoes.jpg')

上面的代码很简单,主要流程都是先指定configs配置文件,再指定一个pix2pix的checkpoint模型权重文件就可以啦🎉🎉🎉

ending

看到这里相信盆友们都对pix2pix有了一个全面深入的了解啦🌴🌴🌴很开心能把学到的知识以文章的形式分享给大家。如果你也觉得我的分享对你有所帮助,please一键三连嗷!!!下期见

相关推荐
Eliauk &5 分钟前
【机器学习】分类算法-KNN算法实现
人工智能·python·算法·机器学习·分类
littlesujin6 分钟前
昇思25天打卡营-mindspore-ML- Day14-VisionTransformer图像分类
人工智能·分类·数据挖掘
大舍传媒13 分钟前
欧美海外媒体发稿,国外新闻发布,外媒发布
大数据·人工智能·游戏引擎·信息与通信·用户运营
RamendeusStudio16 分钟前
绝区肆--2024 年AI安全状况
人工智能·安全
内容营销专家刘鑫炜21 分钟前
蚂蚁全媒体总编刘鑫炜谈新媒体时代艺术家如何创建及提升个人品牌
人工智能·媒体
PhyliciaFelicia21 分钟前
空间数据采集与管理:为什么选择ArcGISPro和Python?
开发语言·python·深度学习·机器学习·arcgis·数据分析
coolkidlan33 分钟前
【AI原理解析】-目标检测概述
人工智能·目标检测
LDR—0071 小时前
LDR6020-VR串流线:开启虚拟现实新纪元的钥匙
人工智能·vr
只是有点小怂1 小时前
【PYG】处理Cora数据集分类任务使用的几个函数log_softmax,nll_loss和argmax
人工智能·分类·数据挖掘
TechLead KrisChang1 小时前
合合信息大模型“加速器”重磅上线
人工智能·深度学习·机器学习