【深度学习】论文笔记:空间变换网络(Spatial Transformer Networks)

本文是对Google DeepMind 团队2015年发表的空间变换网络STN的详细讲解,作为初学者也是参考了很多博客,都在本文末尾给出,感谢前辈们的努力。

空间变换网络 (Spatial Transformer Networks,简称STN)是一种深度学习 模型,旨在增强网络对几何变换的适应能力 。STN是由Max Jaderberg等人在2015年提出的,其核心思想是在传统的卷积神经网络(CNN)中嵌入一个可学习的模块,该模块能够显式地对输入图像进行空间变换,从而使得网络能够对输入图像的几何变形具有更好的适应性。STN的引入使得网络能够自动进行图像的校正,例如旋转、缩放、剪切等,这在很多视觉任务中是非常有用的,如图像识别、目标检测和图像分割等。

一、为什么提出(Why)

  1. 一个理想中的模型:我们希望鲁棒的图像处理模型具有空间不变性,当目标发生某种转化后,模型依然能给出同样的正确的结果

  2. 什么是空间不变性 :举例来说,如下图所示,假设一个模型能准确把左图中的人物分类为凉宫春日,当这个目标做了放大、旋转、平移后,模型仍然能够正确分类,我们就说这个模型在这个任务上具有尺度不变性,旋转不变性,平移不变性

  3. CNN在这方面的能力是不足的:maxpooling的机制给了CNN一点点这样的能力,当目标在池化单元内任意变换的话,激活的值可能是相同的,这就带来了一点点的不变性。但是池化单元一般都很小(一般是2*2),只有在深层的时候特征被处理成很小的feature map的时候这种情况才会发生

  4. Spatial Transformer:本文提出的空间变换网络STN(Spatial Transformer Networks)STN可以使模型学习平移、缩放、旋转和更通用的扭曲的不变性。(二维空间变换网络)

二、STN是什么(What)

  1. STN对feature map(包括输入图像)进行空间变换,输出一张新的图像。
  2. 我们希望STN对feature map进行变换后能把图像纠正到成理想的图像,然后丢进NN去识别,举例来说,如下图所示,输入模型的图像可能是摆着各种姿势,摆在不同位置的凉宫春日,我们希望STN把它纠正到图像的正中央,放大,占满整个屏幕,然后再丢进CNN去识别。
  3. 这个网络可以作为单独的模块,可以在CNN的任何地方插入(即插即用),所以STN的输入不止是输入图像,可以是CNN中间层的feature map

三、STN是怎么做的(How)

STN可以通过为每个输入样本生成适当的变换 来主动对图像(或特征图)进行空间变换。然后在整个特征图上(非局部)执行变换,并且可以包括缩放、裁剪、旋转以及非刚性变形。这使得包含空间变换器的网络不仅可以选择图像中最相关(注意力)的区域,还可以将这些区域转换为规范的预期姿势,以简化后续层中的推理。

如上图所示,STN的输入为 U U U,输出为 V V V,因为输入可能是中间层的feature map,所以画成了立方体(多channel),STN主要分为下述三个步骤

  1. 定位网络(Localization Network):这一部分是STN的核心,其任务是学习输入图像的空间变换参数 。定位网络可以是任意的网络结构,它接受输入图像,并输出空间变换所需的参数 。这些参数定义了一个变换矩阵 ,用于调整图像的空间位置。(是一个自己定义的网络,它输入 U U U,输出变化参数 Θ \Theta Θ,这个参数用来映射 U U U和 V V V的坐标关系)。
  2. 网格生成器(Grid Generator):接收定位网络输出的变换参数,并生成一个对应于输出图像的坐标网格 。这个坐标网格对应于输入图像中的每一个像素位置 。根据 V V V中的坐标点和变化参数 Θ \Theta Θ,计算出 U U U中的坐标点。这里是因为 V V V的大小是自己先定义好的,当然可以得到 V V V的所有坐标点,而填充 V V V 中每个坐标点的像素值的时候,要从 U U U中去取,所以根据 V V V中每个坐标点和变化参数 Θ \Theta Θ进行运算,得到一个坐标。在sampler中就是根据这个坐标去 U U U中找到像素值,这样子来填充 V V V
  3. Sampler:要做的是填充 V V V,根据Grid generator得到的一系列坐标和原图 U U U(因为像素值要从 U U U中取)来填充,因为计算出来的坐标可能为小数,要用另外的方法来填充,比如双线性插值。从输入图像中采样像素来产生变换后的输出图像。这一步骤确保了图像的空间变换是可微分的,从而可以通过反向传播算法进行训练。

下面针对每个模块阐述一下

1、Localisation net

这个模块就是输入 U U U,输出一个变换参数 Θ \Theta Θ,那么这个 Θ \Theta Θ具体是指什么呢?

我们知道线性代数里,图像的平移,旋转和缩放都可以用矩阵运算来做

举例来说,如果想放大图像中的目标,可以这么运算,把(x,y)中的像素值填充到(x',y')上去,比如把原来(2,2)上的像素点,填充到(4,4)上去。

x ′ y ′ \] = \[ 2 0 0 2 \] \[ x y \] + \[ 0 0 \] \\begin{bmatrix}x\^{'}\\\\y\^{'}\\end{bmatrix}=\\begin{bmatrix}2\&0\\\\0\&2\\end{bmatrix}\\begin{bmatrix}x\\\\y\\end{bmatrix}+\\begin{bmatrix}0\\\\0\\end{bmatrix} \[x′y′\]=\[2002\]\[xy\]+\[00

如果想旋转图像中的目标,可以这么运算(可以在极坐标系中推出来,证明放到最后的附录)

x ′ y ′ \] = \[ c o s Θ − s i n Θ s i n Θ c o s Θ \] \[ x y \] + \[ 0 0 \] \\begin{bmatrix}x\^{'}\\\\y\^{'}\\end{bmatrix}=\\begin{bmatrix}cos\\Theta\&-sin\\Theta\\\\sin\\Theta\&cos\\Theta\\end{bmatrix}\\begin{bmatrix}x\\\\y\\end{bmatrix}+\\begin{bmatrix}0\\\\0\\end{bmatrix} \[x′y′\]=\[cosΘsinΘ−sinΘcosΘ\]\[xy\]+\[00

这些都是属于仿射变换(affine transformation)

x ′ y ′ \] = \[ a b c d \] \[ x y \] + \[ e f \] \\begin{bmatrix}x\^{\^{\\prime}}\\\\y\^{\^{\\prime}}\\end{bmatrix}=\\begin{bmatrix}a\&b\\\\c\&d\\end{bmatrix}\\begin{bmatrix}x\\\\y\\end{bmatrix}+\\begin{bmatrix}e\\\\f\\end{bmatrix} \[x′y′\]=\[acbd\]\[xy\]+\[ef

在仿射变化中,变化参数就是这6个变量, Θ = { a , b , c , d , e , f } (此 Θ 跟上述旋转变化里的角度 Θ 无关) \Theta=\{a,b,c,d,e,f\}\text{(此}\Theta\text{跟上述旋转变化里的角度}\Theta\text{无关)} Θ={a,b,c,d,e,f}(此Θ跟上述旋转变化里的角度Θ无关)

这6个变量就是用来映射输入图和输出图之间的坐标点的关系 的,我们在第二步grid generator就要根据这个变化参数,来获取原图的坐标点

总结如下:

  1. 功能 :定位网络的主要任务是预测空间变换的参数。根据输入图像,这个网络会输出一组参数,这些参数定义了一个空间变换,可以是平移、旋转、缩放等或者更复杂的仿射变换或者非线性变换。
  2. 结构 :定位网络通常是一个小型的卷积神经网络全连接网络,其具体结构可以根据任务的复杂度和输入数据的特性来定制。网络的输出大小是固定的,对应于特定变换所需的参数数量。

2、Grid generator

有了第一步的变化参数,这一步是做个矩阵运算,这个运算是以目标图 V V V的所有坐标点为自变量,以为参数做一个矩阵运算,得到输入图 U U U的坐标点

( x i s y i s ) = Θ ( x i t y i t 1 ) = [ Θ 11 Θ 12 Θ 13 Θ 21 Θ 22 Θ 23 ] ( x i t y i t 1 ) \begin{pmatrix}x_i^s\\y_i^s\end{pmatrix}=\Theta\begin{pmatrix}x_i^t\\y_i^t\\1\end{pmatrix}=\begin{bmatrix}\Theta_{11}&\Theta_{12}&\Theta_{13}\\\Theta_{21}&\Theta_{22}&\Theta_{23}\end{bmatrix}\begin{pmatrix}x_i^t\\y_i^t\\1\end{pmatrix} (xisyis)=Θ xityit1 =[Θ11Θ21Θ12Θ22Θ13Θ23] xityit1

其中 ( x i t , y i ) 记为输出图 V 中的第 i 个坐标点, V 中的长宽可以和 U 不一样,自己定义的,所以这里用 i 来标识第几个坐标点 ( x i s , y i ) {(x_{i}{t},y_{i})} 记 为 输 出 图 V 中 的 第 i 个 坐 标 点 , V 中 的 长 宽 可 以 和 U 不 一 样 , 自 己 定 义 的 , 所 以 这 里 用 i 来 标 识 第 几 个 坐 标 点 {(x_{i}{s},y_{i})} (xit,yi)记为输出图V中的第i个坐标点,V中的长宽可以和U不一样,自己定义的,所以这里用i来标识第几个坐标点(xis,yi)

  • 功能 :网格生成器接收定位网络预测的变换参数,并生成一个坐标网格 ,该网格代表了输入图像中每个像素映射到输出图像中的新位置
  • 原理 :对于每个输出图像的像素位置,网格生成器使用变换参数来计算对应的输入图像中的坐标。这一过程通常涉及到矩阵运算,用于实现平移、旋转、缩放等仿射变换。

3、Sampler

由于在第二步计算出了V中每个点对应到U的坐标点,在这一步就可以直接根据V的坐标点取得对应到U中坐标点的像素值来进行填充,而不需要经过矩阵运算。需要注意的是,填充并不是直接填充,首先计算出来的坐标可能是小数,要处理一下,其次填充的时候往往要考虑周围的其它像素值。填充根据的公式如下。

V i = ∑ n ∑ m U n m ∗ k ( x i s − m ; ϕ x ) ∗ k ( y i s − n ; ϕ y ) V_i=\sum_n\sum_mU_{nm}*k(x_i^s-m;\phi_x)*k(y_i^s-n;\phi_y) Vi=n∑m∑Unm∗k(xis−m;ϕx)∗k(yis−n;ϕy)

举例来说,我要填充目标图V中的(2,2)这个点的像素值,经过以下计算得到(1.6,2.4)

( x i s y i s ) = [ Θ 11 Θ 12 Θ 13 Θ 21 Θ 22 Θ 23 ] ( x i t y i t 1 ) ( 1.6 2.4 ) = [ 0 0.5 0.6 1 0 0.4 ] ( 2 2 1 ) \begin{gathered}\begin{pmatrix}x_i^s\\y_i^s\end{pmatrix}=\begin{bmatrix}\Theta_{11}&\Theta_{12}&\Theta_{13}\\\Theta_{21}&\Theta_{22}&\Theta_{23}\end{bmatrix}\begin{pmatrix}x_i^t\\y_i^t\\1\end{pmatrix}\\\begin{pmatrix}1.6\\2.4\end{pmatrix}=\begin{bmatrix}0&0.5&0.6\\1&0&0.4\end{bmatrix}\begin{pmatrix}2\\2\\1\end{pmatrix}\end{gathered} (xisyis)=[Θ11Θ21Θ12Θ22Θ13Θ23] xityit1 (1.62.4)=[010.500.60.4] 221

如果四舍五入后直接填充,则难以做梯度下降。

我们知道做梯度下降时,梯度的表现就是权重发生一点点变化的时候,输出的变化会如何。

如果用四舍五入后直接填充,那么(1.6,2.4)四舍五入后变成(2,2)当 Θ \Theta Θ(我们求导的时候是需要对 Θ \Theta Θ求导的)有一点点变化的时候,(1.6,2.4)可能变成了(1.9,2.1)四舍五入后还是变成(2,2),输出并没有变化,对 Θ \Theta Θ的梯度没有改变,这个时候没法用梯度下降来优化 Θ \Theta Θ

如果采用上面双线性插值的公式来填充,在这个例子里就会考虑(2,2)周围的四个点来填充,这样子,当 Θ \Theta Θ有一点点变化的时,式子的输出就会有变化,因为 ( x i s , y i ) (x_{i}{s},y_{i}) (xis,yi)的变化会引起V的变化。注意下式中U的下标,第一个下标是纵坐标,第二个下标才是横坐标。

V = U 21 ( 1 − 0.6 ) ( 1 − 0.4 ) + U 22 ( 1 − 0.4 ) ( 1 − 0.4 ) + U 31 ( 1 − 0.6 ) ( 1 − 0.6 ) + U 32 ( 1 − 0.4 ) ( 1 − 0.6 ) V=U_{21}(1-0.6)(1-0.4)+U_{22}(1-0.4)(1-0.4)+U_{31}(1-0.6)(1-0.6)+U_{32}(1-0.4)(1-0.6) V=U21(1−0.6)(1−0.4)+U22(1−0.4)(1−0.4)+U31(1−0.6)(1−0.6)+U32(1−0.4)(1−0.6)

4、STN小结

简单总结一下,如下图所示

  1. Localization net根据输入图,计算得到一个Θ
  2. Grid generator根据输出图的坐标点和Θ,计算出输入图的坐标点,举例来说想知道输出图上(2,2)应该填充什么坐标点,则跟Θ 运算,得到(1.6,2.4)
  3. Sampler根据自己定义的填充规则(一般用双线性插值)来填充,比如(2,2)坐标对应到输入图上的坐标为(1.6,2.4),那么就要根据输入图上(1.6,2.4)周围的四个坐标点(1,2),(1,3),(2,2),(2,3)的像素值来填充。

四、STN模块的pytorch实现

这里我们假设Mnist数据集作为网络输入:

(1)首先定义Localisation net特征提取部分,为两个Conv层后接Maxpool和Relu操作:

(2)定义Localisation net的变换参数θ回归 部分,为两层全连接层内接Relu:

(3)在nn.module的继承类中定义完整的STN模块操作:

五、空间变换网络的实际应用

1、STN作为网络的第一层

2、STN插入CNN 的中间层

六、评价

思想非常巧妙,因为卷积神经网络中的池化层(pooling layer)直接用一些max pooling 或者average pooling 的方法,将图片信息压缩,减少运算量提升准确率。

作者认为之前pooling的方法太过于暴力,直接将信息合并会导致关键信息无法识别出来,所以提出了一个叫空间转换器(spatial transformer)的模块,将图片中的的空间域信息做对应的空间变换 ,从而能将关键的信息提取出来。

Unlike pooling layers, where the receptive fields are fixed and local, the spatial transformer module is a dynamic mechanism that can actively spatially transform an image (or a feature map) by producing an appropriate transformation for each input sample.

空间转换器模型直观的实验图:

(a)列是原始的图片信息,其中第一个手写数字7没有做任何变换,第二个手写数字5,做了一定的旋转变化,而第三个手写数字6,加上了一些噪声信号;这些变化都是随机的

(b)列中的彩色边框是学习到的spatial transformer的框盒(bounding box),每一个框盒其实就是对应图片学习出来的一个spatial transformer

🪧©列中是通过spatial transformer转换之后的特征图,可以看出7的关键区域被选择出来,5被旋转成为了正向的图片,6的噪声信息没有被识别进入。

(d)列最终可以通过这些转换后的特征图来预测出中手写数字的数值。

|----------------------------------------------------------------------------------------------------------------------------------------------|
| 🌱spatial transformer其实就是注意力机制的实现,因为训练出的spatial transformer能够找出图片信息中需要被关注的区域,同时这个transformer又能够具有旋转、缩放变换的功能,这样图片局部的重要信息能够通过变换而被框盒提取出来。🌱 |

参考


本人能力有限,上述内容如有理解不当的地方,欢迎与我讨论!

相关推荐
风铃喵游42 分钟前
让大模型调用MCP服务变得超级简单
前端·人工智能
旷世奇才李先生1 小时前
Pillow 安装使用教程
深度学习·microsoft·pillow
booooooty1 小时前
基于Spring AI Alibaba的多智能体RAG应用
java·人工智能·spring·多智能体·rag·spring ai·ai alibaba
PyAIExplorer1 小时前
基于 OpenCV 的图像 ROI 切割实现
人工智能·opencv·计算机视觉
风口猪炒股指标2 小时前
技术分析、超短线打板模式与情绪周期理论,在市场共识的形成、分歧、瓦解过程中缘起性空的理解
人工智能·博弈论·群体博弈·人生哲学·自我引导觉醒
ai_xiaogui2 小时前
一键部署AI工具!用AIStarter快速安装ComfyUI与Stable Diffusion
人工智能·stable diffusion·部署ai工具·ai应用市场教程·sd快速部署·comfyui一键安装
聚客AI3 小时前
Embedding进化论:从Word2Vec到OpenAI三代模型技术跃迁
人工智能·llm·掘金·日新计划
weixin_387545643 小时前
深入解析 AI Gateway:新一代智能流量控制中枢
人工智能·gateway
聽雨2373 小时前
03每日简报20250705
人工智能·社交电子·娱乐·传媒·媒体
二川bro4 小时前
飞算智造JavaAI:智能编程革命——AI重构Java开发新范式
java·人工智能·重构