Segment-anything学习到微调系列_SAM细节理解和部分代码

前言

本系列文章是博主在工作中使用SAM模型时的学习笔记,包含三部分:

  1. SAM初步理解,简单介绍模型框架,不涉及细节和代码
  2. SAM细节理解,对各模块结合代码进一步分析
  3. SAM微调实例,原始代码涉及隐私,此部分使用公开的VOC2007数据集,Point和Box作为提示进行mask decoder微调讲解

此篇为第二部分,如果已看过第一部分的,可以跳过下文的模型总览中介绍输入输出和流程及最后的数据引擎part,和第一篇一致。本篇很多图和部分内容参考自【大模型系列】一文看懂SAM大模型,感谢原作者。

模型总览

SAM论文: arxiv.org/abs/2304.02...

SAM Github:github.com/facebookres...

SAM在线demo: segment-anything.com/demo

SAM的一部分灵感是来源于NLP中的基座模型(Foundation Model),Foundation Model是OpenAI提出的一个概念,它指的是在超大量数据集上预训练过的大模型(如GPT系列、BERT),这些模型具有非常强大的 zero-shot 和 few-shot能力,结合prompt engineering和fine tuning等技术可以将基座模型应用在各种下游任务中并实现惊人的效果。

SAM就是想构建一个这样的图像分割基座模型,即使是一个未见过的数据集,模型也能自动或半自动(基于prompt)地完成下游的分割任务。为了实现这个目标,SAM定义了一种可提示化的分割任务(promptable segmentation task),这个提示可以是点、框、掩码、文本 (代码中未实现)等形式,基于这个提示模型就能分割出提示处所在物体的masks。同时这种提示可以是模糊的,比如以下图剪刀握手那的黄色部分点为提示,分割掩码可以是下图最右边三种情况中任意一种,从上到下分别代表whole, part, subpart 三种层级的分割,这也是SAM兼容的。要达到这种效果就需要足够的高质量分割数据,SAM团队用他们提出的Data Engine策略成功使用人工加模型自动标注的方式制作除了一个有10亿个masks的分割数据集**SA-1B**,这也是他们核心的贡献之一,本文尾部会介绍相关流程。模型架构来说相对比较常规,主要是借鉴了ViT和DETR,本身创新不大。

如上图,SAM模型架构主要包括image encoder,prompt encoder和mask decoder三部分:

  • image encoder,使用了ViT模型将图像编码得到image embedding
  • prompt encoder,将point、box、mask、txt等提示信息进行编码,后续会和image embedding一起用于生成masks
  • mask decoder,将上述两个模块得到的embeddings整合,然后结合两个可学习的tokens生成不同层级的masks和对应的置信度值

值得一提的是,prompt encoder和mask decoder都是非常轻量的,主要的计算开销都在image encoder上,这点从模型权重上也能看出来。以ViT_B为基础的SAM权重是375M,其中prompt encoder只有32.8k,mask decoder是16.3M(4.35%),剩余则是image encoder,可想而知图像编码这块是非常耗时的。因此在实际推理中,一般单张图的image embedding只计算一次,然后将结果缓存起来,需要的时候直接调用。在image embedding已经计算好的情况下,论文中说给定一个prompt,生成mask时prompt encoder和mask decoder在浏览器中的计算耗时也仅需50ms。下面会具体介绍下各模块的输入输出和流程,均只考虑batch size为1的情况,代码讲解在下一篇。

Image encoder

输入:

默认是1024x1024的图像,如尺寸不一致会将原图按最长边resize

输出:

单张图的1x256x64x64的image embedding,即编码后的图像特征

流程

上图是ViTViT论文中的结构图,image encoder整体流程和ViT是一样的,区别在于不需要[class]token做分类,只输出最终的图像编码张量

  • 输入1024的图,拆分成64x64的768维patchs
  • 经过attention block(window和global的MSA,相对位置编码)和MLP得到同样大小64x64x768embbeding特征
  • 再经过neck得到1x256x64x64的图片embedding

这块有一篇文字介绍的更详细,如果想了解更多细节可以看这篇:Image encoder模块Vision Transformer网络解析

Image encoder主要由attention block和neck组成,下面将根据代码简单介绍

attention block

1024x1024x3的图片经过一个patch_size=16的PatchEmbed层,将原图分为一个个小块patch,每个patch会在channel维度展开成向量(16x16x3=768),即得到1x64x64x768的patchs,然后patchs经过attention模块(window attention或者global attention),再经MLP得到与输入x一致大小的1x64x64x768

window_partition

拆分窗口,相当于把原图拆成多个小图叠在一起

ini 复制代码
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) ##先按窗口大小拆分
## .permute将数据按窗口划分,再通过.view展成多个大小一直的小图
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)

Attention类

window attention和global attention计算都一样,区别在于输入x的B H W大小,window的B是batch小窗口的个数,H W是窗口大小;二global的则是针对全图计算Attention

qkv是直接用一个全连接层一次性得到然后拆分成3个多个注意力的q, k, v

相对位置编码和swin transformer类似,只针对query计算,直接加到attention上

ini 复制代码
    def forward(self, x: torch.Tensor) -> torch.Tensor:
    ## window attention的B是batchx小窗口的个数,H W是窗口大小;
        B, H, W, _ = x.shape
        # qkv with shape (3, B, nHead, H * W, C)         ## self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
        # q, k, v with shape (B * nHead, H * W, C)
        q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
        ## self.scale多头注意力的根号d,self.scale = (dim_input // num_heads )**-0.5
        attn = (q * self.scale) @ k.transpose(-2, -1) 
        if self.use_rel_pos:
        	##self.rel_pos_h是网络学到的参数
            attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
        attn = attn.softmax(dim=-1)
        x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
        x = self.proj(x)  ## self.proj = nn.Linear(dim, dim)
        return x

相对位置编码

即add_decomposed_rel_pos函数的实现,self.rel_pos_h就是类似swin transformer的相对位置偏移表(长度为2*M-1,因为可取的值是[-M+1,M-1]共2*M-1个),是网络学习到的

ini 复制代码
# 非global的input_size[0] = input_size[1] = 14,global是64
# multi-head attention中的head维度:head_dim = block输出的维度768 / head的数量12 = 64
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))

根据窗口大小生成对应相对位置索引index,再从self.rel_pos_h中取对应索引位置的偏移向量得到Rh(14x14x64 ),再把query展成多个小窗口形式与Rh矩阵乘法得到最终相对位置编码,attn上直接加上h和w方向的相对位置编码即可

python 复制代码
def add_decomposed_rel_pos(        
    attn: torch.Tensor,        
    q: torch.Tensor,        
    rel_pos_h: torch.Tensor,        
    rel_pos_w: torch.Tensor,        
    q_size: Tuple[int, int],        
    k_size: Tuple[int, int], ) -> torch.Tensor:        
    """ Calculate decomposed Relative Positional Embeddings    
    Args:                
        attn (Tensor): attention map.                
        q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).                
        rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.                
        rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.                
        q_size (Tuple): spatial sequence size of query q with (q_h, q_w).                
        k_size (Tuple): spatial sequence size of key k with (k_h, k_w).        

    Returns:                
        attn (Tensor): attention map with added relative positional embeddings.        
    """    
    # q: 300x196x64    
    # atten:300x196x196    
    q_h, q_w = q_size        
    k_h, k_w = k_size
    
    # Rh: 14x14x64        
    Rh = get_rel_pos(q_h, k_h, rel_pos_h)        
    Rw = get_rel_pos(q_w, k_w, rel_pos_w)        
    B, _, dim = q.shape
    
    # r_q: 300x14x14x64        
    r_q = q.reshape(B, q_h, q_w, dim) 
    
    # rel_h: 300x14x14x14    
    # 等价于:   
    # rel_h = torch.matmul(r_q, Rh.transpose(1, 2))
    # rel_w = torch.matmul(r_q.transpose(1, 2), Rw.transpose(1, 2)).transpose(1, 2)    
    rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)        
    rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
    
    # 将相对位置编码加在atten里面,再resize回300x196x196    
    attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w) 
    return attn

最终attention block的forward代码

window_size为0计算全局attention,否则计算局部的window attention

ini 复制代码
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        shortcut = x
        x = self.norm1(x)
        # Window partition
        if self.window_size > 0:
            H, W = x.shape[1], x.shape[2]
            x, pad_hw = window_partition(x, self.window_size)

        x = self.attn(x)
        # Reverse window partition
        if self.window_size > 0:
            x = window_unpartition(x, self.window_size, pad_hw, (H, W))

        x = shortcut + x
        x = x + self.mlp(self.norm2(x))
        return x

neck降低embedding维度

ini 复制代码
Sequential(
  (0): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (1): LayerNorm2d()
  (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (3): LayerNorm2d()
)

neck部分由两个卷积层组成,分别是256x768x1x1和256x256x3x3,最后输出的image imbedding的尺寸是1x256x64x64

encoder的forward

python 复制代码
def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.patch_embed(x) ## 拆成patchs
        ## 网络学习到的绝对位置编码,absolute positional embedding
        if self.pos_embed is not None:
            x = x + self.pos_embed
		### tansformer block由window attention和global attention组成
        for blk in self.blocks:
            x = blk(x)
        ## 上面得到的x是 1*64*64*768
        x = self.neck(x.permute(0, 3, 1, 2))
        ## 最后输出的image imbedding的尺寸是1*256*64*64
        return x

Prompt encoder

输入:

point、box、mask、txt(代码未实现)等prompt,格式一般如下,B为batch size

  • point需要包含点的x,y坐标BxNx2和label(0为前景,1位背景)BxNx1
  • box包含框的左上和右下两个点,BxNx4,对于某个gt即单个mask,只会有1个box;如果输入的是N个box最终会生成N个masks
  • mask一般和SAM最终输出mask的hxw(256x256),Bx1xHxW
  • txt在SAM代码中未实现,这块可以参考Grounded-Segment-Anything

输出两个:

  • sparse_embeddings 点和框的稀疏嵌入,形状为BxNx(embed_dim),其中N由输入点和框的数量确定,如果两者同时有则N的计算方式为(点的个数+2x框的个数)

    • point box 全都没有,输出大小:Bx0x256
    • 如果只有point,输出大小:Bx(N+1)x256,会补充一个[0,0]空点在最后,label为-1,表示只有点提示;
    • 如果只有box,输出大小: (B*N)x2x256
    • piont、box都有,输出大小:BxNx256
  • dense_embeddings 掩码的密集嵌入,形状为Bx(embed_dim)x(embed_H)x(embed_W),默认大小为Bx256x64x64,没有提示时会返回一个网络学习到的no mask默认嵌入

流程

网络已自动学会了针对不通过类型提示的编码信息,输入的point、box、mask等提示加上位置编码后,再加上网络学会的综合编码信息,最终对point、box这种稀疏的提示会返回sparse embedding, 对mask会返回dense embeddings(没有mask提示时是网络学习到的embeddings)。这部分就相当于把各种提示转换为decoder能理解的格式。

point embedding

输入points是个tuple,包含(point_coords, point_labels),point_coords一般是BxNx2的一些列点的xy坐标,point_labels是BXN个点对应的label,0代表是一个背景点,1代表是前景点(需要分割出mask的部分),以下都以1个点和1个box为提示分析

  • step1:首先生成一组可学习的向量point embedding,大小为:4x1x256,即前景/背景和框的左上右下两个点:

    scss 复制代码
    ## ModuleList((0-3): 4 x Embedding(1, 256)); 4个点代表 pos/neg point + 2 box corners
    self.point_embeddings = nn.ModuleList([nn.Embedding(1, 256) for i in range(4)])
  • step2:再生成一组可学习的向量not_a_point_embed,大小为1x256,用于表示该位置不是一个点

    ini 复制代码
    self.not_a_point_embed = nn.Embedding(1, embed_dim)
  • step3:点的padding

    N为传入点的个数,点point表示为BxNx2的tensor,如果prompt里面没有bbox只有点,则补充一个[0,0]点到points后面,其对应的label为-1,此时point大小为Bx(N+1)x2,label为Bx(N+1);

    如果传入的还有bbox,此时的point大小为BxNx2,label为BxN(不加pad)

    ini 复制代码
    ## 没有bbox,补充【0,0】点到每个point后面,其对应的label为-1
    if pad:
        padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device, dtype=points.dtype)
        padding_label = -torch.ones((labels.shape[0], 1), device=labels.device, dtype=points.dtype)
        points = torch.cat([points, padding_point], dim=1)
        labels = torch.cat([labels, padding_label], dim=1)
  • step4:计算Positionally encode:

    点的横纵坐标除以输入尺寸w,h(1024,1024)到[0,1]之间,再和随机高斯矩阵(2x128, positional_encoding_gaussian_matrix)相乘后,算sin cos值后h和w拼在一起得到点的位置编码(BxNx256),N=2表示只有point,N=1表示还有box没有pad一个空点

    python 复制代码
        def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor:
            """Positionally encode points that are normalized to [0,1]."""
            # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape
            coords = 2 * coords - 1
            coords = coords.to(self.positional_encoding_gaussian_matrix.dtype)
            coords = coords @ self.positional_encoding_gaussian_matrix  ##随机高斯矩阵 2x128
            coords = 2 * np.pi * coords
            # outputs d_1 x ... x d_n x C shape
            return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
  • step5:计算最终的point embedding

    上一步得到的256维向量,叠加学习到的embedding向量(非点、背景点、前景点,各自分别学习了)最终为BxNx256的向量

    ini 复制代码
    point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)    
    point_embedding[labels == -1] = 0.0        
    point_embedding[labels == -1] += self.not_a_point_embed.weight      # 对应label为-1的padding点,加上not_a_point_embed
    point_embedding[labels == 0] += self.point_embeddings[0].weight     # neg点加上point_embeddings[0]
    point_embedding[labels == 1] += self.point_embeddings[1].weight     # pos点加上point_embeddings[1]

完整的embedding流程

ini 复制代码
def _embed_points(
    self,        
    points: torch.Tensor,        
    labels: torch.Tensor,        
    pad: bool,     ) -> torch.Tensor:        
    """Embeds point prompts."""        
    points = points + 0.5  # Shift to center of pixel
    # 如果没有输入的box的话,会将points的长度用0补充形成Bx(N+1)x2,label用【-1】补充成Bx(N+1)
    if pad:            
        padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)            
        padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)            
        points = torch.cat([points, padding_point], dim=1)            
        labels = torch.cat([labels, padding_label], dim=1)
    # 将points与一个2x128的随机高斯矩阵相乘再通过进行sin、cos运算,两者的运算结果拼接得到
    # point_embedding: BxNx256 或者 Bx(N+1)x256
    point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)    
    point_embedding[labels == -1] = 0.0        
    point_embedding[labels == -1] += self.not_a_point_embed.weight
    point_embedding[labels == 0] += self.point_embeddings[0].weight
    point_embedding[labels == 1] += self.point_embeddings[1].weight
    return point_embedding

最后输出的结果是sparse_embeddings BxNx256,如果prompt只有点则sparse_embeddings就是Bx(N+1)x256的point_embedding,如果还有box,那最终结果是BxNx256的point_embedding和叠加后面Bx1x256的box embedding拼在一起当作最终的sparse_embeddings

box embedding

对于某个gt即单个mask,只会有1个box;如果输入的是N个box最终会生成N个masks。一个box有两个点,左上和右下角点,embedding步骤如下

  • step1:先resize为Nx2x2;N代表多个框
  • step2:再使用point embedding一样的位置编码方式,得到corner_embedding位置编码后的Nx2x256向量
  • step3:再加上之前网络学习到的box角点embeding向量(0 1是前景背景,2 3是box的左上右下)
python 复制代码
    def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
        """Embeds box prompts."""
        boxes = boxes + 0.5  # Shift to center of pixel
        coords = boxes.reshape(-1, 2, 2) ## 操作与points类似,将2个点resize成Nx2x2
        ## 使用point embedding编码的方式,得到corner_embedding Nx2x256
        corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
        ## 再加上学习到的点的embedding
        corner_embedding[:, 0, :] += self.point_embeddings[2].weight    #2 3是box的左上右下
        corner_embedding[:, 1, :] += self.point_embeddings[3].weight    #2 3是box的左上右下
        return corner_embedding ##Nx2x256,N代表多个框

最后输出的box的embedding的尺寸是Nx2x256,N代表多个框

sparse embedding

point和box的prompt最终结果都是sparse embedding,因此代码里会将两者合并起来得到最终结果,

  • 如果只有point,输出大小为Bx(N+1)x256,因为padding了1个空点,比如一个点时为Bx2x256
  • 如果只有box,输出大小为 (B*N)x2x256,比如一个box时为Bx2x256
  • point、box都有,输出大小:BxNx256,N的计算方式为(点的个数+2x框的个数),比如一个point和一个box时为Bx3x256

dense_embeddings

dense_embeddings也是mask embedding,即针对全图的编码,分两种情况,输入的prompt有无mask

  • 有mask提示,则经过self.mask_downscaling卷积网络不断下采样,简单粗暴得到Nx256x64x64维的embedding(N是输入mask的个数)

    ini 复制代码
    self.mask_downscaling:
    Sequential(
      (0): Conv2d(1, 4, kernel_size=(2, 2), stride=(2, 2))
      (1): LayerNorm2d()
      (2): GELU(approximate='none')
      (3): Conv2d(4, 16, kernel_size=(2, 2), stride=(2, 2))
      (4): LayerNorm2d()
      (5): GELU(approximate='none')
      (6): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))
  • 无mask提示,网络有一个1x256的向量self.no_mask_embed表示无mask提示时的特征向量,直接将其复制expand成Nx256x64x64的embedding

    ini 复制代码
    self.no_mask_embed = nn.Embedding(1, embed_dim)
    dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
                    bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] )

Mask decoder

输入:

  • image encoder得到的image_embeddings和图像的positional encoding
  • prompt encoder得到的prompt embeddings(sparse和dense两种)

输出:

  • masks,如果指定了"multimask_output"参数则会输出3个层级的mask(whole, part, and subpart),否则只输出1个mask
  • IoU scores,可以理解为每个mask的置信度,由网络中的iou token得到

流程

  • 首先会image_embeddings会混入dense embeddings的信息(两者直接相加),sparse embeddings则会与mask token和IoU token拼在一起成为一个新的token,mask token后续会用于生成mask,IoU token用于衡量每个mask的好坏

  • 然后这个新的token和image_embeddings经过一个TwoWayTransformer模块(下图黄色框部分),先做token的self attention,然后做token(作为key)到图像的cross attention,经过MLP更新token,最后再图像(作为key)到token的attention,目的是不断更新图像和token中的信息,会重复两次

  • 更新后token再做一次token(作为key)到图像的cross attention后,又拆出来之前的两个部分mask token和IoU token,后者就代表每个mask的置信度;

    而图像信息经过转置卷积还原到原图大小后,会和mask token做矩阵乘法生成最终的masks,类似 YOLACT中的"prototype masks"和"mask coefficients"矩阵乘法

output_tokens

类似NLP和ViT中的[cls]token用来分类,decoder中定义了两个可学习的token辅助生成mask,两个拼在一起就是output_tokens:

  • iou_token(1x256),会用于计算后续IoU scores(上图绿色部分)
  • mask_tokens(4x256) ,用于生成最终的mask(上图红色框部分),分别对应单张mask(仅在不需要多层mask时启用)+3种层级的mask(whole, part, and subpart)
ini 复制代码
self.iou_token = nn.Embedding(1, transformer_dim) ##transformer_dim=256
self.num_mask_tokens = num_multimask_outputs + 1 ##num_multimask_outputs=3,对应3种层级的mask(whole, part, and subpart),+1是首层单张mask(仅在不需要多层mask时启用)
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)

sparse embedding处理

output_tokens和promp encoder得到的sparse embedding(点、框的稀疏提示)会concate在一起,当成新的tokens(只有1个点作为prompt时,为Nx7x256,后续所有都按只有1个点prompt情况分析),对应上图左下角

ini 复制代码
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) #5x256
##点、框的稀疏提示,沿batch方向复制成相同维度,得到Nx5x256
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)

image_embeddings处理

image encoder得到了1x256x64x64的image_embeddings,将其从batch维度拓展成和tokens一样的Nx256x64x64(N为prompt个数),然后和dense_embeddings相加,即加入全图mask的稠密提示信息(有mask作为prompt时是对应mask的embedding特征,无mask作为prompt时,是网络自己学到的embedding特征)

image_pe

image_pe是1x256x64x64的位置编码特征,编码方式和prompt encoder中给point的Positionally encode方法一样(随机高斯矩阵后正余弦固定编码),同样展成Nx256x64x64的大小

ini 复制代码
# Expand per-image data in batch direction to be per-mask
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape

TwoWayTransformer

将上述src、pos_src、tokens经过TwoWayTransformer得到两个输出hs(即更新后的tokens), src(更新后的image_embeddings+dense_prompt_embeddings,包含原图和mask prompt的信息)

ini 复制代码
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)

网络主要由两个TwoWayAttentionBlock组成,计算有以下特点:

  • tokens(iou+mask+sparse embedding)主要作为query(1x7x256),query_pe即query的位置编码信息使用tokens本身替代;
  • src即image_embeddings,是作为key的(展成了1x4096x256);pos_src即image_pe,是image_embeddings的位置编码信息(1x4096x256)
  • 每次算Attention不管query还是key都会加上其对应的位置编码信息,类似ShortCut

TwoWayAttentionBlock

TwoWayAttentionBlock主要包含四层,对应流程那块图的黄色框四层部分

  • 针对query(即tokens,iou+mask+sparse embedding)的self-atten,但首层无位置编码

    ini 复制代码
            # Self attention block
            if self.skip_first_layer_pe:
                queries = self.self_attn(q=queries, k=queries, v=queries)
            else:
                q = queries + query_pe
                attn_out = self.self_attn(q=q, k=q, v=queries)
                queries = queries + attn_out
            queries = self.norm1(queries)
  • tokens到image的cross attention(tokens作为query, image_embeddings作为key)

    ini 复制代码
            # Cross attention block, tokens attending to image embedding
            q = queries + query_pe
            k = keys + key_pe
            attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)
            queries = queries + attn_out
            queries = self.norm2(queries)
  • 针对tokens的mlp block

    ini 复制代码
            # MLP block
            mlp_out = self.mlp(queries)
            queries = queries + mlp_out
            queries = self.norm3(queries)
  • 与第二层相反,image到tokens的cross attention(image_embeddings作为query, tokens作为key)

    ini 复制代码
            # Cross attention block, image embedding attending to tokens
            q = queries + query_pe
            k = keys + key_pe
            attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)
            keys = keys + attn_out
            keys = self.norm4(keys)

经过两层TwoWayAttentionBlock更新queries和keys后,再来一个final attention层,更新得到最终的queries即hs, keys即src

ini 复制代码
        # Apply the final attention layer from the points to the image
        q = queries + point_embedding
        k = keys + image_pe
        attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
        queries = queries + attn_out
        queries = self.norm_final_attn(queries)

整个流程较为复杂,参考下图

预测masks

transformer把输入的tokens和image embedding更新后变成了hs(1x7x256), src(1x4096x256),接下来就是利用这两个信息去上采样生成mask。整个流程可以参考下图,queries就是我们的hs,keys就是src

hs,即更新后的tokens

此前的tokens是由iou_token和mask_tokens拼接得到,所以从hs上也可以拆下来更新后的那两个token,其中iou_token_out表示生成mask的质量即IoU scores,mask_tokens_out用于生成不同层级mask

ini 复制代码
# Run the transformer,update tokens和image embedding
hs, src = self.transformer(src, pos_src, tokens) 
## hs 1x7x256; src: 1x4096x256
iou_token_out = hs[:, 0, :]  ##  1x1x256
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] ##  1x4x256

src,即更新后的mask embeddings(从image embeddings+dense embeddings学习得来)

把1x4096x256的mask embeddings变成1x256x64x64(和image encoder输出的大小对应),再经过两层转置卷积进行上采样(到原图大小)变为1x32x256x256的upscaled_embedding

ini 复制代码
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)

基于mask_tokens和上采样后的mask embeddings做矩阵乘法得到mask

1x4x256的mask_tokens,4层包含不同层级only one, whole, part, subpart,所以用4个不同的全连接层(三层MLP网络)降低embedded维度,从256维降低为32维,再把4层结果又concat在一起得到1x4x32维的hyper_in(可以理解为超维压缩特征)

最终由1x4x32的hyper_in和1x32x256x256的upscaled_embedding矩阵乘法得到1x4x256x256的masks结果 (为了可以矩阵运算,需要将后者h w合并成在一个维度运算完再展开) 这块来自 YOLACT的思路 (YOLOV5分割模型也是用的这种,github有个讲解比较详细的github.com/ultralytics... ,也可以参考这篇 YOLACT

此处只针对单个点prompt(N=1)分析,一个prompt对应一组mask结果

ini 复制代码
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
	##1x1x256 变为 1x1x32
    hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
hyper_in = torch.stack(hyper_in_list, dim=1)  ##1x4x32
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w)  ## 1x4x256x256

iou_pred

之前得到的1x256的iou_token经过IoU预测head(全连接+sigmoid)直接得到1x4的iou_pred,即对应上面4层生成mask的IoU值

ini 复制代码
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)

后处理到原图大小

预测得到的1x4x256x256 masks,根据参数multimask_output决定是否输出多层级mask,如果是则取后3层mask,否则取最顶层,iou_pred也取对应的;

得到的mask是256x256大小,先插值到网络输入大小(1024x1024),再去掉padding部分,最后再插值缩放到原图大小hxw

ini 复制代码
## self.image_encoder.img_size=1024
masks = F.interpolate( masks, (self.image_encoder.img_size, self.image_encoder.img_size), mode="bilinear",    align_corners=False,   )
masks = masks[..., : input_size[0], : input_size[1]]
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)

整图分割推理(segment everything)

流程

在图片上生成32x32的网格,得到1024个采样点,每个采样点都当做1个前景的prompt进入prompt encoder然后和image encoder结果一起生成mask,每次会处理一个batch(默认64)的采样点;每个batch得到的mask都会进行以下几个过滤:

  • predicted IoU过滤,mask decoder除了返回masks还会预测对应mask iou值,过滤低置信度(默认阈值0.88)的mask
  • stability score过滤,stability score是mask在两个阈值下二值化后的IoU值,可以理解为改变过滤阈值后还能得到同样mask的能力,过滤低于0.95的mask
  • mask threshold过滤,直接过滤mask logits值低于mask_threshold(默认0.0)的mask
  • boundary过滤,每个mask生成外界矩形,过滤超过图像边界的mask

所有batch过滤后的的masks结果再进行nms过滤(mask对应外接矩形的nms,阈值0.7)就得到最终的分割结果

图片crop

一般默认是基于整张原图去生成32x32的网格点,代码里可以配置参数crop_n_layers(默认0)将原图进行切分,每条边切分成2**crop_n_layers份,如值为1对应每边切成2份得到2x2个小的子图,值为2对应每边切成4份得到4x4个子图。切分时子图会有一定重叠区域(overlap_ratio参数控制),某种程度防止把一个物体切成两份了?

对应函数是 generate_crop_boxes,返回的是每个切分层级得到的子图坐标,以1024x1024原图,crop_n_layers为1,重叠比例为512/1500为例,返回结果如下:

yaml 复制代码
##第一个代表原图不切分,后面4个即2x2的子图区域
[[0, 0, 1024, 1024], 
[0, 0, 687, 687], [0, 338, 687, 1024], [338, 0, 1024, 687], [338, 338, 1024, 1024]]

为方便理解后续都按crop_n_layers=0不切分情况分析

网格点生成

n_per_side为每个边需要采样的点数,默认是32,坐标范围是[0,1]之间,x,y方向都采样得到32x32的网格,将对应网格坐标乘以原始图片的长宽,得到1024个采样点在原图的坐标,这些点都作为prompt进行mask生成

python 复制代码
def build_point_grid(n_per_side: int) -> np.ndarray:
    """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
    offset = 1 / (2 * n_per_side)
    points_one_side = np.linspace(offset, 1 - offset, n_per_side)
    points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
    points_y = np.tile(points_one_side[:, None], (1, n_per_side))
    points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
    return points

batch个点推理得到masks

1024个点拆分成多个batch推理,1个batch64个点,网络输入大小默认1024x1024,图片原始尺寸会按最长边进行resize到1024(如960x540的图变1024x576),上面得到点坐标是对应原图大小的,需要经apply_coords进一步转换到resize后的图的坐标,每个点的label都是1,即都作为前景点去得到mask;然后整个batch的经过prompt_encoder得到对应点的sparse embedding和网络学到的dense_embeddings(因为无mask prompt),再和image encoder得到的image imbedding一起经过mask decoder得到最终masks

ini 复制代码
        # 坐标位置对应到resize后的图上
        transformed_points = self.predictor.transform.apply_coords(points, im_size)
        in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
        in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)  #1024个都是前景点,label=1
        ## 常规的基于点提示得到masks流程
        masks, iou_preds, _ = self.predictor.predict_torch(
            in_points[:, None, :],
            in_labels[:, None],
            multimask_output=True,
            return_logits=True,
        )

batch内masks过滤

predicted IoU过滤

mask decoder除了返回masks还会预测对应mask iou值,过滤低置信度(默认阈值0.88)的mask

python 复制代码
        # Filter by predicted IoU
        if self.pred_iou_thresh > 0.0:
            keep_mask = data["iou_preds"] > self.pred_iou_thresh ##默认阈值0.88
            data.filter(keep_mask)

stability score过滤

stability score是mask在两个阈值下二值化后的IoU值,可以理解为改变过滤阈值后还能得到同样mask的能力,计算方式如下代码:

ini 复制代码
## mask_threshold=0.0, threshold_offset=1.0
intersections = (   (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32))
unions = (  (masks > (mask_threshold - threshold_offset)) .sum(-1, dtype=torch.int16)  .sum(-1, dtype=torch.int32))
stability_score = intersections / unions

过滤低于self.stability_score_thresh(默认0.95)的mask

mask threshold过滤

直接过滤mask logits值低于mask_threshold(默认0.0)的mask

css 复制代码
data["masks"] = data["masks"] > self.predictor.model.mask_threshold

boundary过滤

每个mask生成外界矩形,过滤超过图像边界的mask

css 复制代码
        data["boxes"] = batched_mask_to_box(data["masks"])
        # Filter boxes that touch crop boundaries
        keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
        if not torch.all(keep_mask):
            data.filter(keep_mask)

nms过滤

所有batch的过滤结果会再经过一次nms过滤,基于对应mask的外接矩形进行nms过滤,如下代码

ini 复制代码
        # Remove duplicates within this crop.
        keep_by_nms = batched_nms(
            data["boxes"].float(),
            data["iou_preds"],
            torch.zeros_like(data["boxes"][:, 0]),  # categories
            iou_threshold=self.box_nms_thresh,
        )
        data.filter(keep_by_nms)

如果配置了图片crop,即多层级的子图得到的mask还会再进行一次nms过滤得到原图最终的masks

最终结果

有个参数output_mode可以设置输出mask的类型,默认是"binary_mask",会比较耗内存,还可以选择coco的RLE编码格式,参数为'uncompressed_rle', or 'coco_rle'

git上也有官方demo可以参考:全图分割的官方demo

数据引擎(data engine)

SAM除了模型外,还公开了一份有10亿个masks的1100万张图的分割数据集**SA-1B**,基于他们提出的data engine方案得到,这块的贡献也是非��显著,也体现了Data-centric AI的惊人能力,[这块知乎上"一堆废纸"博主介绍的比较好](如何评价Meta/FAIR 最新工作Segment Anything? - 一堆废纸的回答 - 知乎 www.zhihu.com/question/59...%25E3%2580%2582%25E4%25BB%258E%25E8%25AE%25BA%25E6%2596%2587%25E9%2587%258C%25E6%2580%25BB%25E7%25BB%2593%25E5%25B0%25B1%25E6%2598%25AF%25E8%25BE%2585%25E5%258A%25A9%25E4%25BA%25BA%25E5%25B7%25A5%25E6%25A0%2587%25E6%25B3%25A8%25E3%2580%2581%25E5%258D%258A%25E8%2587%25AA%25E5%258A%25A8%25E6%25A0%2587%25E6%25B3%25A8%25E3%2580%2581%25E5%2585%25A8%25E8%2587%25AA%25E5%258A%25A8%25E6%25A0%2587%25E6%25B3%25A8%25E4%25B8%2589%25E6%25AD%25A5%25EF%25BC%258C%25E5%2585%25B7%25E4%25BD%2593%25E5%25A6%2582%25E4%25B8%258B%25EF%25BC%259A "https://www.zhihu.com/question/593888697/answer/2972047807)%E3%80%82%E4%BB%8E%E8%AE%BA%E6%96%87%E9%87%8C%E6%80%BB%E7%BB%93%E5%B0%B1%E6%98%AF%E8%BE%85%E5%8A%A9%E4%BA%BA%E5%B7%A5%E6%A0%87%E6%B3%A8%E3%80%81%E5%8D%8A%E8%87%AA%E5%8A%A8%E6%A0%87%E6%B3%A8%E3%80%81%E5%85%A8%E8%87%AA%E5%8A%A8%E6%A0%87%E6%B3%A8%E4%B8%89%E6%AD%A5%EF%BC%8C%E5%85%B7%E4%BD%93%E5%A6%82%E4%B8%8B%EF%BC%9A")

  • 第一步以人工标注为主。初始模型在公开数据集训练后辅助生成masks,再人工精修调整,再用标好的新数据迭代模型。如此重复6次,从12万张图得到430万masks
  • 第二步是模型半自动标注高置信度masks,然后人工标注补充剩余未标出的masks。mask的置信度判断是用一个模型对mask进行目标检测,如果能检测出物体则是置信度较高mask无需再人工标注,这个目标检测模型是基于第一步得到的数据训练的。如此迭代5次,从18万张图新增了590万masks
  • 第三部是模型全自动标注。基于此前两步的数据得到模型,已有较好的分割能力且能适配模糊提示分割(局部mask或者整体mask),对一张图撒32x32的网格点进行segment everything,后处理会挑选搞IoU和搞稳定性的masks并做NMS得到全图最终的masks。针对所有图片自动分割,最终得到了SA-1B数据集
相关推荐
Lenyiin1 小时前
02.01、移除重复节点
c++·算法·leetcode
Lulsj4 小时前
代码随想录day22 | leetcode 39.组合总和 40.组合总和II 131.分割回文串
算法·leetcode
yvestine7 小时前
数据挖掘——支持向量机分类器
人工智能·算法·机器学习·支持向量机·分类·数据挖掘·svm
robin_suli7 小时前
穷举vs暴搜vs深搜vs回溯vs剪枝系列一>
算法·剪枝·深度优先遍历·回溯
魂兮-龙游8 小时前
C语言中的printf、sprintf、snprintf、vsnprintf 函数
c语言·开发语言·算法
陈序缘8 小时前
PyTorch快速入门
人工智能·pytorch·python·深度学习·算法·机器学习
KeyPan8 小时前
【视觉SLAM:四、相机与图像】
人工智能·深度学习·数码相机·算法·机器学习·计算机视觉
HUT_Tyne2659 小时前
力扣--LCR 167.招式拆解I
数据结构·算法·leetcode
hanlin.liu168810 小时前
【刷题日记】455.分发饼干
算法·leetcode·职场和发展
田梓燊10 小时前
01 背包
算法