目录
[4.2 回到MaskDecoder.forward](#4.2 回到MaskDecoder.forward)
[5.1 obj_ptr = self.obj_ptr_proj(sam_output_token)](#5.1 obj_ptr = self.obj_ptr_proj(sam_output_token))
[5.2 对象记忆的"褪色与重置"机制](#5.2 对象记忆的"褪色与重置"机制)
一、前言
上一篇中我们走出了MaskDecoder.forward的self._dynamic_multimask_via_stability函数,这一篇我们看看接下来发生什么,其实就是选择了一个稳定掩码之后返回这个掩码masks、掩码分数iou_pred、目标存在性分数object_score_logits,还有把之前我们得到的(1,4,256)的mask_tokens_out取第0个token给到变量sam_tokens_out作为物体记忆token也返回 。
masks: torch.Size([1, 1, 256, 256])
iou_pred:tensor([[0.9436]], device='cuda:0') 数值只是示例
sam_tokens_out: torch.Size([1, 1, 256])
object_score_logits: torch.Size([1, 1]) 即 tensor([[24.0962]], device='cuda:0') 数值只是示例
然后就回到_forward_sam_heads,主要操作就是把256的mask上采样到1024,然后就是之前的(1,4,256)的mask_tokens_out取第0个token得到的(1,1,256)的sam_output_token,然后取第0个得到(1,256)的sam_output_token,就是那4个掩码token的第0个。然后经过MLP变成obj_ptr,然后如果物体消失的话obj_ptr要加self.no_obj_ptr(self.no_obj_ptr是一个(1,256)的可学习的矩阵),obj_ptr就是物体的记忆指纹,如果物体消失的话记忆要归零。
最后返回的就是低分辨的256的mask,高分辨率的1024的mask,ious就是最佳mask对应的置信度,obj_ptr就是用于记忆物体的,object_score_logits就是用于判断物体是否存在的(如果大于0就是存在)
low_res_multimasks: torch.Size([1, 1, 256, 256])
high_res_multimasks: torch.Size([1, 1, 1024, 1024])
ious: tensor([[0.9436]], device='cuda:0')
low_res_masks: torch.Size([1, 1, 256, 256])
low_res_masks: torch.Size([1, 1, 1024, 1024])
obj_ptr: torch.Size([1, 256])
object_score_logits: object_score_logits: tensor([[24.0962]], device='cuda:0')
感觉现在还没有完全理解SAM2跟踪对物体消失的处理方式,而且我觉得需要新开一篇文章将SAM2记忆机制与DeepSeek OCR的记忆机制进行比对。
四、MaskDecoder
4.2 回到MaskDecoder.forward
sam2/modeling/sam/mask_decoder.py
python
def forward(
self,
image_embeddings: torch.Tensor,
image_pe: torch.Tensor,
sparse_prompt_embeddings: torch.Tensor,
dense_prompt_embeddings: torch.Tensor,
multimask_output: bool,
repeat_image: bool,
high_res_features: Optional[List[torch.Tensor]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict masks given image and prompt embeddings.
Arguments:
image_embeddings (torch.Tensor): the embeddings from the image encoder
image_pe (torch.Tensor): positional encoding with the shape of image_embeddings
sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes
dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs
multimask_output (bool): Whether to return multiple masks or a single
mask.
Returns:
torch.Tensor: batched predicted masks
torch.Tensor: batched predictions of mask quality
torch.Tensor: batched SAM token for mask output
"""
# 输入:
# image_embeddings: torch.Size([1, 256, 64, 64])
# image_pe:torch.Size([1, 256, 64, 64])
# sparse_embeddings: torch.Size([1, 3, 256])
# dense_embeddings : torch.Size([1, 256, 64, 64])
# multimask_output:False
# repeat_image: False
# high_res_features:[
# torch.Size([1, 32, 256, 256]),
# torch.Size([1, 64, 128, 128])
# ]
# >>> 1. 先把所有 embedding 喂给 mask decoder,拿到 4 个输出
masks, iou_pred, mask_tokens_out, object_score_logits = self.predict_masks(
image_embeddings=image_embeddings,
image_pe=image_pe,
sparse_prompt_embeddings=sparse_prompt_embeddings,
dense_prompt_embeddings=dense_prompt_embeddings,
repeat_image=repeat_image,
high_res_features=high_res_features,
)
# mask: torch.Size([1, 4, 256, 256])
# iou_pred: tensor([[0.9436, 0.9098, 0.9337, 0.9457]], device='cuda:0')
# mask_tokens_out: torch.Size([1, 4, 256])
# object_score_logits: torch.Size([1, 1]) 比如 tensor([[24.0962]], device='cuda:0')
# Select the correct mask or masks for output
# multimask_output:False
if multimask_output:
# >>> 2-a. 训练/多 mask 模式:只要后 3 个 mask(跳过第 0 个"默认" mask)
masks = masks[:, 1:, :, :]
iou_pred = iou_pred[:, 1:]
# iou_pred: tensor([[0.9436]], device='cuda:0')
# self.dynamic_multimask_via_stability:True self.training:False
elif self.dynamic_multimask_via_stability and not self.training:
# >>> 2-b. 测试阶段且开 stability 筛 mask:自动挑一个最稳的
# masks: torch.Size([1, 4, 256, 256])
# iou_pred: torch.Size([1, 4])
masks, iou_pred = self._dynamic_multimask_via_stability(masks, iou_pred)
# masks: torch.Size([1, 1, 256, 256])
# iou_scores_out: tensor([[0.9436]], device='cuda:0')
else:
# >>> 2-c. 默认单 mask 模式:直接取第 0 个通道
masks = masks[:, 0:1, :, :]
iou_pred = iou_pred[:, 0:1]
# multimask_output: False self.use_multimask_token_for_obj_ptr:True
if multimask_output and self.use_multimask_token_for_obj_ptr:
# >>> 3-a. 多 mask 且要把 token 当 object pointer:用后 3 个 token
sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c] shape
else:
# >>> 3-b. 其余情况(包括单 mask)一律用第 0 个 token 当"物体记忆"
# 获取掩码输出token。这里我们*总是*使用单掩码输出的token。
# 在测试时,即使我们在1次点击后进行跟踪(且设置multimask_output=True),
# 这里仍然使用单掩码token。理由是:在训练期间我们总是在多次点击后才进行跟踪,
# 因此训练过程中看到的过往token始终是单掩码token(我们将它作为对象记忆token)。
# mask_tokens_out: torch.Size([1, 4, 256])
sam_tokens_out = mask_tokens_out[:, 0:1] # [b, 1, c] shape
# sam_tokens_out: torch.Size([1, 1, 256])
# Prepare output
# masks: torch.Size([1, 1, 256, 256])
# iou_pred:tensor([[0.8732]], device='cuda:0')
# sam_tokens_out: torch.Size([1, 1, 256])
# object_score_logits: torch.Size([1, 1]) 即 tensor([[24.0962]], device='cuda:0')
return masks, iou_pred, sam_tokens_out, object_score_logits
代码整体流程一句话总结
用 predict_masks 一次性生成 4 组 mask 及其对应 IoU、token、objectness。
根据 multimask_output 标志和 dynamic_multimask_via_stability 策略,决定到底留几个 mask:
训练/多 mask 模式 → 留 3 个;
测试开 stability → 自动挑 1 个最稳的;
其余 → 直接拿第 0 个。
再按同样逻辑挑一个(或 3 个)token 作为后续跟踪用的"物体记忆"。
把最终 mask、IoU、token、objectness 分数一起返回。
4.2.2 sam_tokens_out是什么?
你困惑的核心是:sam_tokens_out到底是什么?为什么要这样选?
这是SAM在视频跟踪 场景下的关键设计,我把它叫 "物体身份证" 。
1. sam_tokens_out的本质
sam_tokens_out = tensor([[[ 0.2, -0.5, 1.1, ..., 0.8 ]]]) # 形状: [1, 1, 256]
这是一个 256维的向量 ,是模型对"这个物体是谁"的抽象编码
就像人的身份证号,同一物体在不同帧中保持相同的身份特征
在视频跟踪时,这个token会被保存下来,在下一帧作为"记忆"提示模型
2. 为什么总是用第0个token?(最关键)
# 即使 multimask_output=True,这里仍然用第0个 sam_tokens_out = mask_tokens_out[:, 0:1]这是SAM刻意训练的结果:
阶段 发生了什么 记忆token是谁 训练时 用户点击物体5次 → 模型输出单掩码 → 保存token 第0个token 测试时 第1帧:用户点击1次 → 模型输出多掩码 仍然用第0个token做记忆 跟踪时 第2-N帧:用保存的token识别物体,无需再点击 第0个token持续使用 核心逻辑:
模型在训练时只见过第0个token代表物体身份
如果测试时突然用第1-3个token,相当于给模型一个它不认识的新身份证
结果:身份混乱,跟踪会漂移(drift)
3. 形象的比喻:海关安检
# 第1帧(首次见到物体): 海关(模型)问:"你是谁?" 游客(物体)拿出护照(token 0):"我是张三,护照号12345" 海关记录:张三 → 12345 # 第2帧(跟踪): 海关问:"你是谁?" 游客拿出护照(token 0):"我是张三,护照号12345" 海关核对记录:"对上了,是同一个人,放行" # 如果游客拿出token 1的护照: "我是张三,护照号67890" 海关:"???我记录的是12345,你换了护照,不是同一个人"问题:为什么解码掩码用多掩码,但身份用单掩码?
# 解码掩码(当前帧"长什么样") masks = multimask_logits # 候选1: 全身, 候选2: 头部, 候选3: 尾巴 # 保存身份("你是谁") sam_tokens_out = token_0 # 不管长什么样,身份证始终是0号分离设计:
掩码 :描述外观(可变,多候选)
Token :描述身份(固定,单身份证)
4. 什么时候用第1-3个token?
if multimask_output and self.use_multimask_token_for_obj_ptr: sam_tokens_out = mask_tokens_out[:, 1:] # [b, 3, c]这只在纯多掩码输出模式 下启用,不用于跟踪。比如:
用户只想看当前帧的多个可能结果
不保存身份信息,看完就丢弃
总结
sam_tokens_out就是物体的"记忆指纹",为什么选第0个?因为SAM在训练时就是这么教的 :"第0个token,你给我记住,你就是物体的唯一身份证,不管后续有多少个候选掩码,身份不能乱。"
这就是为什么注释里强调 "we always use the token for single mask output" ------这是SAM能稳定跟踪的秘密武器。
五、回到_forward_sam_heads
sam2/modeling/sam2_base.py
python
@torch.inference_mode()
def _forward_sam_heads(
self,
backbone_features,
point_inputs=None,
mask_inputs=None,
high_res_features=None,
multimask_output=False,
):
"""
完整的 SAM 风格「提示编码 + 掩膜解码」一条龙:
把 backbone 特征与用户提示(点/框/mask)一起送 SAM,
输出多组或单组掩膜 logits、IoU 估计、对象指针等。
参数:
backbone_features (Tensor):
已融合记忆的图像嵌入,形状 (B, C, H, W),
其中 H=W=sam_image_embedding_size(默认 64)。
point_inputs (dict | None):
{"point_coords": (B, P, 2), "point_labels": (B, P)}
坐标为**绝对像素**(会在内部归一化),label: 1=前, 0=背, -1=pad。
mask_inputs (Tensor | None):
低分辨率掩膜提示 (B,1,H',W'),与点提示互斥。
high_res_features (list[Tensor] | None):
额外两层更高分辨率特征 [4H,4W] 与 [2H,2W],供 decoder refine 边缘。
multimask_output (bool):
True → 输出 3 候选掩膜 + 3 IoU;False → 1 掩膜 + 1 IoU。
返回:
tuple:
0. low_res_multimasks -- (B,M,H/4,W/4) 低分多掩膜 logits
1. high_res_multimasks -- (B,M,H,W) 高分多掩膜 logits
2. ious -- (B,M) 各掩膜 IoU 估计
3. low_res_masks -- (B,1,H/4,W/4) **最佳**低分掩膜
4. high_res_masks -- (B,1,H,W) **最佳**高分掩膜
5. obj_ptr -- (B,C) 对象指针(用于记忆)
6. object_score_logits -- (B,) 对象出现置信度(可软可硬)
"""
# backbone_features: torch.Size([1, 256, 64, 64])
# point_inputs: {
# 'point_coords': tensor([[[421.3333, 540.4445],
# [654.9333, 921.6000]]], device='cuda:0'),
# 'point_labels': tensor([[2, 3]],
# device='cuda:0', dtype=torch.int32)
# }
# mask_inputs:None
# high_res_features:[
# torch.Size([1, 32, 256, 256]),
# torch.Size([1, 64, 128, 128])
# ]
# multimask_output: False
B = backbone_features.size(0)
# B: 1
device = backbone_features.device
# device: device(type='cuda', index=0)
# --- 1. 输入断言:尺寸必须匹配 SAM 预设 ---
assert backbone_features.size(1) == self.sam_prompt_embed_dim
assert backbone_features.size(2) == self.sam_image_embedding_size
assert backbone_features.size(3) == self.sam_image_embedding_size
# --- 2. 构造点提示 ---
if point_inputs is not None:
sam_point_coords = point_inputs["point_coords"] # (B, P, 2)
# sam_point_coords: torch.Size([1, 2, 2])
# sam_point_coords: tensor([[[421.3333, 540.4445],
# [654.9333, 921.6000]]], device='cuda:0'),
sam_point_labels = point_inputs["point_labels"] # (B, P)
# sam_point_labels: torch.Size([1, 2])
# sam_point_labels: tensor([[2, 3]], device='cuda:0', dtype=torch.int32)
assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
else:
# 无点提示时,用 1 个 pad 点(label=-1)占位,保证 prompt encoder 正常 forward
sam_point_coords = torch.zeros(B, 1, 2, device=device)
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
# --- 3. 构造掩膜提示 ---
# mask_inputs: None
if mask_inputs is not None:
# 若外部 mask 分辨率不符,先双线性下采样到 prompt encoder 期望尺寸
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
sam_mask_prompt = F.interpolate(
mask_inputs.float(),
size=self.sam_prompt_encoder.mask_input_size,
align_corners=False,
mode="bilinear",
antialias=True, # 抗锯齿,减少下采样 aliasing
)
else:
sam_mask_prompt = mask_inputs
else:
# 无 mask 时,prompt encoder 内部会加 learned `no_mask_embed`
sam_mask_prompt = None
# --- 4. 送进 SAM 提示编码器 ---
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
points=(sam_point_coords, sam_point_labels),
boxes=None, # 框提示在外部已转成 2 点,不再走这里
masks=sam_mask_prompt,
)
# sparse_embeddings: torch.Size([1, 3, 256])
# dense_embeddings: torch.Size([1, 256, 64, 64])
# backbone_features: torch.Size([1, 256, 64, 64])
# multimask_output: False
# high_res_features:[
# torch.Size([1, 32, 256, 256]),
# torch.Size([1, 64, 128, 128])
# ]
# --- 5. 送进 SAM mask 解码器 ---
(
low_res_multimasks, # (B, M, H/4, W/4) M=3 or 1
ious, # (B, M) IoU 估计
sam_output_tokens, # (B, M, C) 解码器输出 token
object_score_logits, # (B,) 对象出现/消失 logits
) = self.sam_mask_decoder(
image_embeddings=backbone_features,
image_pe=self.sam_prompt_encoder.get_dense_pe(), # 固定 2D 位置编码
sparse_prompt_embeddings=sparse_embeddings,
dense_prompt_embeddings=dense_embeddings,
multimask_output=multimask_output,
repeat_image=False, # image 已 batch,无需再 repeat
high_res_features=high_res_features, # 高分特征供 refine 边缘
)
# low_res_multimasks: torch.Size([1, 1, 256, 256])
# ious: tensor([[0.9436]], device='cuda:0')
# sam_output_tokens: torch.Size([1, 1, 256])
# object_score_logits: torch.Size([1, 1]) 即 tensor([[24.0962]], device='cuda:0')
# --- 6. 对象 score 后处理:若模型预测"无对象",把掩膜 logits 置为 NO_OBJ_SCORE ---
# self.pred_obj_scores: True
if self.pred_obj_scores:
# object_score_logits: tensor([[24.0962]], device='cuda:0')
is_obj_appearing = object_score_logits > 0 # 硬阈值
# is_obj_appearing: tensor([[True]], device='cuda:0')
# NO_OBJ_SCORE: -1024.0
# low_res_multimasks: torch.Size([1, 1, 256, 256])
# 记忆用掩膜必须**硬**选择:有对象才保留,否则置极大负值
low_res_multimasks = torch.where(
is_obj_appearing[:, None, None],
low_res_multimasks,
NO_OBJ_SCORE,
)
# --- 7. 数据类型转换:bf16/fp16 -> fp32(老版本 PyTorch interpolate 不支持 bf16)---
low_res_multimasks = low_res_multimasks.float()
# 上采样到图像原分辨率(stride=1)
# self.image_size: 1024
high_res_multimasks = F.interpolate(
low_res_multimasks,
size=(self.image_size, self.image_size),
mode="bilinear",
align_corners=False,
)
# high_res_multimasks: torch.Size([1, 1, 1024, 1024])
# --- 8. 选取最佳掩膜 ---
# sam_output_tokens: torch.Size([1, 1, 256])
sam_output_token = sam_output_tokens[:, 0] # 默认取第 1 个 token(单掩膜时即自身)
# sam_output_token :torch.Size([1, 256])
# multimask_output: False
if multimask_output:
# 多掩膜时,选 IoU 估计最高的那个
best_iou_inds = torch.argmax(ious, dim=-1) # (B,)
batch_inds = torch.arange(B, device=device)
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
# 若解码器输出了多个 token,同样要选最佳
if sam_output_tokens.size(1) > 1:
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
else:
# 单掩膜时,最佳即唯一
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
# low_res_masks: torch.Size([1, 1, 256, 256])
# low_res_masks: torch.Size([1, 1, 1024, 1024])
# --- 9. 从最佳 token 提取对象指针(用于记忆)---
# sam_output_token: torch.Size([1, 256])
obj_ptr = self.obj_ptr_proj(sam_output_token) # (B, C)
# obj_ptr: torch.Size([1, 256])
# --- 10. 对象指针后处理:若模型认为"无对象",指针也被削弱或替换 ---
# self.pred_obj_scores: True
if self.pred_obj_scores:
# self.soft_no_obj_ptr: False
if self.soft_no_obj_ptr:
# 软削弱:用 sigmoid 概率加权
assert not self.teacher_force_obj_scores_for_mem
lambda_is_obj_appearing = object_score_logits.sigmoid()
else:
# 硬削弱:0/1 加权
# is_obj_appearing: tensor([[True]], device='cuda:0')
lambda_is_obj_appearing = is_obj_appearing.float()
# lambda_is_obj_appearing: tensor([[1.]], device='cuda:0')
# self.fixed_no_obj_ptr: True
if self.fixed_no_obj_ptr:
obj_ptr = lambda_is_obj_appearing * obj_ptr
# obj_ptr: torch.Size([1, 256])
# 剩余权重用"无对象指针"补齐,保证指针和为 1
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
# obj_ptr: torch.Size([1, 256])
# low_res_multimasks: torch.Size([1, 1, 256, 256])
# high_res_multimasks: torch.Size([1, 1, 1024, 1024])
# ious: tensor([[0.9436]], device='cuda:0')
# low_res_masks: torch.Size([1, 1, 256, 256])
# low_res_masks: torch.Size([1, 1, 1024, 1024])
# obj_ptr: torch.Size([1, 256])
# object_score_logits: object_score_logits: tensor([[24.0962]], device='cuda:0')
# --- 11. 返回打包结果 ---
return (
low_res_multimasks, # 0 低分多掩膜
high_res_multimasks, # 1 高分多掩膜
ious, # 2 各掩膜 IoU 估计
low_res_masks, # 3 最佳低分掩膜
high_res_masks, # 4 最佳高分掩膜
obj_ptr, # 5 对象指针(记忆 key/query)
object_score_logits, # 6 对象出现置信度 logits
)
_forward_sam_heads核心总结:视频对象跟踪的"推理引擎"这是一个纯推理 函数(
@torch.inference_mode()),负责将用户提示(点击、框、前一帧掩码)与历史记忆融合,生成当前帧的掩码预测和对象身份指针。
核心逻辑:5步流水线
第1-3步:输入准备(预处理)
# 输入:backbone特征 + 用户提示(点/掩码)+ 高分辨率细节特征 # 输出:标准化的稀疏/密集嵌入
点提示 :坐标归一化到
[0,1],标签1=前景, 0=背景, -1=填充掩码提示 :下采样到
256×256输入尺寸无输入时:用默认填充值保证流程不中断
第4-5步:SAM核心推理
# 输入:预处理后的嵌入 # 输出:4个候选掩码 + IoU分数 + 4个token + 对象分数
提示编码器:将用户点击/掩码编码成256维向量
掩码解码器 :生成低分辨率 掩码(
64×64)和高分辨率细节多候选机制:默认输出4个token(0=单掩码主输出,1-3=多掩码备选)
第6-7步:后处理与分辨率提升
# 输入:低分辨率掩码 # 输出:原始分辨率掩码 + 处理无对象情况
对象分数过滤 :如果模型认为"没有物体",掩码置为极大负值(
NO_OBJ_SCORE)上采样 :双线性插值到原始图像尺寸(
1024×1024)第8步:掩码选择策略
# 关键决策:单掩码 vs 多掩码
单掩码模式 (
multimask_output=False):直接用token 0的掩码多掩码模式 (
multimask_output=True):从4个候选中选IoU最高的那个**
dynamic_multimask_via_stability**:如果token 0不稳定(轮廓模糊),自动回退到最佳多掩码第9-10步:生成"对象身份证"(核心创新)
# 输入:最佳token(256维) # 输出:对象指针(256维,用于记忆机制)
对象指针 (
obj_ptr):从最佳token投影得到的身份特征向量跨帧记忆:这个指针会被存入记忆库,下一帧作为"我是谁"的提示
动态加权 :根据对象分数(
object_score_logits)削弱指针强度
soft_no_obj_ptr=True:用sigmoid概率软性削弱
fixed_no_obj_ptr=True:用0/1硬性削弱
设计哲学:为什么这样设计?
设计点 原因 类比 总是用token 0做记忆 训练时只见过token 0代表身份,其他token是"短期临时工" 护照号 vs 临时通行证 多掩码候选 处理歧义点击(边界、重叠区域) 专家会诊,取最佳方案 对象指针 将掩码"外观"与物体"身份"解耦 人脸(可变化) vs 身份证号(不变) 动态削弱 物体消失时,记忆应衰减避免干扰 记忆随时间褪色
输入输出全貌
输入:
backbone_features:融合历史记忆的图像特征[1, 256, 64, 64]
point_inputs:用户点击坐标和标签
high_res_features:边缘细节特征[1, 32, 256, 256]和[1, 64, 128, 128]输出(7个值):
low_res_multimasks:4个低分辨率候选[1, 4, 256, 256]
high_res_multimasks:4个高分辨率候选[1, 4, 1024, 1024]
ious:4个IoU分数[1, 4]
low_res_masks:最佳低分辨率掩码[1, 1, 256, 256]
high_res_masks:最佳高分辨率掩码[1, 1, 1024, 1024]
obj_ptr:对象指针[1, 256]← 用于下一帧记忆
object_score_logits:对象出现置信度[1, 1]
一句话总结
这个函数是视频SAM的推理中枢 :它接收带记忆的图像 和用户提示 ,输出当前帧掩码 的同时,生成一个256维的身份指针,告诉记忆系统"这个物体长这样,下次见到要认出来"。
5.1 obj_ptr = self.obj_ptr_proj(sam_output_token)
这个会进入MLP
python
# Lightly adapted from
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/
transformer/transformer_predictor.py # noqa
class MLP(nn.Module):
def __init__(
self,
input_dim: int,
hidden_dim: int,
output_dim: int,
num_layers: int,
activation: nn.Module = nn.ReLU,
sigmoid_output: bool = False,
) -> None:
super().__init__()
self.num_layers = num_layers
h = [hidden_dim] * (num_layers - 1)
self.layers = nn.ModuleList(
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
)
self.sigmoid_output = sigmoid_output
self.act = activation()
def forward(self, x):
# x: torch.Size([1, 256])
for i, layer in enumerate(self.layers):
# self.num_layers:3
x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)
# i=0 x: torch.Size([1, 256])
# i=1 x: torch.Size([1, 256])
# i=2 x: torch.Size([1, 256])
# self.sigmoid_output: False
if self.sigmoid_output:
x = F.sigmoid(x)
return x
我们去sam2/modeling/sam2_base.py,在这个文件搜索self.obj_ptr_proj,发现它在_build_sam_heads函数中。而且确实在类SAM2Base的初始化中存在self._build_sam_heads()。
if self.use_obj_ptrs_in_encoder:
a linear projection on SAM output tokens to turn them into object pointers
self.obj_ptr_proj = torch.nn.Linear(self.hidden_dim, self.hidden_dim)
if self.use_mlp_for_obj_ptr_proj:
self.obj_ptr_proj = MLP(# self.hidden_dim:256
self.hidden_dim, self.hidden_dim, self.hidden_dim, 3
)else:
self.obj_ptr_proj = torch.nn.Identity()
5.2 对象记忆的"褪色与重置"机制
--- 9. 从最佳 token 提取对象指针(用于记忆)---
sam_output_token: torch.Size([1, 256])
obj_ptr = self.obj_ptr_proj(sam_output_token) # (B, C)
obj_ptr: torch.Size([1, 256])
--- 10. 对象指针后处理:若模型认为"无对象",指针也被削弱或替换 ---
self.pred_obj_scores: True
if self.pred_obj_scores:
self.soft_no_obj_ptr: False
if self.soft_no_obj_ptr:
软削弱:用 sigmoid 概率加权
assert not self.teacher_force_obj_scores_for_mem
lambda_is_obj_appearing = object_score_logits.sigmoid()
else:
硬削弱:0/1 加权
is_obj_appearing: tensor([[True]], device='cuda:0')
lambda_is_obj_appearing = is_obj_appearing.float()
lambda_is_obj_appearing: tensor([[1.]], device='cuda:0')
self.fixed_no_obj_ptr: True
if self.fixed_no_obj_ptr:
obj_ptr = lambda_is_obj_appearing * obj_ptr
obj_ptr: torch.Size([1, 256])
剩余权重用"无对象指针"补齐,保证指针和为 1
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
obj_ptr: torch.Size([1, 256])
什么意思?
这段代码是对象记忆的"褪色与重置"机制,防止物体消失后残留的错误记忆干扰后续跟踪。
核心思想:物体消失时,记忆应该"归零"或"模糊化"
obj_ptr是物体的"记忆指纹",但如果物体离开画面,这个指纹会变成 "幽灵记忆" ------下一帧会错误地匹配背景或其他物体。这段代码就是动态削弱或替换这个指纹。
分步拆解
第1步:提取初始对象指针
obj_ptr = self.obj_ptr_proj(sam_output_token) # [1, 256]
sam_output_token:最佳掩码对应的token(物体的"身份编码")
obj_ptr_proj:线性投影层,将token转换为记忆指针(用于下一帧检索)第2步:计算"存在权重"
lambda_is_obj_appearing
# 硬模式(当前配置) lambda_is_obj_appearing = is_obj_appearing.float() # 0.0 或 1.0 # 结果:tensor([[1.0]]) 表示"物体存在"
lambda的含义:
λ = 1.0 :物体明确存在 → 记忆指针100%保留
λ = 0.0 :物体消失 → 记忆指针被削弱或替换
第3步:削弱或替换指针(关键操作)
pythonif self.fixed_no_obj_ptr: # True obj_ptr = lambda_is_obj_appearing * obj_ptr # 物体消失时乘以0,直接清零 # 用"空指针"补全剩余权重 obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr两种模式对比:
模式 soft_no_obj_ptr=False(硬)soft_no_obj_ptr=True(软)λ计算 is_obj_appearing.float()(0/1)object_score_logits.sigmoid()(0~1连续值)物体存在 λ=1.0 → obj_ptr不变λ≈1.0 → obj_ptr几乎不变物体消失 λ=0.0 → obj_ptr = 0+no_obj_ptrλ≈0.0 → obj_ptr很弱 +no_obj_ptr主导优点 干净利落的开关,无记忆残留 平滑过渡,避免突变抖动 缺点 可能因误判导致记忆突然中断 物体消失后仍残留微弱记忆
什么是
self.no_obj_ptr?这是一个可学习的参数向量 ,形状
[B,256],代表 "没有物体时的默认记忆状态" 。类比 :就像数据库里的 "空值(NULL)" 或 "默认用户",当物体不存在时,记忆系统应该指向这个"空身份",而不是一个过期的物体身份。
完整流程可视化
场景A:物体正常存在
is_obj_appearing = True → λ = 1.0 obj_ptr = 1.0 * "猫的指纹" + (1-1.0) * "空指针" = "猫的指纹" # 完全保留记忆场景B:物体消失(硬模式)
is_obj_appearing = False → λ = 0.0 obj_ptr = 0.0 * "猫的指纹" + (1-0.0) * "空指针" = "空指针" # 彻底清空记忆,避免"猫"的幽灵残留场景C:物体可能消失(软模式)
object_score_logits = -2.0 → λ ≈ 0.12 obj_ptr = 0.12 * "猫的指纹" + 0.88 * "空指针" = "模糊的、几乎空的指针" # 记忆很弱,但仍有微弱痕迹
为什么要这么设计?
防止"记忆污染":
如果没有这个机制:
# 第1帧:跟踪猫 obj_ptr = "猫的指纹" # 第10帧:猫跑出画面,但记忆未清空 # 第11帧:出现一只狗 obj_ptr 仍然 ≈ "猫的指纹" → 模型误以为狗是猫 → **跟踪失败**有了这个机制:
# 第10帧:模型检测不到猫 → λ=0 obj_ptr = "空指针" # 第11帧:出现狗 obj_ptr = "空指针" → 模型需要用户重新提示 → **避免误匹配**
一句话总结
这段代码实现了记忆的"生命周期管理" :物体存在时强化记忆,物体消失时主动遗忘,防止过时的记忆成为下一帧的"幽灵干扰"。
在sam2/modeling/sam2_base.py搜索self.no_obj_ptr,发现在SAM2Base初始化中,
条件:只有当模型需要预测对象分数 AND 在编码器中使用对象指针时,才创建这个参数
if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
创建一个形状为 [1, hidden_dim] 的可学习参数(默认 hidden_dim=256)
self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
用截断正态分布重新初始化(均值=0,标准差=0.02)
trunc_normal_(self.no_obj_ptr, std=0.02)
这段代码是初始化"无对象指针" ,让它成为模型的一个可学习参数。为什么这样做?
torch.nn.Parameter将
no_obj_ptr注册为模型参数,意味着:
它会被优化器更新(可学习)
会出现在
model.parameters()中会被保存到检查点(checkpoint)
torch.zeros然后trunc_normal_
先清零:确保初始值不会太大,避免训练初期梯度爆炸
再加微小噪声 :
std=0.02的截断正态分布,让初始值在零附近小幅扰动目的 :模型能自动学习什么样的"空状态"最合适,而不是人为设定死值
为什么不是固定值?
如果
no_obj_ptr是固定零向量
self.no_obj_ptr = torch.zeros(1, 256) # 固定值
模型无法适应不同场景
可能与其他真实对象指针"撞车"(距离太近)
作为可学习参数,模型会学会:
远离所有真实对象的指针分布
成为记忆空间中的 "无效身份"锚点
类比: learnable "NULL" token
这就像数据库设计中的 "可学习的NULL值" :
# 固定NULL NULL = [0, 0, 0, ..., 0] # 死板,可能与某些数据冲突 # 可学习NULL NULL = [0.01, -0.02, 0.005, ...] # 训练后自动调整,确保唯一性
一句话总结
no_obj_ptr不是随意设定的零向量,而是一个可训练的参数,模型通过反向传播自动学会"当物体不存在时,记忆系统应该处于什么状态",从而更智能地清除过时记忆。
所以其实就是obj_ptr是sam_output_token经过MLP后得到的,如果没消失的话obj_ptr就不用加self.no_obj_ptr,如果消失的话就要加。我知道(1,4,256)的mask_tokens_out取第0个token得到的(1,1,256)sam_output_token,然后取第0个得到(1,256)的sam_output_token,但是这个sam_output_token是啥?是用于判断物体是否存在的token吗?(不是,判断物体是否存在是object_score_logits,这个是那4个输出掩码的第一个)。那为什么要经过MLP变成obj_ptr,然后为什么物体消失要加self.no_obj_ptr,我理解物体是有可能是慢慢消失的,就是物体的一部分慢慢变小直至消失,它是怎么记忆这一点呢?
你问到了视频跟踪最核心的设计,我来逐层拆解这三个关键概念:
1.
sam_output_token到底是什么?
# 形状: [1, 256] sam_output_token = tensor([-0.2, 0.5, 1.2, ..., -0.8])本质 :这是解码器对"这个物体长什么样"的语义指纹,不是"存在判断器"。
类比:
人脸识别:128维的脸部特征向量(描述五官特征)
不是:一个开关信号(0=没人,1=有人)
它不负责说"物体在不在",只负责说"如果物体在,它的身份特征是什么"。
2. 为什么要经过MLP变成
obj_ptr?
obj_ptr = self.obj_ptr_proj(sam_output_token) # 线性投影关键原因:空间对齐 + 可学习记忆空间
对比 sam_output_tokenobj_ptr来源 掩码解码器输出 经过投影层 空间 解码空间(偏外观描述) 记忆空间(偏身份检索) 用途 生成掩码 检索/存储记忆 可学习性 固定语义 可训练调整
翻译官作用:MLP把"外观描述"翻译成"身份ID",让记忆系统能看懂
训练效果 :自动拉近同物体不同帧的
obj_ptr距离,推远不同物体的距离如果没有投影:
# 直接存储 token 记忆库 = { "第1帧": token_0, "第5帧": token_0, "第10帧": token_0 } # 第11帧检索 查询 = token_0_new 结果 = 找与 token_0_new 最相似的记忆 → 可能匹配错误(token空间≠记忆空间)
3. 物体消失机制:怎么感知"慢慢消失"?
这是独立判断的! 物体消失由
object_score_logits分支 负责,不是sam_output_token。工作流程:
第N帧:物体完整
object_score_logits = 5.0 # λ = sigmoid(5.0) ≈ 0.99 sam_output_token = "猫的完整身份编码" obj_ptr = 0.99 * "猫" + 0.01 * "空指针" # 几乎全保留第N+5帧:被遮挡50%
object_score_logits = 1.0 # λ ≈ 0.73 sam_output_token = "猫的模糊身份" obj_ptr = 0.73 * "模糊猫" + 0.27 * "空指针" # 记忆开始模糊第N+10帧:完全消失
object_score_logits = -3.0 # λ ≈ 0.05 sam_output_token = "垃圾编码(解码器乱猜)" obj_ptr = 0.05 * "垃圾" + 0.95 * "空指针" # 几乎清空
4. 为什么需要
self.no_obj_ptr?如果物体消失后
obj_ptr直接清零:
# 第10帧:obj_ptr = [0, 0, ..., 0] # 第11帧:新物体出现 obj_ptr_new = "狗的身份" 记忆检索:找与 [0,0,...,0] 最相似的特征 → 可能匹配到背景 → 跟踪漂移正确设计:
# 第10帧:obj_ptr = 学习的"空指针"(远离所有真实物体) # 第11帧:新物体出现 记忆检索:找不到有效匹配 → 触发"需要用户重新提示" → 避免误匹配
self.no_obj_ptr就是记忆空间里的"无人区" ,告诉模型:"这里什么都没有,别瞎匹配"。
总结:三个变量的区别
变量 形状 作用 判断是否消失 sam_output_token [256]身份编码(物体是谁) ❌ 不判断 object_score_logits [1]存在置信度(物体在不在) ✅ 输出λ obj_ptr [256]记忆指针(用于检索) ❌ 但会被λ削弱 关系链:
物体外观 → sam_output_token → obj_ptr_proj → obj_ptr → 存入记忆 ↓ object_score_logits → λ → 控制 obj_ptr 的削弱程度
一句话回答
物体消失不是由token判断,而是由独立的
object_score_logits分支 输出λ值;当λ→0时,obj_ptr会向学习的"空指针" 渐进靠拢,实现软性遗忘,避免硬清零导致记忆空间污染。
软模式和硬模式?
代码真相:加权和 vs 硬切换
这段代码做的是向量插值,不是"硬清零":
# 当 λ=0.8 时 obj_ptr = 0.8 * "猫的指针" + 0.2 * "空指针" # 混合状态 # 当 λ=0.0 时 obj_ptr = 0.0 * "猫的指针" + 1.0 * "空指针" # 完全切换关键看
lambda_is_obj_appearing的取值方式:
两种模式:硬 vs 软
硬模式(你代码里的配置)
self.fixed_no_obj_ptr = True self.soft_no_obj_ptr = False # λ = 0.0 或 1.0(离散) lambda_is_obj_appearing = is_obj_appearing.float() # tensor([[0.0]]) 或 tensor([[1.0]])行为:
物体存在 :λ=1.0 →
obj_ptr = 猫的指针(纯猫)物体消失 :λ=0.0 →
obj_ptr = 空指针(纯空)确实是"突然切换",没有中间态。物体消失的瞬间,记忆指针立刻变成"无效身份"。
软模式
self.soft_no_obj_ptr = True # λ = 0.0 ~ 1.0(连续) lambda_is_obj_appearing = object_score_logits.sigmoid() # tensor([[0.73]])行为:
物体清晰 :λ≈0.99 →
obj_ptr ≈ 猫的指针物体模糊 :λ≈0.50 →
obj_ptr = 0.5*猫 + 0.5*空(混合状态)物体消失 :λ≈0.01 →
obj_ptr ≈ 空指针这才是"慢慢消失" ,通过
object_score_logits的平滑变化实现渐进过渡。
为什么你看到的是硬模式?
因为
fixed_no_obj_ptr=True是默认配置 ,它在论文和官方代码中被验证最稳定。原因:
避免模糊记忆:软模式可能导致"半透明幽灵"------物体消失后,记忆处于猫和空的叠加态,下一帧可能匹配到背景
训练更简单:硬开关让梯度更干净,没有中间态的歧义
实际足够好 :
object_score_logits本身可以快速响应 ,0.0 和 1.0 之间的切换在时序上是平滑的(因为 logits 是连续变化的)
条件:只有当模型需要预测对象分数 AND 在编码器中使用对象指针时,才创建这个参数
if self.pred_obj_scores and self.use_obj_ptrs_in_encoder:
创建一个形状为 [1, hidden_dim] 的可学习参数(默认 hidden_dim=256)
self.no_obj_ptr = torch.nn.Parameter(torch.zeros(1, self.hidden_dim))
用截断正态分布重新初始化(均值=0,标准差=0.02)
trunc_normal_(self.no_obj_ptr, std=0.02)
这个到底学习什么?
这就是 视频记忆系统的"空值锚点" ,它学习的东西非常精妙:
它学习的是:记忆空间中的"无人区"坐标
想象一个256维的向量空间,每个点代表一个可能的物体身份:
区域A:猫的指针(如
[0.5, -0.2, ..., 1.0])区域B:狗的指针(如
[-0.8, 0.3, ..., -0.5])区域C:车的指针(如
[0.1, 0.9, ..., 0.4])
no_obj_ptr学习的就是这个空间里"不属于任何物体"的那个点 。
为什么要 学习 这个?(而不是用固定的零向量)
问题1:零向量可能"太普通"
# 如果 no_obj_ptr = [0, 0, ..., 0] # 猫的指针 = [0.01, -0.02, 0.005, ...] # 小噪声 距离(猫, 空) = ||[0.01, -0.02, ...]|| ≈ 很小 → 模型可能分不清"猫"和"空"问题2:零向量可能"太特殊"
某些激活函数(如ReLU)会让零区域梯度消失,导致"空状态"无法有效传播。
解决方案:可学习的空指针
# 训练后,no_obj_ptr 可能变成: no_obj_ptr = [3.5, -3.5, 3.5, -3.5, ...] # 远离所有物体分布 距离(猫, 空) = 很大 # 明确分离 距离(狗, 空) = 很大 # 明确分离
它到底怎么学习?(训练动力学)
Step 1: 数据驱动
当视频中物体消失 时,
λ → 0,记忆指针被强制靠近no_obj_ptr下一帧如果仍然没有物体,模型损失会惩罚错误的匹配
反向传播 :梯度更新
no_obj_ptr,让它成为一个"不会导致误匹配"的状态Step 2: 竞争学习
# 损失函数隐含的竞争关系: # 猫的指针 ← 靠近"猫的所有历史帧" # 狗的指针 ← 靠近"狗的所有历史帧" # 空指针 ← 远离"所有物体",靠近"所有无物体帧"Step 3: 收敛到最优解
经过大量数据训练后,
no_obj_ptr自动找到:
最佳位置 :记忆空间中距离所有真实物体最远的区域
最佳编码 :能够让下一帧的预测误差最小的向量
类比:学习数据的"聚类中心"
就像K-means聚类,但反着来:
普通聚类 :找到数据最密集的中心点
no_obj_ptr :找到最远离所有数据 的外围点
图示:
向量空间: 猫 ● \ \ ● 狗 / / 车 ● 空 ● ← no_obj_ptr 学习到这里(数据荒漠)
为什么用
trunc_normal_初始化?
trunc_normal_(self.no_obj_ptr, std=0.02)
小方差(0.02):初始值接近0但不精确为0,给优化留出空间
截断分布:防止初始值过大,远离原点,破坏记忆空间结构
从零开始:让模型自己决定"空"应该是什么,而不是预设偏见
一句话总结
no_obj_ptr学习的是记忆空间中的"数据荒漠"坐标 ------它通过反向传播自动找到一个远离所有真实物体、且最有利于消除误匹配的向量表示,作为物体消失时记忆的"安全锚点"。