论文阅读(二十五):PVTv2: Improved Baselines with Pyramid Vision Transformer

文章目录


论文:PVTv2: Improved Baselines with Pyramid Vision Transformer

代码:Github

1.回顾PVT

在提出PVT之前,Vision Transformer领域输出的特征图和输入大小基本保持一致,不发生尺度的变化。在PVT中提出将多个Transformer模块进行叠加,同时在每个模块内部 Attention 机构进行特征提取的大小变化。整体结构如下:

共采用了四个Transformer编码器进行叠加,每个阶段只有参数不同结构都一样。并且为了不盲目叠加编码器、适应尺度变化,作者提出可更改大小的注意力模块(SRA:Spatial-Reduction Attention,空间缩减注意力机制):

可理解为将重塑K、V的形状,由原先的 ( H W ) × C (HW)×C (HW)×C缩减为 ( H W ) R 2 × ( R 2 C ) \frac{(HW)}{R^2}×(R^2C) R2(HW)×(R2C)。经过此操作后特征图的大小不断变化,形成多尺度 PVT。各PVT版本的结构如下:

但PVT也存在如下问题:

  • (1)与ViT类似,在处理高分辨率输入时,计算复杂度高。
  • (2)PVT将图像视为不重叠的小块序列,在一定程度上失去了图像的局部连续性。
  • (3)PVT中位置编码大小固定,这对任意大小的图像处理不灵活。

为解决这些问题而提出了PVT-v2。

2.重叠块嵌入

在原始PVT及ViT中均采用刚好切分的方式,这使得图像边界部分的信息无法得到完整解读。同时,也丧失了这些分块的局部连续性。因此,PVT2中将原始图像进行零填充,并使用 s t r i d e stride stride小于 k e r n e l _ s i z e kernel\_size kernel_size的重叠分块操作完成嵌入(上图中红色边框内的虚线即为重叠部分)。代码如下:

py 复制代码
#原始PVT直接进行分块
self.proj = nn.Conv2d(in_chans=3, embed_dim=64, kernel_size=4, stride=4)

#PVTv2包含零填充+重叠分块
self.proj = nn.Conv2d(in_chans=3, embed_dim=64, kernel_size=7,stride=4,padding=(3, 3))

此时PVT2与PVT在 p a t c h _ e m b e d i n g patch\_embeding patch_embeding部分的输入、输出大小相同,即,输入尺寸为 ( 1 , 3 , 224 , 2240 ) (1,3,224,2240) (1,3,224,2240),输出尺寸仍为 ( b a t c h _ s i z e , c h a n n a l , 56 , 56 ) (batch\_size,channal,56,56) (batch_size,channal,56,56)。但PVT2的嵌入包含了每个图像patch和周围相邻图像patch的相关信息。实际代码:

py 复制代码
class OverlapPatchEmbed(nn.Module):
    """ Image to Patch Embedding
    """
    def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):
        super().__init__()
        patch_size = to_2tuple(patch_size)
        assert max(patch_size) > stride, "Set larger patch_size than stride"
        self.patch_size = patch_size
        self.proj = nn.Conv2d(
            in_chans, embed_dim, patch_size,
            stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2))
        self.norm = nn.LayerNorm(embed_dim)
 
    def forward(self, x):  # (1,3,224,224)
        x = self.proj(x)  # (1,64,56,56)
        x = x.permute(0, 2, 3, 1)  # (1,56,56,64)
        x = self.norm(x)
        return x

其中, p a t c h _ s i z e = k e r n e l _ s i z e = 2 ∗ s t r i d e − 1 , p a d d i n g _ s i z e = s t r i d e − 1 patch\_size=kernel\_size=2*stride-1,padding\_size=stride-1 patch_size=kernel_size=2∗stride−1,padding_size=stride−1。

3.移除固定位置编码

在MLP层中两个FC之间加入了3×3的卷积以移除固定大小的位置编码,意思是,使用零填充给特征图外面加一圈padding,而卷积层可以根据特征图外圈的0学习到特征图的轮廓信息,换句话说可以学习到一些绝对位置信息,因此可以用DW卷积来建模位置信息并且去掉位置编码以减少计算量。实现代码:

py 复制代码
class DWConv(nn.Module):
    def __init__(self, dim=768):
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)

        return x

class Mlp(nn.Module):
    def forward(self, x, H, W):
        x = self.fc1(x)
        x = self.dwconv(x, H, W) #这里这里
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

4.线性SRA

PVT中使用SRA结构降低分辨率,而PVTv2中使用池化+卷积操作实现,以降低计算量。

py 复制代码
class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False):
        if not linear:
            if sr_ratio > 1:
                self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
                self.norm = nn.LayerNorm(dim)
        else: #使用线性 SRA
            self.pool = nn.AdaptiveAvgPool2d(7) #加了一层池化
            self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1) #卷积核为 1
            self.norm = nn.LayerNorm(dim)
            self.act = nn.GELU() #激活函数
        self.apply(self._init_weights)

    def forward(self, x, H, W):
        B, N, C = x.shape
        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        if not self.linear: #这里是原版 SRA
            if self.sr_ratio > 1:
                x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
                x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
                x_ = self.norm(x_)
                kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
            else:
                kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

        else: #是线性 SRA
            x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
            x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1)
            x_ = self.norm(x_)
            x_ = self.act(x_)
            kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

5.版本结构

PVTv2各版本结构如下:

相关推荐
云卓科技6 分钟前
无人机之自动控制原理篇
科技·算法·目标检测·机器人·无人机
yyfhq7 小时前
dcgan
深度学习·机器学习·生成对抗网络
这个男人是小帅7 小时前
【图神经网络】 AM-GCN论文精讲(全网最细致篇)
人工智能·pytorch·深度学习·神经网络·分类
零澪灵9 小时前
[论文阅读] GPT-4 Technical Report
论文阅读
谁怕平生太急9 小时前
论文阅读:三星-TinyClick
论文阅读·mobile_ui_agent
zhilanguifang9 小时前
ERC论文阅读(02)--SAC;-LSTM论文阅读笔记
论文阅读·笔记·lstm
loong_XL9 小时前
论文阅读工具:arXiv、papers.cool、txyz
论文阅读
YRr YRr9 小时前
深度学习:正则化(Regularization)详细解释
人工智能·深度学习
yyfhq9 小时前
rescorediff
python·深度学习·机器学习
思通数据9 小时前
AI助力医疗数据自动化:诊断报告识别与管理
大数据·人工智能·目标检测·机器学习·计算机视觉·目标跟踪·自动化