掌握空间注意力 STN 模型结构——让神经网络学会自动“看准位置”

目录

一、前言

[二、为什么需要 STN](#二、为什么需要 STN)

[(一)传统 CNN 的局限性](#(一)传统 CNN 的局限性)

(二)目标位置变化问题

(三)数据增强的局限

1、增加训练成本

2、无法覆盖所有情况

[三、STN 的核心思想](#三、STN 的核心思想)

[四、STN 整体结构](#四、STN 整体结构)

[五、Localization Network 详解](#五、Localization Network 详解)

[(一)什么是 Localization Network](#(一)什么是 Localization Network)

(二)工作原理

(三)仿射变换矩阵

[六、Grid Generator 详解](#六、Grid Generator 详解)

[(一)什么是 Grid Generator](#(一)什么是 Grid Generator)

(二)工作过程

(三)作用

[七、Sampler 详解](#七、Sampler 详解)

[(一)什么是 Sampler](#(一)什么是 Sampler)

(二)为什么需要插值

(三)双线性插值

[八、STN 工作流程分析](#八、STN 工作流程分析)

[九、STN 能学习哪些变换](#九、STN 能学习哪些变换)

(一)平移变换

(二)旋转变换

(三)缩放变换

(四)仿射变换

[十、STN 的优势](#十、STN 的优势)

(一)自动学习空间变换

(二)增强模型鲁棒性

(三)可嵌入任意网络

(四)端到端训练

[十一、PyTorch 实现 STN](#十一、PyTorch 实现 STN)

[十二、STN 与注意力机制的关系](#十二、STN 与注意力机制的关系)

(一)SE注意力

(二)CBAM空间注意力

(三)STN

[十三、STN 的应用场景](#十三、STN 的应用场景)

(一)OCR文字识别

(二)人脸识别

(三)目标检测

(四)医学影像

(五)自动驾驶

[十四、STN 的不足](#十四、STN 的不足)

(一)只能学习较简单变换

(二)增加训练难度

(三)逐渐被Transformer替代

[十五、STN 对深度学习发展的意义](#十五、STN 对深度学习发展的意义)

十六、总结


在计算机视觉任务中,我们经常会遇到这样的问题:

同一个目标在不同图片中可能存在:

  • 平移(Translation)

  • 旋转(Rotation)

  • 缩放(Scaling)

  • 透视变换(Perspective Transformation)

例如:

在手写数字识别任务中:

数字"8"可能出现在图片中央,也可能偏左偏右。

在人脸识别任务中:

同一个人的照片可能存在:

  • 侧脸

  • 倾斜

  • 放大

  • 缩小

对于传统卷积神经网络(CNN)来说,这些变化都会影响模型识别效果。

虽然卷积神经网络具有一定的平移不变性(Translation Invariance),但面对较大的几何变换时,其性能仍然会明显下降。

为了解决这一问题,Google DeepMind 团队于 2015 年提出了:

Spatial Transformer Network

简称:

STN

论文名称:

Spatial Transformer Networks

STN 的核心思想非常简单:

让神经网络自动学习如何调整输入图片的位置和形状,从而更加容易识别目标。

因此:

STN 被认为是深度学习领域最早的空间注意力(Spatial Attention)模型之一。


二、为什么需要 STN

(一)传统 CNN 的局限性

传统卷积网络:

复制代码
输入图片

↓

卷积

↓

池化

↓

分类

默认假设:

目标位置基本固定。

然而现实情况并非如此。


(二)目标位置变化问题

例如:

数字识别任务:

复制代码
数字 5

可能出现:

复制代码
左上角

右下角

旋转30°

放大2倍

此时:

即便是同一个数字。

神经网络也需要重新学习。


(三)数据增强的局限

传统解决方案:

数据增强。

例如:

复制代码
旋转

翻转

裁剪

缩放

虽然有效。

但存在两个问题:

1、增加训练成本

需要生成大量样本。

2、无法覆盖所有情况

现实中的变换无限多。

数据增强无法完全解决。


三、STN 的核心思想

STN 提出一个全新的思路:

不要人为调整图片。

而是:

复制代码
让网络自己学习如何调整

例如:

原始图片:

复制代码
倾斜数字

经过 STN:

复制代码
自动旋正

然后再送入分类网络。

整个过程:

复制代码
Input

↓

STN

↓

Feature Map

↓

CNN

↓

Output

STN 相当于:

复制代码
自动图像校正器

四、STN 整体结构

STN 主要由三个部分组成:

复制代码
Localization Network

↓

Grid Generator

↓

Sampler

即:

1、定位网络

2、网格生成器

3、采样器


五、Localization Network 详解

(一)什么是 Localization Network

Localization Network:

定位网络。

作用:

预测图像应该如何变换。

输入:

复制代码
Feature Map

输出:

复制代码
变换参数 θ

(二)工作原理

通常采用:

复制代码
CNN

+

FC

结构。

例如:

复制代码
Image

↓

Conv

↓

Conv

↓

FC

↓

θ

输出:

仿射变换矩阵。


(三)仿射变换矩阵

二维空间:

通常采用:

复制代码
2 × 3

矩阵。

形式:

复制代码
[a11 a12 tx]

[a21 a22 ty]

其中:

  • tx:水平平移

  • ty:垂直平移

同时还可以表示:

  • 旋转

  • 缩放

  • 错切


六、Grid Generator 详解

(一)什么是 Grid Generator

得到变换参数后。

需要计算:

复制代码
输出图像对应输入图像哪里

这项工作由:

Grid Generator 完成。


(二)工作过程

首先生成:

复制代码
标准坐标网格

例如:

复制代码
(-1,-1)

(0,0)

(1,1)

然后:

利用 θ 进行变换。

得到:

新的采样位置。


(三)作用

本质上:

Grid Generator 完成:

复制代码
坐标映射

工作。


七、Sampler 详解

(一)什么是 Sampler

Sampler:

采样器。

作用:

根据 Grid Generator 计算出的坐标。

从原图中取值。


(二)为什么需要插值

变换后:

坐标通常不是整数。

例如:

复制代码
(15.3 , 26.8)

无法直接取像素。

因此需要:

插值计算。


(三)双线性插值

STN 默认采用:

复制代码
Bilinear Interpolation

即:

双线性插值。

优点:

  • 平滑

  • 可微分

  • 支持反向传播


八、STN 工作流程分析

完整流程:

复制代码
Input Image

↓

Localization Network

↓

Transformation Parameter

↓

Grid Generator

↓

Sampling Grid

↓

Sampler

↓

Transformed Feature

↓

CNN

↓

Prediction

整个过程:

完全自动学习。

无需人工干预。


九、STN 能学习哪些变换

(一)平移变换

例如:

复制代码
向左移动

向右移动

(二)旋转变换

例如:

复制代码
30°

45°

90°

(三)缩放变换

例如:

复制代码
放大

缩小

(四)仿射变换

例如:

复制代码
旋转

平移

缩放

错切

同时进行。


十、STN 的优势

(一)自动学习空间变换

传统方法:

复制代码
人工设计

STN:

复制代码
自动学习

(二)增强模型鲁棒性

面对:

  • 旋转

  • 平移

  • 尺度变化

表现更加稳定。


(三)可嵌入任意网络

可以插入:

  • LeNet

  • AlexNet

  • VGG

  • ResNet

等各种模型。


(四)端到端训练

无需额外标注。

直接反向传播。


十一、PyTorch 实现 STN

PyTorch 官方提供了 STN 支持。

核心代码如下:

python 复制代码
import torch
import torch.nn.functional as F

theta = torch.tensor([
    [[1,0,0],
     [0,1,0]]
], dtype=torch.float)

feature_map = torch.randn(
    1,3,28,28
)

grid = F.affine_grid(
    theta,
    feature_map.size()
)

output = F.grid_sample(
    feature_map,
    grid
)

print(output.shape)

其中:

复制代码
F.affine_grid()

负责:

Grid Generator。

而:

复制代码
F.grid_sample()

负责:

Sampler。

这两个函数基本实现了 STN 的核心功能。


十二、STN 与注意力机制的关系

很多同学会疑惑:

STN 算不算 Attention?

答案是:

算。

但属于:

复制代码
Spatial Attention

空间注意力。


(一)SE注意力

关注:

复制代码
哪个通道重要

属于:

Channel Attention。


(二)CBAM空间注意力

关注:

复制代码
哪里重要

属于:

Spatial Attention。


(三)STN

更进一步:

不仅关注哪里重要。

还能够:

复制代码
主动移动和调整目标位置

因此:

STN 是更早期、更经典的空间注意力模型。


十三、STN 的应用场景

(一)OCR文字识别

自动校正:

  • 倾斜文字

  • 弯曲文字


(二)人脸识别

自动对齐:

  • 眼睛

  • 鼻子

  • 嘴巴

位置。


(三)目标检测

提高:

目标定位精度。


(四)医学影像

自动聚焦:

病灶区域。


(五)自动驾驶

增强:

车辆

行人

交通标志

识别能力。


十四、STN 的不足

(一)只能学习较简单变换

主要针对:

  • 平移

  • 旋转

  • 缩放

复杂形变效果有限。


(二)增加训练难度

额外引入:

Localization Network。


(三)逐渐被Transformer替代

近年来:

  • ViT

  • DETR

  • Swin Transformer

发展迅速。

部分场景已经不再依赖 STN。


十五、STN 对深度学习发展的意义

STN 的最大贡献:

首次让神经网络具备:

复制代码
自动空间校正能力

其思想深刻影响了:

  • Attention机制

  • Vision Transformer

  • Deformable Conv

  • DETR

等后续研究。

很多现代视觉模型:

本质上都在解决:

复制代码
让模型关注正确位置

的问题。

而 STN 正是这一思想的重要起点。


十六、总结

Spatial Transformer Network(STN)是深度学习视觉领域的重要里程碑模型,也是最经典的空间注意力机制之一。

本文重点掌握了:

1、STN 提出的背景;

2、传统 CNN 的局限性;

3、STN 核心思想;

4、Localization Network 原理;

5、Grid Generator 原理;

6、Sampler 原理;

7、PyTorch 实现方式;

8、STN 与 Attention 的关系;

9、STN 的应用场景;

10、STN 的优势与不足。

可以将 STN 理解为:

"让神经网络拥有自动调整图片位置和形状能力的空间注意力模块。"

虽然如今 Transformer 系列模型更加流行,但 STN 在深度学习发展史上具有重要意义。掌握 STN,不仅有助于理解空间注意力机制,也为学习 Deformable Attention、Vision Transformer、DETR 等先进视觉模型打下坚实基础。

相关推荐
Together_CZ1 小时前
OpenCV 5.0 重磅发布:全面技术深度解析
图像处理·人工智能·opencv·计算机视觉·llm·dnn·推理
AI玫瑰助手1 小时前
Python函数:函数的文档字符串(docstring)编写
android·java·python
ABCDEEE71 小时前
RAG优化
人工智能
ITxiaobing20231 小时前
Neel Somani 解读加州 AB 205 能源可靠性框架的长期市场影响
大数据·人工智能·能源
雪碧聊技术1 小时前
python核心语法:模块
python·模块·
浊酒南街1 小时前
列表和元组知识总结
linux·python
小当家.1051 小时前
Excel AI Converter:用 大模型 自动转换excel表格格式
人工智能·excel·工具
MartinYeung51 小时前
[论文学习]透过增强式 Few-Shot Learning 实现高效 PII 从大型语言模型中提取
人工智能·学习·语言模型
恋恋风尘hhh1 小时前
从 Function Calling 到 MCP:Agent 工具调用的协议演进与架构实践
ai·agent