目录
[二、为什么需要 STN](#二、为什么需要 STN)
[(一)传统 CNN 的局限性](#(一)传统 CNN 的局限性)
[三、STN 的核心思想](#三、STN 的核心思想)
[四、STN 整体结构](#四、STN 整体结构)
[五、Localization Network 详解](#五、Localization Network 详解)
[(一)什么是 Localization Network](#(一)什么是 Localization Network)
[六、Grid Generator 详解](#六、Grid Generator 详解)
[(一)什么是 Grid Generator](#(一)什么是 Grid Generator)
[七、Sampler 详解](#七、Sampler 详解)
[(一)什么是 Sampler](#(一)什么是 Sampler)
[八、STN 工作流程分析](#八、STN 工作流程分析)
[九、STN 能学习哪些变换](#九、STN 能学习哪些变换)
[十、STN 的优势](#十、STN 的优势)
[十一、PyTorch 实现 STN](#十一、PyTorch 实现 STN)
[十二、STN 与注意力机制的关系](#十二、STN 与注意力机制的关系)
[十三、STN 的应用场景](#十三、STN 的应用场景)
[十四、STN 的不足](#十四、STN 的不足)
[十五、STN 对深度学习发展的意义](#十五、STN 对深度学习发展的意义)
在计算机视觉任务中,我们经常会遇到这样的问题:
同一个目标在不同图片中可能存在:
-
平移(Translation)
-
旋转(Rotation)
-
缩放(Scaling)
-
透视变换(Perspective Transformation)
例如:
在手写数字识别任务中:
数字"8"可能出现在图片中央,也可能偏左偏右。
在人脸识别任务中:
同一个人的照片可能存在:
-
侧脸
-
倾斜
-
放大
-
缩小
对于传统卷积神经网络(CNN)来说,这些变化都会影响模型识别效果。
虽然卷积神经网络具有一定的平移不变性(Translation Invariance),但面对较大的几何变换时,其性能仍然会明显下降。
为了解决这一问题,Google DeepMind 团队于 2015 年提出了:
Spatial Transformer Network
简称:
STN
论文名称:
Spatial Transformer Networks
STN 的核心思想非常简单:
让神经网络自动学习如何调整输入图片的位置和形状,从而更加容易识别目标。
因此:
STN 被认为是深度学习领域最早的空间注意力(Spatial Attention)模型之一。
二、为什么需要 STN
(一)传统 CNN 的局限性
传统卷积网络:
输入图片
↓
卷积
↓
池化
↓
分类
默认假设:
目标位置基本固定。
然而现实情况并非如此。
(二)目标位置变化问题
例如:
数字识别任务:
数字 5
可能出现:
左上角
右下角
旋转30°
放大2倍
此时:
即便是同一个数字。
神经网络也需要重新学习。
(三)数据增强的局限
传统解决方案:
数据增强。
例如:
旋转
翻转
裁剪
缩放
虽然有效。
但存在两个问题:
1、增加训练成本
需要生成大量样本。
2、无法覆盖所有情况
现实中的变换无限多。
数据增强无法完全解决。
三、STN 的核心思想
STN 提出一个全新的思路:
不要人为调整图片。
而是:
让网络自己学习如何调整
例如:
原始图片:
倾斜数字
经过 STN:
自动旋正
然后再送入分类网络。
整个过程:
Input
↓
STN
↓
Feature Map
↓
CNN
↓
Output
STN 相当于:
自动图像校正器
四、STN 整体结构
STN 主要由三个部分组成:
Localization Network
↓
Grid Generator
↓
Sampler
即:
1、定位网络
2、网格生成器
3、采样器
五、Localization Network 详解
(一)什么是 Localization Network
Localization Network:
定位网络。
作用:
预测图像应该如何变换。
输入:
Feature Map
输出:
变换参数 θ
(二)工作原理
通常采用:
CNN
+
FC
结构。
例如:
Image
↓
Conv
↓
Conv
↓
FC
↓
θ
输出:
仿射变换矩阵。
(三)仿射变换矩阵
二维空间:
通常采用:
2 × 3
矩阵。
形式:
[a11 a12 tx]
[a21 a22 ty]
其中:
-
tx:水平平移
-
ty:垂直平移
同时还可以表示:
-
旋转
-
缩放
-
错切
六、Grid Generator 详解
(一)什么是 Grid Generator
得到变换参数后。
需要计算:
输出图像对应输入图像哪里
这项工作由:
Grid Generator 完成。
(二)工作过程
首先生成:
标准坐标网格
例如:
(-1,-1)
(0,0)
(1,1)
然后:
利用 θ 进行变换。
得到:
新的采样位置。
(三)作用
本质上:
Grid Generator 完成:
坐标映射
工作。
七、Sampler 详解
(一)什么是 Sampler
Sampler:
采样器。
作用:
根据 Grid Generator 计算出的坐标。
从原图中取值。
(二)为什么需要插值
变换后:
坐标通常不是整数。
例如:
(15.3 , 26.8)
无法直接取像素。
因此需要:
插值计算。
(三)双线性插值
STN 默认采用:
Bilinear Interpolation
即:
双线性插值。
优点:
-
平滑
-
可微分
-
支持反向传播
八、STN 工作流程分析
完整流程:
Input Image
↓
Localization Network
↓
Transformation Parameter
↓
Grid Generator
↓
Sampling Grid
↓
Sampler
↓
Transformed Feature
↓
CNN
↓
Prediction
整个过程:
完全自动学习。
无需人工干预。
九、STN 能学习哪些变换
(一)平移变换
例如:
向左移动
向右移动
(二)旋转变换
例如:
30°
45°
90°
(三)缩放变换
例如:
放大
缩小
(四)仿射变换
例如:
旋转
平移
缩放
错切
同时进行。
十、STN 的优势
(一)自动学习空间变换
传统方法:
人工设计
STN:
自动学习
(二)增强模型鲁棒性
面对:
-
旋转
-
平移
-
尺度变化
表现更加稳定。
(三)可嵌入任意网络
可以插入:
-
LeNet
-
AlexNet
-
VGG
-
ResNet
等各种模型。
(四)端到端训练
无需额外标注。
直接反向传播。
十一、PyTorch 实现 STN
PyTorch 官方提供了 STN 支持。
核心代码如下:
python
import torch
import torch.nn.functional as F
theta = torch.tensor([
[[1,0,0],
[0,1,0]]
], dtype=torch.float)
feature_map = torch.randn(
1,3,28,28
)
grid = F.affine_grid(
theta,
feature_map.size()
)
output = F.grid_sample(
feature_map,
grid
)
print(output.shape)
其中:
F.affine_grid()
负责:
Grid Generator。
而:
F.grid_sample()
负责:
Sampler。
这两个函数基本实现了 STN 的核心功能。
十二、STN 与注意力机制的关系
很多同学会疑惑:
STN 算不算 Attention?
答案是:
算。
但属于:
Spatial Attention
空间注意力。
(一)SE注意力
关注:
哪个通道重要
属于:
Channel Attention。
(二)CBAM空间注意力
关注:
哪里重要
属于:
Spatial Attention。
(三)STN
更进一步:
不仅关注哪里重要。
还能够:
主动移动和调整目标位置
因此:
STN 是更早期、更经典的空间注意力模型。
十三、STN 的应用场景
(一)OCR文字识别
自动校正:
-
倾斜文字
-
弯曲文字
(二)人脸识别
自动对齐:
-
眼睛
-
鼻子
-
嘴巴
位置。
(三)目标检测
提高:
目标定位精度。
(四)医学影像
自动聚焦:
病灶区域。
(五)自动驾驶
增强:
车辆
行人
交通标志
识别能力。
十四、STN 的不足
(一)只能学习较简单变换
主要针对:
-
平移
-
旋转
-
缩放
复杂形变效果有限。
(二)增加训练难度
额外引入:
Localization Network。
(三)逐渐被Transformer替代
近年来:
-
ViT
-
DETR
-
Swin Transformer
发展迅速。
部分场景已经不再依赖 STN。
十五、STN 对深度学习发展的意义
STN 的最大贡献:
首次让神经网络具备:
自动空间校正能力
其思想深刻影响了:
-
Attention机制
-
Vision Transformer
-
Deformable Conv
-
DETR
等后续研究。
很多现代视觉模型:
本质上都在解决:
让模型关注正确位置
的问题。
而 STN 正是这一思想的重要起点。
十六、总结
Spatial Transformer Network(STN)是深度学习视觉领域的重要里程碑模型,也是最经典的空间注意力机制之一。
本文重点掌握了:
1、STN 提出的背景;
2、传统 CNN 的局限性;
3、STN 核心思想;
4、Localization Network 原理;
5、Grid Generator 原理;
6、Sampler 原理;
7、PyTorch 实现方式;
8、STN 与 Attention 的关系;
9、STN 的应用场景;
10、STN 的优势与不足。
可以将 STN 理解为:
"让神经网络拥有自动调整图片位置和形状能力的空间注意力模块。"
虽然如今 Transformer 系列模型更加流行,但 STN 在深度学习发展史上具有重要意义。掌握 STN,不仅有助于理解空间注意力机制,也为学习 Deformable Attention、Vision Transformer、DETR 等先进视觉模型打下坚实基础。