文章目录
论文:PVTv2: Improved Baselines with Pyramid Vision Transformer
代码:Github
1.回顾PVT
在提出PVT之前,Vision Transformer领域输出的特征图和输入大小基本保持一致,不发生尺度的变化。在PVT中提出将多个Transformer模块进行叠加,同时在每个模块内部 Attention 机构进行特征提取的大小变化。整体结构如下:
共采用了四个Transformer编码器进行叠加,每个阶段只有参数不同结构都一样。并且为了不盲目叠加编码器、适应尺度变化,作者提出可更改大小的注意力模块(SRA:Spatial-Reduction Attention,空间缩减注意力机制):
可理解为将重塑K、V的形状,由原先的 ( H W ) × C (HW)×C (HW)×C缩减为 ( H W ) R 2 × ( R 2 C ) \frac{(HW)}{R^2}×(R^2C) R2(HW)×(R2C)。经过此操作后特征图的大小不断变化,形成多尺度 PVT。各PVT版本的结构如下:
但PVT也存在如下问题:
- (1)与ViT类似,在处理高分辨率输入时,计算复杂度高。
- (2)PVT将图像视为不重叠的小块序列,在一定程度上失去了图像的局部连续性。
- (3)PVT中位置编码大小固定,这对任意大小的图像处理不灵活。
为解决这些问题而提出了PVT-v2。
2.重叠块嵌入
在原始PVT及ViT中均采用刚好切分的方式,这使得图像边界部分的信息无法得到完整解读。同时,也丧失了这些分块的局部连续性。因此,PVT2中将原始图像进行零填充,并使用 s t r i d e stride stride小于 k e r n e l _ s i z e kernel\_size kernel_size的重叠分块操作完成嵌入(上图中红色边框内的虚线即为重叠部分)。代码如下:
py
#原始PVT直接进行分块
self.proj = nn.Conv2d(in_chans=3, embed_dim=64, kernel_size=4, stride=4)
#PVTv2包含零填充+重叠分块
self.proj = nn.Conv2d(in_chans=3, embed_dim=64, kernel_size=7,stride=4,padding=(3, 3))
此时PVT2与PVT在 p a t c h _ e m b e d i n g patch\_embeding patch_embeding部分的输入、输出大小相同,即,输入尺寸为 ( 1 , 3 , 224 , 2240 ) (1,3,224,2240) (1,3,224,2240),输出尺寸仍为 ( b a t c h _ s i z e , c h a n n a l , 56 , 56 ) (batch\_size,channal,56,56) (batch_size,channal,56,56)。但PVT2的嵌入包含了每个图像patch和周围相邻图像patch的相关信息。实际代码:
py
class OverlapPatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
patch_size = to_2tuple(patch_size)
assert max(patch_size) > stride, "Set larger patch_size than stride"
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_chans, embed_dim, patch_size,
stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = nn.LayerNorm(embed_dim)
def forward(self, x): # (1,3,224,224)
x = self.proj(x) # (1,64,56,56)
x = x.permute(0, 2, 3, 1) # (1,56,56,64)
x = self.norm(x)
return x
其中, p a t c h _ s i z e = k e r n e l _ s i z e = 2 ∗ s t r i d e − 1 , p a d d i n g _ s i z e = s t r i d e − 1 patch\_size=kernel\_size=2*stride-1,padding\_size=stride-1 patch_size=kernel_size=2∗stride−1,padding_size=stride−1。
3.移除固定位置编码
在MLP层中两个FC之间加入了3×3的卷积以移除固定大小的位置编码,意思是,使用零填充给特征图外面加一圈padding,而卷积层可以根据特征图外圈的0学习到特征图的轮廓信息,换句话说可以学习到一些绝对位置信息,因此可以用DW卷积来建模位置信息并且去掉位置编码以减少计算量。实现代码:
py
class DWConv(nn.Module):
def __init__(self, dim=768):
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
class Mlp(nn.Module):
def forward(self, x, H, W):
x = self.fc1(x)
x = self.dwconv(x, H, W) #这里这里
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
4.线性SRA
PVT中使用SRA结构降低分辨率,而PVTv2中使用池化+卷积操作实现,以降低计算量。
py
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False):
if not linear:
if sr_ratio > 1:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)
else: #使用线性 SRA
self.pool = nn.AdaptiveAvgPool2d(7) #加了一层池化
self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) #卷积核为 1
self.norm = nn.LayerNorm(dim)
self.act = nn.GELU() #激活函数
self.apply(self._init_weights)
def forward(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
if not self.linear: #这里是原版 SRA
if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
else: #是线性 SRA
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
x_ = self.act(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
5.版本结构
PVTv2各版本结构如下: