前言
本系列文章是博主在工作中使用SAM模型时的学习笔记,包含三部分:
- SAM初步理解,简单介绍模型框架,不涉及细节和代码
- SAM细节理解,对各模块结合代码进一步分析
- SAM微调实例,原始代码涉及隐私,此部分使用公开的VOC2007数据集,Point和Box作为提示进行mask decoder微调讲解
此篇为第二部分,如果已看过第一部分的,可以跳过下文的模型总览中介绍输入输出和流程及最后的数据引擎part,和第一篇一致。本篇很多图和部分内容参考自【大模型系列】一文看懂SAM大模型,感谢原作者。
模型总览
SAM论文: arxiv.org/abs/2304.02...
SAM Github:github.com/facebookres...
SAM在线demo: segment-anything.com/demo
SAM的一部分灵感是来源于NLP中的基座模型(Foundation Model),Foundation Model是OpenAI提出的一个概念,它指的是在超大量数据集上预训练过的大模型(如GPT系列、BERT),这些模型具有非常强大的 zero-shot 和 few-shot能力,结合prompt engineering和fine tuning等技术可以将基座模型应用在各种下游任务中并实现惊人的效果。
SAM就是想构建一个这样的图像分割基座模型,即使是一个未见过的数据集,模型也能自动或半自动(基于prompt)地完成下游的分割任务。为了实现这个目标,SAM定义了一种可提示化的分割任务(promptable segmentation task),这个提示可以是点、框、掩码、文本 (代码中未实现)等形式,基于这个提示模型就能分割出提示处所在物体的masks。同时这种提示可以是模糊的,比如以下图剪刀握手那的黄色部分点为提示,分割掩码可以是下图最右边三种情况中任意一种,从上到下分别代表whole, part, subpart 三种层级的分割,这也是SAM兼容的。要达到这种效果就需要足够的高质量分割数据,SAM团队用他们提出的Data Engine策略成功使用人工加模型自动标注的方式制作除了一个有10亿个masks的分割数据集**SA-1B**,这也是他们核心的贡献之一,本文尾部会介绍相关流程。模型架构来说相对比较常规,主要是借鉴了ViT和DETR,本身创新不大。
如上图,SAM模型架构主要包括image encoder,prompt encoder和mask decoder三部分:
- image encoder,使用了ViT模型将图像编码得到image embedding
- prompt encoder,将point、box、mask、txt等提示信息进行编码,后续会和image embedding一起用于生成masks
- mask decoder,将上述两个模块得到的embeddings整合,然后结合两个可学习的tokens生成不同层级的masks和对应的置信度值
值得一提的是,prompt encoder和mask decoder都是非常轻量的,主要的计算开销都在image encoder上,这点从模型权重上也能看出来。以ViT_B为基础的SAM权重是375M,其中prompt encoder只有32.8k,mask decoder是16.3M(4.35%),剩余则是image encoder,可想而知图像编码这块是非常耗时的。因此在实际推理中,一般单张图的image embedding只计算一次,然后将结果缓存起来,需要的时候直接调用。在image embedding已经计算好的情况下,论文中说给定一个prompt,生成mask时prompt encoder和mask decoder在浏览器中的计算耗时也仅需50ms。下面会具体介绍下各模块的输入输出和流程,均只考虑batch size为1的情况,代码讲解在下一篇。
Image encoder
输入:
默认是1024x1024的图像,如尺寸不一致会将原图按最长边resize
输出:
单张图的1x256x64x64的image embedding,即编码后的图像特征
流程
上图是ViTViT论文中的结构图,image encoder整体流程和ViT是一样的,区别在于不需要[class]token做分类,只输出最终的图像编码张量
- 输入1024的图,拆分成64x64的768维patchs
- 经过attention block(window和global的MSA,相对位置编码)和MLP得到同样大小64x64x768embbeding特征
- 再经过neck得到1x256x64x64的图片embedding
这块有一篇文字介绍的更详细,如果想了解更多细节可以看这篇:Image encoder模块Vision Transformer网络解析。
Image encoder主要由attention block和neck组成,下面将根据代码简单介绍
attention block
1024x1024x3的图片经过一个patch_size=16的PatchEmbed层,将原图分为一个个小块patch,每个patch会在channel维度展开成向量(16x16x3=768),即得到1x64x64x768的patchs,然后patchs经过attention模块(window attention或者global attention),再经MLP得到与输入x一致大小的1x64x64x768
window_partition
拆分窗口,相当于把原图拆成多个小图叠在一起
ini
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) ##先按窗口大小拆分
## .permute将数据按窗口划分,再通过.view展成多个大小一直的小图
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
Attention类
window attention和global attention计算都一样,区别在于输入x的B H W大小,window的B是batch小窗口的个数,H W是窗口大小;二global的则是针对全图计算Attention
qkv是直接用一个全连接层一次性得到然后拆分成3个多个注意力的q, k, v
相对位置编码和swin transformer类似,只针对query计算,直接加到attention上
ini
def forward(self, x: torch.Tensor) -> torch.Tensor:
## window attention的B是batchx小窗口的个数,H W是窗口大小;
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C) ## self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
## self.scale多头注意力的根号d,self.scale = (dim_input // num_heads )**-0.5
attn = (q * self.scale) @ k.transpose(-2, -1)
if self.use_rel_pos:
##self.rel_pos_h是网络学到的参数
attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
attn = attn.softmax(dim=-1)
x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x) ## self.proj = nn.Linear(dim, dim)
return x
相对位置编码
即add_decomposed_rel_pos函数的实现,self.rel_pos_h就是类似swin transformer的相对位置偏移表(长度为2*M-1,因为可取的值是[-M+1,M-1]共2*M-1个),是网络学习到的
ini
# 非global的input_size[0] = input_size[1] = 14,global是64
# multi-head attention中的head维度:head_dim = block输出的维度768 / head的数量12 = 64
self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim))
self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim))
根据窗口大小生成对应相对位置索引index,再从self.rel_pos_h中取对应索引位置的偏移向量得到Rh(14x14x64 ),再把query展成多个小窗口形式与Rh矩阵乘法得到最终相对位置编码,attn上直接加上h和w方向的相对位置编码即可
python
def add_decomposed_rel_pos(
attn: torch.Tensor,
q: torch.Tensor,
rel_pos_h: torch.Tensor,
rel_pos_w: torch.Tensor,
q_size: Tuple[int, int],
k_size: Tuple[int, int], ) -> torch.Tensor:
""" Calculate decomposed Relative Positional Embeddings
Args:
attn (Tensor): attention map.
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
k_size (Tuple): spatial sequence size of key k with (k_h, k_w).
Returns:
attn (Tensor): attention map with added relative positional embeddings.
"""
# q: 300x196x64
# atten:300x196x196
q_h, q_w = q_size
k_h, k_w = k_size
# Rh: 14x14x64
Rh = get_rel_pos(q_h, k_h, rel_pos_h)
Rw = get_rel_pos(q_w, k_w, rel_pos_w)
B, _, dim = q.shape
# r_q: 300x14x14x64
r_q = q.reshape(B, q_h, q_w, dim)
# rel_h: 300x14x14x14
# 等价于:
# rel_h = torch.matmul(r_q, Rh.transpose(1, 2))
# rel_w = torch.matmul(r_q.transpose(1, 2), Rw.transpose(1, 2)).transpose(1, 2)
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)
# 将相对位置编码加在atten里面,再resize回300x196x196
attn = (attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]).view(B, q_h * q_w, k_h * k_w)
return attn
最终attention block的forward代码
window_size为0计算全局attention,否则计算局部的window attention
ini
def forward(self, x: torch.Tensor) -> torch.Tensor:
shortcut = x
x = self.norm1(x)
# Window partition
if self.window_size > 0:
H, W = x.shape[1], x.shape[2]
x, pad_hw = window_partition(x, self.window_size)
x = self.attn(x)
# Reverse window partition
if self.window_size > 0:
x = window_unpartition(x, self.window_size, pad_hw, (H, W))
x = shortcut + x
x = x + self.mlp(self.norm2(x))
return x
neck降低embedding维度
ini
Sequential(
(0): Conv2d(768, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): LayerNorm2d()
(2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(3): LayerNorm2d()
)
neck部分由两个卷积层组成,分别是256x768x1x1和256x256x3x3,最后输出的image imbedding的尺寸是1x256x64x64
encoder的forward
python
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.patch_embed(x) ## 拆成patchs
## 网络学习到的绝对位置编码,absolute positional embedding
if self.pos_embed is not None:
x = x + self.pos_embed
### tansformer block由window attention和global attention组成
for blk in self.blocks:
x = blk(x)
## 上面得到的x是 1*64*64*768
x = self.neck(x.permute(0, 3, 1, 2))
## 最后输出的image imbedding的尺寸是1*256*64*64
return x
Prompt encoder
输入:
point、box、mask、txt(代码未实现)等prompt,格式一般如下,B为batch size
- point需要包含点的x,y坐标BxNx2和label(0为前景,1位背景)BxNx1
- box包含框的左上和右下两个点,BxNx4,对于某个gt即单个mask,只会有1个box;如果输入的是N个box最终会生成N个masks
- mask一般和SAM最终输出mask的hxw(256x256),Bx1xHxW
- txt在SAM代码中未实现,这块可以参考Grounded-Segment-Anything
输出两个:
-
sparse_embeddings 点和框的稀疏嵌入,形状为BxNx(embed_dim),其中N由输入点和框的数量确定,如果两者同时有则N的计算方式为(点的个数+2x框的个数)
- point box 全都没有,输出大小:Bx0x256
- 如果只有point,输出大小:Bx(N+1)x256,会补充一个[0,0]空点在最后,label为-1,表示只有点提示;
- 如果只有box,输出大小: (B*N)x2x256
- piont、box都有,输出大小:BxNx256
-
dense_embeddings 掩码的密集嵌入,形状为Bx(embed_dim)x(embed_H)x(embed_W),默认大小为Bx256x64x64,没有提示时会返回一个网络学习到的no mask默认嵌入
流程
网络已自动学会了针对不通过类型提示的编码信息,输入的point、box、mask等提示加上位置编码后,再加上网络学会的综合编码信息,最终对point、box这种稀疏的提示会返回sparse embedding, 对mask会返回dense embeddings(没有mask提示时是网络学习到的embeddings)。这部分就相当于把各种提示转换为decoder能理解的格式。
point embedding
输入points是个tuple,包含(point_coords, point_labels),point_coords一般是BxNx2的一些列点的xy坐标,point_labels是BXN个点对应的label,0代表是一个背景点,1代表是前景点(需要分割出mask的部分),以下都以1个点和1个box为提示分析
-
step1:首先生成一组可学习的向量point embedding,大小为:4x1x256,即前景/背景和框的左上右下两个点:
scss## ModuleList((0-3): 4 x Embedding(1, 256)); 4个点代表 pos/neg point + 2 box corners self.point_embeddings = nn.ModuleList([nn.Embedding(1, 256) for i in range(4)])
-
step2:再生成一组可学习的向量not_a_point_embed,大小为1x256,用于表示该位置不是一个点
iniself.not_a_point_embed = nn.Embedding(1, embed_dim)
-
step3:点的padding
N为传入点的个数,点point表示为BxNx2的tensor,如果prompt里面没有bbox只有点,则补充一个[0,0]点到points后面,其对应的label为-1,此时point大小为Bx(N+1)x2,label为Bx(N+1);
如果传入的还有bbox,此时的point大小为BxNx2,label为BxN(不加pad)
ini## 没有bbox,补充【0,0】点到每个point后面,其对应的label为-1 if pad: padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device, dtype=points.dtype) padding_label = -torch.ones((labels.shape[0], 1), device=labels.device, dtype=points.dtype) points = torch.cat([points, padding_point], dim=1) labels = torch.cat([labels, padding_label], dim=1)
-
step4:计算Positionally encode:
点的横纵坐标除以输入尺寸w,h(1024,1024)到[0,1]之间,再和随机高斯矩阵(2x128, positional_encoding_gaussian_matrix)相乘后,算sin cos值后h和w拼在一起得到点的位置编码(BxNx256),N=2表示只有point,N=1表示还有box没有pad一个空点
pythondef _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: """Positionally encode points that are normalized to [0,1].""" # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape coords = 2 * coords - 1 coords = coords.to(self.positional_encoding_gaussian_matrix.dtype) coords = coords @ self.positional_encoding_gaussian_matrix ##随机高斯矩阵 2x128 coords = 2 * np.pi * coords # outputs d_1 x ... x d_n x C shape return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
-
step5:计算最终的point embedding
上一步得到的256维向量,叠加学习到的embedding向量(非点、背景点、前景点,各自分别学习了)最终为BxNx256的向量
inipoint_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) point_embedding[labels == -1] = 0.0 point_embedding[labels == -1] += self.not_a_point_embed.weight # 对应label为-1的padding点,加上not_a_point_embed point_embedding[labels == 0] += self.point_embeddings[0].weight # neg点加上point_embeddings[0] point_embedding[labels == 1] += self.point_embeddings[1].weight # pos点加上point_embeddings[1]
完整的embedding流程
ini
def _embed_points(
self,
points: torch.Tensor,
labels: torch.Tensor,
pad: bool, ) -> torch.Tensor:
"""Embeds point prompts."""
points = points + 0.5 # Shift to center of pixel
# 如果没有输入的box的话,会将points的长度用0补充形成Bx(N+1)x2,label用【-1】补充成Bx(N+1)
if pad:
padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
points = torch.cat([points, padding_point], dim=1)
labels = torch.cat([labels, padding_label], dim=1)
# 将points与一个2x128的随机高斯矩阵相乘再通过进行sin、cos运算,两者的运算结果拼接得到
# point_embedding: BxNx256 或者 Bx(N+1)x256
point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)
point_embedding[labels == -1] = 0.0
point_embedding[labels == -1] += self.not_a_point_embed.weight
point_embedding[labels == 0] += self.point_embeddings[0].weight
point_embedding[labels == 1] += self.point_embeddings[1].weight
return point_embedding
最后输出的结果是sparse_embeddings BxNx256,如果prompt只有点则sparse_embeddings就是Bx(N+1)x256的point_embedding,如果还有box,那最终结果是BxNx256的point_embedding和叠加后面Bx1x256的box embedding拼在一起当作最终的sparse_embeddings
box embedding
对于某个gt即单个mask,只会有1个box;如果输入的是N个box最终会生成N个masks。一个box有两个点,左上和右下角点,embedding步骤如下
- step1:先resize为Nx2x2;N代表多个框
- step2:再使用point embedding一样的位置编码方式,得到corner_embedding位置编码后的Nx2x256向量
- step3:再加上之前网络学习到的box角点embeding向量(0 1是前景背景,2 3是box的左上右下)
python
def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:
"""Embeds box prompts."""
boxes = boxes + 0.5 # Shift to center of pixel
coords = boxes.reshape(-1, 2, 2) ## 操作与points类似,将2个点resize成Nx2x2
## 使用point embedding编码的方式,得到corner_embedding Nx2x256
corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size)
## 再加上学习到的点的embedding
corner_embedding[:, 0, :] += self.point_embeddings[2].weight #2 3是box的左上右下
corner_embedding[:, 1, :] += self.point_embeddings[3].weight #2 3是box的左上右下
return corner_embedding ##Nx2x256,N代表多个框
最后输出的box的embedding的尺寸是Nx2x256,N代表多个框
sparse embedding
point和box的prompt最终结果都是sparse embedding,因此代码里会将两者合并起来得到最终结果,
- 如果只有point,输出大小为Bx(N+1)x256,因为padding了1个空点,比如一个点时为Bx2x256
- 如果只有box,输出大小为 (B*N)x2x256,比如一个box时为Bx2x256
- point、box都有,输出大小:BxNx256,N的计算方式为(点的个数+2x框的个数),比如一个point和一个box时为Bx3x256
dense_embeddings
dense_embeddings也是mask embedding,即针对全图的编码,分两种情况,输入的prompt有无mask
-
有mask提示,则经过self.mask_downscaling卷积网络不断下采样,简单粗暴得到Nx256x64x64维的embedding(N是输入mask的个数)
iniself.mask_downscaling: Sequential( (0): Conv2d(1, 4, kernel_size=(2, 2), stride=(2, 2)) (1): LayerNorm2d() (2): GELU(approximate='none') (3): Conv2d(4, 16, kernel_size=(2, 2), stride=(2, 2)) (4): LayerNorm2d() (5): GELU(approximate='none') (6): Conv2d(16, 256, kernel_size=(1, 1), stride=(1, 1))
-
无mask提示,网络有一个1x256的向量self.no_mask_embed表示无mask提示时的特征向量,直接将其复制expand成Nx256x64x64的embedding
iniself.no_mask_embed = nn.Embedding(1, embed_dim) dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] )
Mask decoder
输入:
- image encoder得到的image_embeddings和图像的positional encoding
- prompt encoder得到的prompt embeddings(sparse和dense两种)
输出:
- masks,如果指定了"multimask_output"参数则会输出3个层级的mask(whole, part, and subpart),否则只输出1个mask
- IoU scores,可以理解为每个mask的置信度,由网络中的iou token得到
流程
-
首先会image_embeddings会混入dense embeddings的信息(两者直接相加),sparse embeddings则会与mask token和IoU token拼在一起成为一个新的token,mask token后续会用于生成mask,IoU token用于衡量每个mask的好坏
-
然后这个新的token和image_embeddings经过一个TwoWayTransformer模块(下图黄色框部分),先做token的self attention,然后做token(作为key)到图像的cross attention,经过MLP更新token,最后再图像(作为key)到token的attention,目的是不断更新图像和token中的信息,会重复两次
-
更新后token再做一次token(作为key)到图像的cross attention后,又拆出来之前的两个部分mask token和IoU token,后者就代表每个mask的置信度;
而图像信息经过转置卷积还原到原图大小后,会和mask token做矩阵乘法生成最终的masks,类似 YOLACT中的"prototype masks"和"mask coefficients"矩阵乘法
output_tokens
类似NLP和ViT中的[cls]token用来分类,decoder中定义了两个可学习的token辅助生成mask,两个拼在一起就是output_tokens:
- iou_token(1x256),会用于计算后续IoU scores(上图绿色部分)
- mask_tokens(4x256) ,用于生成最终的mask(上图红色框部分),分别对应单张mask(仅在不需要多层mask时启用)+3种层级的mask(whole, part, and subpart)
ini
self.iou_token = nn.Embedding(1, transformer_dim) ##transformer_dim=256
self.num_mask_tokens = num_multimask_outputs + 1 ##num_multimask_outputs=3,对应3种层级的mask(whole, part, and subpart),+1是首层单张mask(仅在不需要多层mask时启用)
self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim)
sparse embedding处理
output_tokens和promp encoder得到的sparse embedding(点、框的稀疏提示)会concate在一起,当成新的tokens(只有1个点作为prompt时,为Nx7x256,后续所有都按只有1个点prompt情况分析),对应上图左下角
ini
output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) #5x256
##点、框的稀疏提示,沿batch方向复制成相同维度,得到Nx5x256
output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1)
tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1)
image_embeddings处理
image encoder得到了1x256x64x64的image_embeddings,将其从batch维度拓展成和tokens一样的Nx256x64x64(N为prompt个数),然后和dense_embeddings相加,即加入全图mask的稠密提示信息(有mask作为prompt时是对应mask的embedding特征,无mask作为prompt时,是网络自己学到的embedding特征)
image_pe
image_pe是1x256x64x64的位置编码特征,编码方式和prompt encoder中给point的Positionally encode方法一样(随机高斯矩阵后正余弦固定编码),同样展成Nx256x64x64的大小
ini
# Expand per-image data in batch direction to be per-mask
src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0)
src = src + dense_prompt_embeddings
pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0)
b, c, h, w = src.shape
TwoWayTransformer
将上述src、pos_src、tokens经过TwoWayTransformer得到两个输出hs(即更新后的tokens), src(更新后的image_embeddings+dense_prompt_embeddings,包含原图和mask prompt的信息)
ini
# Run the transformer
hs, src = self.transformer(src, pos_src, tokens)
网络主要由两个TwoWayAttentionBlock组成,计算有以下特点:
- tokens(iou+mask+sparse embedding)主要作为query(1x7x256),query_pe即query的位置编码信息使用tokens本身替代;
- src即image_embeddings,是作为key的(展成了1x4096x256);pos_src即image_pe,是image_embeddings的位置编码信息(1x4096x256)
- 每次算Attention不管query还是key都会加上其对应的位置编码信息,类似ShortCut
TwoWayAttentionBlock
TwoWayAttentionBlock主要包含四层,对应流程那块图的黄色框四层部分
-
针对query(即tokens,iou+mask+sparse embedding)的self-atten,但首层无位置编码
ini# Self attention block if self.skip_first_layer_pe: queries = self.self_attn(q=queries, k=queries, v=queries) else: q = queries + query_pe attn_out = self.self_attn(q=q, k=q, v=queries) queries = queries + attn_out queries = self.norm1(queries)
-
tokens到image的cross attention(tokens作为query, image_embeddings作为key)
ini# Cross attention block, tokens attending to image embedding q = queries + query_pe k = keys + key_pe attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) queries = queries + attn_out queries = self.norm2(queries)
-
针对tokens的mlp block
ini# MLP block mlp_out = self.mlp(queries) queries = queries + mlp_out queries = self.norm3(queries)
-
与第二层相反,image到tokens的cross attention(image_embeddings作为query, tokens作为key)
ini# Cross attention block, image embedding attending to tokens q = queries + query_pe k = keys + key_pe attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) keys = keys + attn_out keys = self.norm4(keys)
经过两层TwoWayAttentionBlock更新queries和keys后,再来一个final attention层,更新得到最终的queries即hs, keys即src
ini
# Apply the final attention layer from the points to the image
q = queries + point_embedding
k = keys + image_pe
attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)
queries = queries + attn_out
queries = self.norm_final_attn(queries)
整个流程较为复杂,参考下图
预测masks
transformer把输入的tokens和image embedding更新后变成了hs(1x7x256), src(1x4096x256),接下来就是利用这两个信息去上采样生成mask。整个流程可以参考下图,queries就是我们的hs,keys就是src
hs,即更新后的tokens
此前的tokens是由iou_token和mask_tokens拼接得到,所以从hs上也可以拆下来更新后的那两个token,其中iou_token_out表示生成mask的质量即IoU scores,mask_tokens_out用于生成不同层级mask
ini
# Run the transformer,update tokens和image embedding
hs, src = self.transformer(src, pos_src, tokens)
## hs 1x7x256; src: 1x4096x256
iou_token_out = hs[:, 0, :] ## 1x1x256
mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] ## 1x4x256
src,即更新后的mask embeddings(从image embeddings+dense embeddings学习得来)
把1x4096x256的mask embeddings变成1x256x64x64(和image encoder输出的大小对应),再经过两层转置卷积进行上采样(到原图大小)变为1x32x256x256的upscaled_embedding
ini
# Upscale mask embeddings and predict masks using the mask tokens
src = src.transpose(1, 2).view(b, c, h, w)
upscaled_embedding = self.output_upscaling(src)
基于mask_tokens和上采样后的mask embeddings做矩阵乘法得到mask
1x4x256的mask_tokens,4层包含不同层级only one, whole, part, subpart,所以用4个不同的全连接层(三层MLP网络)降低embedded维度,从256维降低为32维,再把4层结果又concat在一起得到1x4x32维的hyper_in(可以理解为超维压缩特征)
最终由1x4x32的hyper_in和1x32x256x256的upscaled_embedding矩阵乘法得到1x4x256x256的masks结果 (为了可以矩阵运算,需要将后者h w合并成在一个维度运算完再展开) 这块来自 YOLACT的思路 (YOLOV5分割模型也是用的这种,github有个讲解比较详细的github.com/ultralytics... ,也可以参考这篇 YOLACT)
此处只针对单个点prompt(N=1)分析,一个prompt对应一组mask结果
ini
hyper_in_list: List[torch.Tensor] = []
for i in range(self.num_mask_tokens):
##1x1x256 变为 1x1x32
hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :]))
hyper_in = torch.stack(hyper_in_list, dim=1) ##1x4x32
b, c, h, w = upscaled_embedding.shape
masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) ## 1x4x256x256
iou_pred
之前得到的1x256的iou_token经过IoU预测head(全连接+sigmoid)直接得到1x4的iou_pred,即对应上面4层生成mask的IoU值
ini
# Generate mask quality predictions
iou_pred = self.iou_prediction_head(iou_token_out)
后处理到原图大小
预测得到的1x4x256x256 masks,根据参数multimask_output决定是否输出多层级mask,如果是则取后3层mask,否则取最顶层,iou_pred也取对应的;
得到的mask是256x256大小,先插值到网络输入大小(1024x1024),再去掉padding部分,最后再插值缩放到原图大小hxw
ini
## self.image_encoder.img_size=1024
masks = F.interpolate( masks, (self.image_encoder.img_size, self.image_encoder.img_size), mode="bilinear", align_corners=False, )
masks = masks[..., : input_size[0], : input_size[1]]
masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False)
整图分割推理(segment everything)
流程
在图片上生成32x32的网格,得到1024个采样点,每个采样点都当做1个前景的prompt进入prompt encoder然后和image encoder结果一起生成mask,每次会处理一个batch(默认64)的采样点;每个batch得到的mask都会进行以下几个过滤:
- predicted IoU过滤,mask decoder除了返回masks还会预测对应mask iou值,过滤低置信度(默认阈值0.88)的mask
- stability score过滤,stability score是mask在两个阈值下二值化后的IoU值,可以理解为改变过滤阈值后还能得到同样mask的能力,过滤低于0.95的mask
- mask threshold过滤,直接过滤mask logits值低于mask_threshold(默认0.0)的mask
- boundary过滤,每个mask生成外界矩形,过滤超过图像边界的mask
所有batch过滤后的的masks结果再进行nms过滤(mask对应外接矩形的nms,阈值0.7)就得到最终的分割结果
图片crop
一般默认是基于整张原图去生成32x32的网格点,代码里可以配置参数crop_n_layers(默认0)将原图进行切分,每条边切分成2**crop_n_layers份,如值为1对应每边切成2份得到2x2个小的子图,值为2对应每边切成4份得到4x4个子图。切分时子图会有一定重叠区域(overlap_ratio参数控制),某种程度防止把一个物体切成两份了?
对应函数是 generate_crop_boxes,返回的是每个切分层级得到的子图坐标,以1024x1024原图,crop_n_layers为1,重叠比例为512/1500为例,返回结果如下:
yaml
##第一个代表原图不切分,后面4个即2x2的子图区域
[[0, 0, 1024, 1024],
[0, 0, 687, 687], [0, 338, 687, 1024], [338, 0, 1024, 687], [338, 338, 1024, 1024]]
为方便理解后续都按crop_n_layers=0不切分情况分析
网格点生成
n_per_side为每个边需要采样的点数,默认是32,坐标范围是[0,1]之间,x,y方向都采样得到32x32的网格,将对应网格坐标乘以原始图片的长宽,得到1024个采样点在原图的坐标,这些点都作为prompt进行mask生成
python
def build_point_grid(n_per_side: int) -> np.ndarray:
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
offset = 1 / (2 * n_per_side)
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
return points
batch个点推理得到masks
1024个点拆分成多个batch推理,1个batch64个点,网络输入大小默认1024x1024,图片原始尺寸会按最长边进行resize到1024(如960x540的图变1024x576),上面得到点坐标是对应原图大小的,需要经apply_coords进一步转换到resize后的图的坐标,每个点的label都是1,即都作为前景点去得到mask;然后整个batch的经过prompt_encoder得到对应点的sparse embedding和网络学到的dense_embeddings(因为无mask prompt),再和image encoder得到的image imbedding一起经过mask decoder得到最终masks
ini
# 坐标位置对应到resize后的图上
transformed_points = self.predictor.transform.apply_coords(points, im_size)
in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) #1024个都是前景点,label=1
## 常规的基于点提示得到masks流程
masks, iou_preds, _ = self.predictor.predict_torch(
in_points[:, None, :],
in_labels[:, None],
multimask_output=True,
return_logits=True,
)
batch内masks过滤
predicted IoU过滤
mask decoder除了返回masks还会预测对应mask iou值,过滤低置信度(默认阈值0.88)的mask
python
# Filter by predicted IoU
if self.pred_iou_thresh > 0.0:
keep_mask = data["iou_preds"] > self.pred_iou_thresh ##默认阈值0.88
data.filter(keep_mask)
stability score过滤
stability score是mask在两个阈值下二值化后的IoU值,可以理解为改变过滤阈值后还能得到同样mask的能力,计算方式如下代码:
ini
## mask_threshold=0.0, threshold_offset=1.0
intersections = ( (masks > (mask_threshold + threshold_offset)).sum(-1, dtype=torch.int16).sum(-1, dtype=torch.int32))
unions = ( (masks > (mask_threshold - threshold_offset)) .sum(-1, dtype=torch.int16) .sum(-1, dtype=torch.int32))
stability_score = intersections / unions
过滤低于self.stability_score_thresh(默认0.95)的mask
mask threshold过滤
直接过滤mask logits值低于mask_threshold(默认0.0)的mask
css
data["masks"] = data["masks"] > self.predictor.model.mask_threshold
boundary过滤
每个mask生成外界矩形,过滤超过图像边界的mask
css
data["boxes"] = batched_mask_to_box(data["masks"])
# Filter boxes that touch crop boundaries
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
if not torch.all(keep_mask):
data.filter(keep_mask)
nms过滤
所有batch的过滤结果会再经过一次nms过滤,基于对应mask的外接矩形进行nms过滤,如下代码
ini
# Remove duplicates within this crop.
keep_by_nms = batched_nms(
data["boxes"].float(),
data["iou_preds"],
torch.zeros_like(data["boxes"][:, 0]), # categories
iou_threshold=self.box_nms_thresh,
)
data.filter(keep_by_nms)
如果配置了图片crop,即多层级的子图得到的mask还会再进行一次nms过滤得到原图最终的masks
最终结果
有个参数output_mode可以设置输出mask的类型,默认是"binary_mask",会比较耗内存,还可以选择coco的RLE编码格式,参数为'uncompressed_rle', or 'coco_rle'
git上也有官方demo可以参考:全图分割的官方demo
数据引擎(data engine)
SAM除了模型外,还公开了一份有10亿个masks的1100万张图的分割数据集**SA-1B**,基于他们提出的data engine方案得到,这块的贡献也是非��显著,也体现了Data-centric AI的惊人能力,[这块知乎上"一堆废纸"博主介绍的比较好](如何评价Meta/FAIR 最新工作Segment Anything? - 一堆废纸的回答 - 知乎 www.zhihu.com/question/59...%25E3%2580%2582%25E4%25BB%258E%25E8%25AE%25BA%25E6%2596%2587%25E9%2587%258C%25E6%2580%25BB%25E7%25BB%2593%25E5%25B0%25B1%25E6%2598%25AF%25E8%25BE%2585%25E5%258A%25A9%25E4%25BA%25BA%25E5%25B7%25A5%25E6%25A0%2587%25E6%25B3%25A8%25E3%2580%2581%25E5%258D%258A%25E8%2587%25AA%25E5%258A%25A8%25E6%25A0%2587%25E6%25B3%25A8%25E3%2580%2581%25E5%2585%25A8%25E8%2587%25AA%25E5%258A%25A8%25E6%25A0%2587%25E6%25B3%25A8%25E4%25B8%2589%25E6%25AD%25A5%25EF%25BC%258C%25E5%2585%25B7%25E4%25BD%2593%25E5%25A6%2582%25E4%25B8%258B%25EF%25BC%259A "https://www.zhihu.com/question/593888697/answer/2972047807)%E3%80%82%E4%BB%8E%E8%AE%BA%E6%96%87%E9%87%8C%E6%80%BB%E7%BB%93%E5%B0%B1%E6%98%AF%E8%BE%85%E5%8A%A9%E4%BA%BA%E5%B7%A5%E6%A0%87%E6%B3%A8%E3%80%81%E5%8D%8A%E8%87%AA%E5%8A%A8%E6%A0%87%E6%B3%A8%E3%80%81%E5%85%A8%E8%87%AA%E5%8A%A8%E6%A0%87%E6%B3%A8%E4%B8%89%E6%AD%A5%EF%BC%8C%E5%85%B7%E4%BD%93%E5%A6%82%E4%B8%8B%EF%BC%9A")
- 第一步以人工标注为主。初始模型在公开数据集训练后辅助生成masks,再人工精修调整,再用标好的新数据迭代模型。如此重复6次,从12万张图得到430万masks
- 第二步是模型半自动标注高置信度masks,然后人工标注补充剩余未标出的masks。mask的置信度判断是用一个模型对mask进行目标检测,如果能检测出物体则是置信度较高mask无需再人工标注,这个目标检测模型是基于第一步得到的数据训练的。如此迭代5次,从18万张图新增了590万masks
- 第三部是模型全自动标注。基于此前两步的数据得到模型,已有较好的分割能力且能适配模糊提示分割(局部mask或者整体mask),对一张图撒32x32的网格点进行segment everything,后处理会挑选搞IoU和搞稳定性的masks并做NMS得到全图最终的masks。针对所有图片自动分割,最终得到了SA-1B数据集