SAM2跟踪的理解7——mask decoder

目录

一、前言

四、MaskDecoder.forward

[4.1 MaskDecoder.predict_masks](#4.1 MaskDecoder.predict_masks)

[4.1.2 TwoWayTransformer.forward](#4.1.2 TwoWayTransformer.forward)

[4.1.2.1 TwoWayAttentionBlock.forward](#4.1.2.1 TwoWayAttentionBlock.forward)

[4.1.2.2 self.self_attn------Attention.forward](#4.1.2.2 self.self_attn——Attention.forward)

线性映射前后维度是不变的,那它里面做了什么?有什么作用?

你的意思是,本来q,k,v都是相同的,线性映射之后就不同了是吗,那如何理解线性映射呢?后面什么要对q、k、v拆头,为什么是拆8个头?F.scaled_dot_product_attention做了什么?out_proj做了什么?

这里很多函数进不去,标记一下

如何理解自注意力机制呢?我自己还不了解我自己吗,还要用这种方式?

所以说其实是如果我是token的话,我的身份是受整个向量的其他token影响的,每个token都去询问一遍其他所有token以确定自己的身份,是这样理解吗

这里感觉没有理清楚,标记一下

[4.1.2.3 self.cross_attn_token_to_image------Attention.forward](#4.1.2.3 self.cross_attn_token_to_image——Attention.forward)

[为什么out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) 之后out: torch.Size([1, 8, 9, 16])](#为什么out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) 之后out: torch.Size([1, 8, 9, 16]))

[如何理解attn = softmax(q_h @ k_h.transpose(-2,-1) / sqrt(d_k))和out_h = attn @ v_h](#如何理解attn = softmax(q_h @ k_h.transpose(-2,-1) / sqrt(d_k))和out_h = attn @ v_h)

[代入SAM2分割这一情景,如何理解attn = softmax(q_h @ k_h.transpose(-2,-1) / sqrt(d_k))和out_h = attn @ v_h](#代入SAM2分割这一情景,如何理解attn = softmax(q_h @ k_h.transpose(-2,-1) / sqrt(d_k))和out_h = attn @ v_h)

[4.1.2.4 自注意力和交叉注意力有什么区别?为什么要先自注意力?](#4.1.2.4 自注意力和交叉注意力有什么区别?为什么要先自注意力?)

[4.1.2.5 为什么q和k都要加个绝对位置编码,然后v没加而直接就是keys?](#4.1.2.5 为什么q和k都要加个绝对位置编码,然后v没加而直接就是keys?)

[4.1.2.6 MLP.forward](#4.1.2.6 MLP.forward)


一、前言

下面是第一帧情况下的函数调用顺序。

2.12 <重点> add_new_prompt

2.13 <重点> _run_single_frame_inference

2.14 <重点> track_step

2.15 <重点> _prepare_memory_conditioned_features

2.16 _use_multimask

2.17 <重点> _forward_sam_heads

2.18 提示编码器:类PromptEncoder.forward

2.19 类PositionEmbeddingRandom.forward_with_coords

2.20 类PromptEncoder.get_dense_pe

2.21 掩码解码器 类MaskDecoder.forward

2.22 类MaskDecoder.predict_masks

2.23 TwoWayTransformer.forward(这篇开头在这)

2.24 TwoWayAttentionBlock.forward

2.25 Attention.forward(这篇结束在这)

2.26 <重点> MLP.forward

2.27 Attention.forward

2.28 LayerNorm2d.forward

2.29 MaskDecoder._dynamic_multimask_via_stability

2.30 MaskDecoder._get_stability_scores

2.31 fill_holes_in_mask_scores

2.32 _get_maskmem_pos_enc

2.33 _consolidate_temp_output_across_obj

2.34 _get_orig_video_res_output

四、MaskDecoder.forward

4.1 MaskDecoder.predict_masks

4.1.2 TwoWayTransformer.forward

sam2/modeling/sam/transformer.py

hs, src = self.transformer(src, pos_src, tokens)

上面这句进去就是调用TwoWayTransformer的forward函数。

python 复制代码
class TwoWayTransformer(nn.Module):
    """
    双向 Transformer:  
    1. 先让「稀疏点 token」(queries) 与「稠密图像 token」(keys) 做若干层双向 cross-attention;  
    2. 最后再让 queries 单独对图像做一次 attention,得到增强后的 queries 作为最终输出。  
    图像 token 只做中间传递,最终原样返回。  
    """

    def __init__(
        self,
        depth: int,                         # 双向 attention block 重复次数
        embedding_dim: int,                 # 通道维度 C
        num_heads: int,                     # 多头注意力的头数
        mlp_dim: int,                       # FFN 中间层维度
        activation: Type[nn.Module] = nn.ReLU,
        attention_downsample_rate: int = 2, # Attention 内部 Q/K 下采样比例
    ) -> None:
        super().__init__()
        self.depth = depth
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.mlp_dim = mlp_dim
        self.layers = nn.ModuleList()

        # 堆叠 depth 个双向 attention block
        for i in range(depth):
            self.layers.append(
                TwoWayAttentionBlock(
                    embedding_dim=embedding_dim,
                    num_heads=num_heads,
                    mlp_dim=mlp_dim,
                    activation=activation,
                    attention_downsample_rate=attention_downsample_rate,
                    skip_first_layer_pe=(i == 0),   # 第一层无需给 query 加 PE(已在输入时加好)
                )
            )

        # 最后一层:queries → 图像的 attention
        self.final_attn_token_to_image = Attention(
            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
        )
        self.norm_final_attn = nn.LayerNorm(embedding_dim)

    def forward(
        self,
        image_embedding: Tensor,   # [B, C, H, W]
        image_pe: Tensor,          # [B, C, H, W]  图像位置编码
        point_embedding: Tensor,   # [B, Np, C]     点提示的 embedding(已含 PE)
    ) -> Tuple[Tensor, Tensor]:
        """
        Returns:
          processed point embedding  [B, Np, C]
          processed image embedding  [B, H*W, C]   (与输入内容相同,仅 reshape)
        """
        # image_embedding:  torch.Size([1, 256, 64, 64])
        # image_pe: torch.Size([1, 256, 64, 64])
        # point_embedding: torch.Size([1, 9, 256])

        # 1. 把图像展平成 token 序列
        bs, c, h, w = image_embedding.shape
        # bs:1 c:256  h:64  w:64

        image_embedding = image_embedding.flatten(2).permute(0, 2, 1)  # [B, H*W, C]
        # image_embedding: torch.Size([1, 4096, 256])

        image_pe = image_pe.flatten(2).permute(0, 2, 1)              # [B, H*W, C]
        # image_pe: torch.Size([1, 4096, 256])

        queries = point_embedding                                      # [B, Np, C]
        # queries: torch.Size([1, 9, 256])

        keys = image_embedding                                         # [B, H*W, C]
        # keys: torch.Size([1, 4096, 256])

        # 2. 逐层双向 attention 更新 queries 和 keys
        for layer in self.layers:
            # 进入TwoWayAttentionBlock.forward
            queries, keys = layer(
                queries=queries,
                keys=keys,
                query_pe=point_embedding,   # 每次把原始 PE 作为 Q 的偏置传进去
                key_pe=image_pe,
            )

        # 3. 最后一层:queries 再对图像做一次 attention
        q = queries + point_embedding                                  # 残差加回原始 PE
        k = keys + image_pe
        attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys)   # [B, Np, C]
        queries = queries + attn_out                                   # 残差连接
        queries = self.norm_final_attn(queries)                        # LayerNorm

        # 4. 返回增强后的 queries 和原图 token(下游只拿 queries 用即可)
        return queries, keys

整体流程一句话总结

"稀疏点 token" 先和"稠密图像 token"在多层的双向 cross-attention 里互相更新;

最后再把更新后的点 token 单独对图像做一次 attention 并残差+Norm,得到最终点特征。

图像 token 只充当信息搬运工,原样返回即可。

4.1.2.1 TwoWayAttentionBlock.forward

sam2/modeling/sam/transformer.py

for layer in self.layers:

进入TwoWayAttentionBlock.forward

queries, keys = layer(

queries=queries,

keys=keys,

query_pe=point_embedding, # 每次把原始 PE 作为 Q 的偏置传进去

key_pe=image_pe,

)

python 复制代码
class TwoWayAttentionBlock(nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        num_heads: int,
        mlp_dim: int = 2048,
        activation: Type[nn.Module] = nn.ReLU,
        attention_downsample_rate: int = 2,
        skip_first_layer_pe: bool = False,
    ) -> None:
        """
        一个 Transformer 块,内部 4 步:
        1) sparse queries 自注意力  
        2) queries cross-attend 到 dense keys(token→image)  
        3) 对 queries 做 MLP  
        4) dense keys cross-attend 到 sparse queries(image→token)  
        通过双向交叉,实现"稀疏点"与"稠密图"信息互通。
        """
        super().__init__()

        # 1. 自注意力
        self.self_attn = Attention(embedding_dim, num_heads)
        self.norm1   = nn.LayerNorm(embedding_dim)

        # 2. token→image 交叉注意力
        # 又进入TwoWayAttentionBlock.forward
        # attention_downsample_rate:2
        self.cross_attn_token_to_image = Attention(
            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
        )
        self.norm2 = nn.LayerNorm(embedding_dim)

        # 3. MLP
        self.mlp = MLP(
            embedding_dim, mlp_dim, embedding_dim, num_layers=2, activation=activation
        )
        self.norm3 = nn.LayerNorm(embedding_dim)

        # 4. image→token 交叉注意力
        self.norm4 = nn.LayerNorm(embedding_dim)

        # attention_downsample_rate:2
        self.cross_attn_image_to_token = Attention(
            embedding_dim, num_heads, downsample_rate=attention_downsample_rate
        )

        self.skip_first_layer_pe = skip_first_layer_pe   # 首块是否给 Q 加 PE

    def forward(
        self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor
    ) -> Tuple[Tensor, Tensor]:
        # 输入形状示例:
        # queries: torch.Size([1, 9, 256])  稀疏点 token
        # keys: torch.Size([1, 4096, 256])  稠密图像 token
        # query_pe:torch.Size([1, 9, 256]) 稀疏点token的绝对位置编码
        # key_pe:torch.Size([1, 4096, 256]) 稠密图像token的绝对位置编码

        # ---------- 1. 自注意力 ----------
        # self.skip_first_layer_pe: True
        if self.skip_first_layer_pe:                 # 首层不加 PE,直接 self-attn
            # queries: torch.Size([1, 9, 256]) 
            queries = self.self_attn(q=queries, k=queries, v=queries)
            # queries: torch.Size([1, 9, 256])
        else:
            q = queries + query_pe                   # 残差加 PE
            attn_out = self.self_attn(q=q, k=q, v=queries)
            queries = queries + attn_out             # 残差连接

        queries = self.norm1(queries)                # [B, 9, 256]
        # queries: torch.Size([1, 9, 256])

        # ---------- 2. token→image 交叉注意力 ----------
        q = queries + query_pe                       # 给 query 加 PE
        # q: torch.Size([1, 9, 256])

        k = keys + key_pe                         # 给 key   加 PE
        # k: torch.Size([1, 4096, 256])
        
        # q: torch.Size([1, 9, 256])
        # k: torch.Size([1, 4096, 256])
        # keys: torch.Size([1, 4096, 256])
        attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys)  # 下采样在内部完成
        # attn_out: torch.Size([1, 9, 256])

        queries = queries + attn_out                 # 残差
        # queries: torch.Size([1, 9, 256])

        queries = self.norm2(queries)                # [B, 9, 256]
        # queries: torch.Size([1, 9, 256])

        # ---------- 3. MLP ----------
        mlp_out = self.mlp(queries)
        queries = queries + mlp_out                  # 残差
        queries = self.norm3(queries)                # [B, 9, 256]

        # ---------- 4. image→token 交叉注意力 ----------
        # 注意:这里"角色互换"------用图像 token 做 Q,去 attend 稀疏点
        q = queries + query_pe                       # 稀疏点继续当"被 attend"的 K/V
        k = keys    + key_pe                         # 图像当 Q
        attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries)  # 形状 [B, 4096, 256]
        keys = keys + attn_out                       # 残差更新图像 token
        keys = self.norm4(keys)                      # [B, 4096, 256]

        # 返回更新后的 (queries, keys),供下一层或下游使用
        return queries, keys

总结

  1. 稀疏点先 self-attn,增强自身上下文。

  2. 再把增强后的点去 attend 图像,提取对应位置特征。

  3. 过一遍 MLP,进一步非线性变换。

  4. 最后让图像 token 反过来看这些点,把"哪些区域有点"信息写回图像特征。

    于是"点"与"图"完成一次双向融合,形状全程保持不变:

    queries 始终 [B, Np, C],keys 始终 [B, H·W, C]。

注意上面TwoWayAttentionBlock初始化里面,创建Attention的时候,自注意力是没有传入downsample_rate,所以是默认的1,而交叉注意力是传入downsample_rate=2的,这也是为什么后面线性映射的时候自注意力的Attention是没有降维的(一直是256),而交叉注意力的Attention里面线性映射的时候降维到128了(不过最后会升回256)

4.1.2.2 **self.self_attn------**Attention.forward

sam2/modeling/sam/transformer.py

TwoWayAttentionBlock.forward中

queries = self.self_attn(q=queries, k=queries, v=queries)

上面的语句进入Attention的forward

python 复制代码
class Attention(nn.Module):
    """
    标准多头注意力,但支持「把 Q/K/V 映射到更低维度」以节省计算。
    内部使用 PyTorch 2.x 的 scaled_dot_product_attention,可自动选 Flash / Mem-efficient / Math kernel。
    """

    def __init__(
        self,
        embedding_dim: int,        # 输入 token 的通道数 C
        num_heads: int,            # 头数 h
        downsample_rate: int = 1,  # 把 C 压缩成 C//downsample_rate,再分头
        dropout: float = 0.0,
        kv_in_dim: int = None,     # 如果 K/V 的输入维度与 Q 不同,可单独指定
    ) -> None:
        super().__init__()
        self.embedding_dim = embedding_dim
        self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
        # downsample_rate:1
        self.internal_dim = embedding_dim // downsample_rate  # 压缩后的通道
        self.num_heads = num_heads
        assert self.internal_dim % num_heads == 0, "num_heads 必须整除 internal_dim"

        # 线性映射:Q 来自 embedding_dim,K/V 可能来自别的维度
        self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
        self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
        self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)

        # 输出再映射回原始通道
        self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
        self.dropout_p = dropout

    # 把 [B, N, C] 拆成 [B, h, N, C//h] 以便并行算多头
    def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
        # x:[1,9,256] num_heads:8
        b, n, c = x.shape                     
        x = x.reshape(b, n, num_heads, c // num_heads)  
        # [1,9,8,32]

        return x.transpose(1, 2)               # [1,8,9,32]

    # 逆操作:把 [B, h, N, C//h] 还原成 [B, N, C]
    def _recombine_heads(self, x: Tensor) -> Tensor:
        b, n_heads, n_tokens, c_per_head = x.shape  # [1,8,9,32]
        x = x.transpose(1, 2)                       # [1,9,8,32]
        return x.reshape(b, n_tokens, n_heads * c_per_head)  # [1,9,256]

    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
        # 输入示例:q/k/v 均为 [1,9,256]

        # 1. 线性映射 + 降维
        q = self.q_proj(q)  
        # q: [1,9,256]

        k = self.k_proj(k)  
        # k: [1,9,256]

        v = self.v_proj(v)  
        # v: [1,9,256]

        # 2. 拆头
        q = self._separate_heads(q, self.num_heads)  
        # q: [1,8,9,32]

        k = self._separate_heads(k, self.num_heads)  
        # k: [1,8,9,32]

        v = self._separate_heads(v, self.num_heads)  
        # v: [1,8,9,32]

        # 3. 计算 dropout 开关(推理时关闭)
        # self.training: False
        dropout_p = self.dropout_p if self.training else 0.0
        # dropout_p: 0.0

        # 4. 调用最优 kernel 做 scaled dot-product attention
        with torch.backends.cuda.sdp_kernel(
            # USE_FLASH_ATTN: False
            enable_flash=USE_FLASH_ATTN,     # 全局开关:是否尝试 FlashAttention
            
            # OLD_GPU:True   dropout_p:0.0   MATH_KERNEL_ON:True
            enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,  # 兜底 math kernel
            
            enable_mem_efficient=OLD_GPU,                     # 老 GPU 可开 memory-efficient
        ):
            out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)  # [1,8,9,32]

        # 5. 合并头 + 输出投影
        # out: torch.Size([1, 8, 9, 32])
        out = self._recombine_heads(out)  
        # out: torch.Size([1, 9, 256])

        out = self.out_proj(out)          
        # out: torch.Size([1, 9, 256])

        return out
  1. 先把 Q/K/V 线性映射到更低维度 internal_dim,省计算。

  2. 拆成多头形状,用 PyTorch 官方 scaled_dot_product_attention 自动选最快 kernel。

  3. 计算完再合并头,并线性映射回原始通道数。

  4. 整个模块对外保持"输入 [B, N, C] → 输出 [B, N, C]"的透明接口,内部按需降维。

线性映射前后维度是不变的,那它里面做了什么?有什么作用?

"维度不变"只是 最外圈 的假象------
forward真正做矩阵乘的维度 已经被 q_proj / k_proj / v_proj 压成了 internal_dim = embedding_dim // downsample_rate,自注意力这里是取downsample_rate=1

算完注意力再用 out_proj 把通道升回来。

所以"线性映射"里面干了三件事:

  1. 降维投影 (省钱)

    把原始 C 映射到 C/r(r=downsample_rate,默认 1 即不降;SAM 里常取 2)。

    参数量与计算量都近似降为 1/r²

  2. 多头拆分准备

    投影后的通道必须能被 num_heads 整除,才能直接 .reshape(..., heads, dim_per_head)

    这一步把"物理维度"转成"逻辑头维度",后面 sdpa 并行算 h 个头。

  3. 可学习的线性组合 (混通道)

    即使 r=1 不降维,投影矩阵仍是 满秩可训练参数 ,给每个头提供不同的子空间,

    否则所有头都看到完全相同的 Q/K/V,多头就退化成"单头复制"。


形象一句话
"先压通道省钱,再拆头并行,最后升通道还给你" ------

外表看 [B,N,C]→[B,N,C],里面其实偷偷把计算搬到了更小的空间。

你的意思是,本来q,k,v都是相同的,线性映射之后就不同了是吗,那如何理解线性映射呢?后面什么要对q、k、v拆头,为什么是拆8个头?F.scaled_dot_product_attention做了什么?out_proj做了什么?

把四个问题拆成四步,一次讲清:

  1. 线性映射到底干了什么

    输入 q=k=v=x 确实完全一样,但

    • q_projk_projv_proj三个独立的权重矩阵 (形状都是 [C, C/r])。

    • 各自乘一遍后,得到三套不同Q/K/V,通道数只剩 C/r

      这就是 Attention 论文里最经典的"用三个线性层把同一批 token 变成查询-键-值"。

  2. 为什么要"拆 8 个头"

    • 投影后的 internal_dim 必须能被头数整除,于是
      internal_dim = num_heads × (dim_per_head)

    • SAM 把 num_heads 设成 8,相当于把 256 → 8×32

      每个头只看 32 维的子空间,并行学 8 组不同的相似度函数,最后再拼回来。

    • 头数=8 是超参,经验值;越大越分散,计算量也线性涨。

  3. F.scaled_dot_product_attention 内部做了什么

    给定 [B, h, N, d]Q/K/V,它一步完成:

    复制代码
    scores = Q @ K.transpose(-2,-1) / sqrt(d)   # [B,h,N,N] 相似度
    attn   = softmax(scores, dim=-1)            # 行归一化得权重
    output = attn @ V                           # 加权求和

    并自动选 FlashAttention / memory-efficient / math 三种 kernel 中最快且支持你硬件的那一条。

  4. out_proj 的作用

    多头输出拼完后形状是 [B, N, internal_dim],但下游网络期望原通道 C
    out_proj 就是一个 Linear(internal_dim, C),把压缩后的特征升维+线性混合 ,再残差接回原路径。

    它同时让不同头的信息再次融合,并恢复网络下一层所需的通道数。

一句话串起来:

"同一份 x三个独立线性层 得到 Q/K/V → 压通道 → 拆 8 个头并行算相似度 → 拼回 → 再线性映射回原始通道。"

这里很多函数进不去,标记一下
如何理解自注意力机制呢?我自己还不了解我自己吗,还要用这种方式?

"自己还不了解自己吗?"------这句话把"自注意力"拟人化了,结果就绕进了死胡同。

把它当成纯数学的"特征再分配" 来看,就通了:

  1. token ≠ 人

    每个 token 只是 C 维向量,同一位置在不同层含义一直在变

    "了解我自己"不是指它突然有了意识,而是用全局信息重新给这 C 个数加权,让"该强调的子空间放大,该抑制的噪声归零"。

  2. 没有全局视野的向量 = 局部盲人

    单层 Conv 或 FFN 只能看到邻域(甚至只能看单点);

    自注意力把整句/整图所有位置 拉成一张相似度表,一次完成任意两两通信

    结果:

    • "猫" token 发现当前特征里"狗"分量很高,就把这部分降权;

    • "边界" token 发现远处还有一条同向边界,就把对应特征加强。

      这就是上下文校正,而不是"自我怀疑"。

  3. 数学上就是"可学习的加权平均"

    new_x_i = Σ_j softmax( (x_i W_q)(x_j W_k)^T /√d ) · (x_j W_v)

    权重矩阵 W 训练好后,每个位置自动知道该抄谁、抄多少

    所谓"自"仅表示 Query、Key、Value 都来自同一组 x,不引入额外外部信号

  4. 为什么一定要"投影三次"

    如果直接用 x 点积 x,只能得到原始特征 的相似度;

    投影成 Q/K/V 后,网络可以给"查询侧""键侧""值侧"各自学不同的子空间

    否则就退化成固定核的平滑,没有可学习参数

  5. 拆多头 = 多组"聚光灯"

    8 个头 ≈ 8 台不同颜色的探照灯,每台只照 32 维,

    有的专盯"颜色相似",有的专盯"空间相邻",最后拼起来就同时兼顾多种上下文线索

一句话收尸:
自注意力不是让 token"反省人生",而是给它一次"看完全局再重新整理自己特征"的机会;
三次线性映射 + 多头 + softmax 加权,只是让这次"整理"变得可学习、可并行、多视角。

你的意思是token是我的局部,我的局部与局部之间要互相了解是吗

可以这么类比,但别再往"人"上靠------把"局部"想成纯数据块就不会绕:

  1. 一张图被切成 64×64 的小 patch,每个 patch 只"看见"自己那 8×8 像素;

    它对应的 token 向量里,也许有"边缘""红色""纹理"这些分量,但不知道 20 个 patch 之外还有一条同方向的边缘

    自注意力就是让这条"局部边缘" token 去查整张图里所有其他边缘 token,发现"哦,原来咱俩在同一条大轮廓上",于是把自己的"轮廓置信度"拉高。

  2. 一句话里,"bank" 这个 token 如果只盯自己,无法区分"河岸"还是"银行";

    它扫完全句,发现前面有"river",就把"水"相关的分量加大,把"金融"分量压掉。

    这就是局部与局部之间互相校正,而不是"谈心"。

  3. 数学上就是用相似度做加权平均

    每个局部发出一个"查询向量",和其他所有局部的"键向量"点积,得到"我跟你有多相关";

    再用这组相关度去加权所有局部的"值向量",更新自己

    所以"局部了解局部"=用全局相似度重新加权融合信息,让原本孤立的 C 维向量变成"全局-aware"的新向量。

一句话:
不是 token 有了情感,而是它的数值列表通过"与所有邻居算相似度→加权抄作业"被重新填写了一次,从而带上全局上下文。

所以说其实是如果我是token的话,我的身份是受整个向量的其他token影响的,每个token都去询问一遍其他所有token以确定自己的身份,是这样理解吗

可以这么"拟人"地打比方,但一定记住------这只是帮你建立直觉的比喻,底层全是矩阵乘法,没有"身份""询问"这些意识活动:

  1. 在自注意力前,你(某个 token)手里的 C 个数字只来自你自己的小区域(8×8 像素或一个子词)。

  2. 自注意力让你把这张"全局相似度表"当成权重 ,去加权抄全场所有 token 的特征 ,然后用抄来的结果替换/加回自己原来的 C 个数字

  3. 于是,新的 C 个数字里既保留了原来的局部线索,又混入了"谁跟我很像/谁对我重要"的全局信息。

  4. 下一层网络看到的新向量,就不再只是局部,而是全局-aware 的表示------这就是你说的"身份被其他 token 影响"。

所以,"询问"=点积算相似度,"确定身份"=用相似度加权平均后更新自己的特征向量

整个过程是可学习的线性投影 + 数值加权,没有任何语义层面的"自我认知"。

全局相似度表这是指q、k、v都有各自的可学习的权重矩阵是吗?那这个权重矩阵是怎么被初始化,怎么被学习的呢?它的数值是怎么来的? 是不是有种真值知道更新后的是不是正确的,然后反向更新这个权重矩阵呢?但这个真值是什么?

把"比喻"全部剥掉,只剩可学习的张量可求导的 loss,就清楚了:

  1. 三个权重矩阵哪来的

    • 形状:q_proj.weight [C, C/r]k_proj.weight [C, C/r]v_proj.weight [C, C/r]

    • 初始化:PyTorch 默认用 KaimingUniformxavier_uniform 随机采样,数值服从某种 scaled 均匀分布,与任务无关。

    • 它们就是普通 Linear 层 ,跟 CNN 的卷积核一样,一开始完全随机,没有任何"真值"先验

  2. "全局相似度表"怎么来的

    前向时:

    Q = x W_q, K = x W_k, S = Q K^T /√d

    S 就是 [B, h, N, N] 的"相似度表",每一次 forward 都重新算一次,不存储、不保留,也没有可学习参数

    可学习的只有 W_qW_k(以及 W_vout_proj)。

  3. 数值怎么被"纠正"------没有"真值相似度"

    下游任务给的才是终极真值:

    • 分割任务:像素级 mask

    • 分类任务:类别标签

    • 生成任务:下一个 token ID

    网络输出 → 与这些标签算交叉熵、MSE、Dice ... → 得到标量 loss → 反向传播一直走到 W_qW_kW_vout_proj → 自动求导调整矩阵元素。
    没有任何人告诉模型"第 i 与 j 的相似度应该是 0.7",它只是发现"把这两个 patch 的权重调高,最终 mask 的 IoU 会变大",于是下次继续调高。

  1. 总结一句话
    可学习的是四个线性矩阵,相似度表只是它们在前向时临时算出的副产品;
    "对不对"由下游任务 loss 说了算,loss 只认最终输出与人工标签的差距,不认中间相似度。
这里感觉没有理清楚,标记一下
4.1.2.3 self.cross_attn_token_to_image**------**Attention.forward

TwoWayAttentionBlock.forward中

---------- 2. token→image 交叉注意力 ----------

q = queries + query_pe # 给 query 加 PE

q: torch.Size([1, 9, 256])

k = keys + key_pe # 给 key 加 PE

k: torch.Size([1, 4096, 256])

又进入Attention.forward

attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) # 下采样在内部完成

queries = queries + attn_out # 残差

queries = self.norm2(queries) # [B, 9, 256]

python 复制代码
class Attention(nn.Module):
    """
    An attention layer that allows for downscaling the size of the embedding
    after projection to queries, keys, and values.
    """

    def __init__(
        self,
        embedding_dim: int,
        num_heads: int,
        downsample_rate: int = 1,
        dropout: float = 0.0,
        kv_in_dim: int = None,
    ) -> None:
        super().__init__()
        self.embedding_dim = embedding_dim          # 原始输入维度(q 的输入维度)
        self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim  # k/v 的输入维度,可与 q 不同
        # self.internal_dim = 256 // 2 = 128
        self.internal_dim = embedding_dim // downsample_rate  # 经过降采样后的"内部"维度,用于多头计算
        self.num_heads = num_heads                  # 注意力头数
        assert (
            self.internal_dim % num_heads == 0
        ), "num_heads must divide embedding_dim."

        # 线性映射:把输入映射到统一的 internal_dim 空间
        # embedding_dim:256  self.internal_dim:128
        self.q_proj = nn.Linear(embedding_dim, self.internal_dim)      # 仅 q 来自 embedding_dim
        # self.kv_in_dim:256
        self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)     # k/v 可能来自不同维度
        self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
        # 输出映射:把拼接后的多头结果再映射回原始 embedding_dim
        # embedding_dim:256  self.internal_dim:128
        self.out_proj = nn.Linear(self.internal_dim, embedding_dim)

        self.dropout_p = dropout  # attention dropout 比例

    def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
        """把 [B, N, C] 拆成 [B, num_heads, N, C//num_heads],方便并行算多头"""
        b, n, c = x.shape
        x = x.reshape(b, n, num_heads, c // num_heads)
        return x.transpose(1, 2)  # B x N_heads x N_tokens x C_per_head

    def _recombine_heads(self, x: Tensor) -> Tensor:
        """与 _separate_heads 相反,把多头结果重新拼接回 [B, N, C]"""
        # x: torch.Size([1, 8, 9, 16])
        b, n_heads, n_tokens, c_per_head = x.shape

        # b:1  n_heads:8   n_tokens:9  c_per_head:16
        x = x.transpose(1, 2)  # 先交换维度,变成 [B, N_tokens, N_heads, C_per_head]

        # x: torch.Size([1, 9, 8, 16])
        return x.reshape(b, n_tokens, n_heads * c_per_head)  # B x N_tokens x C

    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
        """
        参数:
            q: [B, Nq, embedding_dim]  查询序列
            k: [B, Nk, kv_in_dim]      键序列
            v: [B, Nk, kv_in_dim]      值序列
        返回:
            out: [B, Nq, embedding_dim]
        """
        # 输入:
        # q: torch.Size([1, 9, 256])
        # k: torch.Size([1, 4096, 256])
        # v: torch.Size([1, 4096, 256])

        # Input projections
        # 初始化的时候 self.internal_dim = embedding_dim // downsample_rate 
        # downsample_rate = 2, 所以交叉注意力里的线性映射发生降维了
        q = self.q_proj(q)  # q: torch.Size([1, 9, 128])
        k = self.k_proj(k)  # k: torch.Size([1, 4096, 128])
        v = self.v_proj(v)  # v: torch.Size([1, 4096, 128])

        # Separate into heads
        q = self._separate_heads(q, self.num_heads)  # q: torch.Size([1, 8, 9, 16])
        k = self._separate_heads(k, self.num_heads)  # k: torch.Size([1, 8, 4096, 16])
        v = self._separate_heads(v, self.num_heads)  # v:torch.Size([1, 8, 4096, 16])

        # self.dropout_p:0  self.training:False
        dropout_p = self.dropout_p if self.training else 0.0  # 推理时关闭 dropout
        # dropout_p: 0.0

        # Attention
        # 根据 GPU 能力及配置选择最优 kernel:FlashAttention / Math / MemoryEfficient
        with torch.backends.cuda.sdp_kernel(
            enable_flash=USE_FLASH_ATTN,                       # USE_FLASH_ATTN:False
            enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,  # OLD_GPU:True   dropout: 0.0  MATH_KERNEL_ON: True
            enable_mem_efficient=OLD_GPU,
        ):
            out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p)
            # out: torch.Size([1, 8, 9, 16])

        out = self._recombine_heads(out)  # out: torch.Size([1, 9, 128])
        out = self.out_proj(out)          # out: torch.Size([1, 9, 256])

        return out

代码整体解释

  1. 作用:实现一个支持"降维"的多头交叉注意力层。

  2. 降维:通过 downsample_rate 把 q/k/v 映射到更低的 internal_dim 再计算注意力,减少计算量;计算完再映射回原始 embedding_dim

  3. 支持异构输入:k/v 的输入维度 kv_in_dim 可与 q 的 embedding_dim 不同,方便像 ViT 编码器-解码器结构那样使用。

  4. 多头拆分/合并:_separate_heads_recombine_heads 负责把张量维度在"token"与"head"之间来回变换,以调用 PyTorch 的高效 scaled_dot_product_attention

  5. 后端选择:sdp_kernel 的上下文管理器根据硬件与超参自动挑选 FlashAttention、Math 或 MemoryEfficient kernel,兼顾速度与显存。

为什么out = F.scaled_dot_product_attention(q, k, v, dropout_p=dropout_p) 之后out: torch.Size([1, 8, 9, 16])

一句话:因为 F.scaled_dot_product_attention"多头已经拆好" 的 4-D 张量上做的是 逐头计算 ,输出形状和输入的"头维"保持一致,所以
[B, num_heads, Nq, C_per_head] 进去 → 同样形状出来。


详细推导

  1. 输入形状

    q: [1, 8, 9, 16] → 含义 [batch=1, heads=8, query_tokens=9, dim_per_head=16]

    k: [1, 8, 4096, 16]

    v: [1, 8, 4096, 16]

  2. scaled_dot_product_attention 内部逻辑

    对每一个头(dim=1 上切片)独立地执行:

    python 复制代码
    # 对于1个头,q乘以k的转置 [9, 16] x [16, 4096] = [9, 4096]
    attn = softmax(q_h @ k_h.transpose(-2,-1) / sqrt(d_k))   # [9, 4096]
    # 对于1个头,[9, 4096] x [4096, 16] = [9, 16]
    out_h = attn @ v_h                                       # [9, 16]

    因此每个头都会得到一个 [9, 16] 的输出。

  3. 把 8 个头的结果按"头维"堆回去,就是 [1, 8, 9, 16]

  4. 输出形状

    函数约定:

    • 不改变头数

    • 不改变 batch 大小

    • token 数始终与 q 的 token 数相同

    • 特征维度始终与 v 的 per-head 维度相同

    所以出来还是 [1, 8, 9, 16]


小结

"多头"维度在 sdpa 眼里只是普通的 batch 维度之一,它不会合并或重塑它;合并/重塑是 _recombine_headsout_proj 后续才做的事。

如何理解attn = softmax(q_h @ k_h.transpose(-2,-1) / sqrt(d_k))和out_h = attn @ v_h

把"注意力"想成查资料

  • q_h:你此刻想提的 问题(9 个关键词,每个 16 维向量)。

  • k_h:资料库里所有 标题(4096 条,每条也是 16 维向量)。

  • v_h:资料库里对应标题的 正文(4096 条,每条 16 维向量)。


1. 打分:问题与标题有多匹配

复制代码
score = q_h @ k_h.transpose(-2, -1)   # [9, 4096]

每条问题向量与每条标题向量做 点积 (cosine 的分子),数值越大表示越相关。

除以 √d_k√16 = 4)防止点积绝对值太大导致 softmax 饱和,这一步叫 scale


2. 转成概率:softmax

复制代码
attn = softmax(score, dim=-1)         # [9, 4096],每行和为 1

对每条问题(每一行)做 softmax,把"相关分"变成 选资料的概率分布

结果 attn[i, j] 就是"问题 i 对标题 j 的关注权重"。


3. 拿概率去加权正文

复制代码
out_h = attn @ v_h                    # [9, 16]

按关注权重把 4096 条正文向量做 加权平均

  • 权重大的正文对结果贡献大;

  • 权重小的几乎被忽略。

    于是 4096 条信息被压缩成 9 条"精炼答案",维度仍保持 16。


一句话总结

"先算相关性,再按相关性加权求和"------这就是注意力机制的核心。

代入SAM2分割这一情景,如何理解attn = softmax(q_h @ k_h.transpose(-2,-1) / sqrt(d_k))和out_h = attn @ v_h

在 SAM 2(Segment Anything Model 2)里,这套注意力被用来做**"记忆-查询"** 式的跨帧传播

  • 不是普通 NLP 里的"单词→单词",而是
    当前帧的像素查询 ← → 记忆库中过去帧的像素特征

把符号代入视频分割场景,就能一眼看懂那两行公式在干什么。


  1. 张量含义(单头视角,h 代表"每头")

    • q_h :当前帧 待解码的 9 个像素 token(可以是某个 prompt 点对应的窗口特征,也可以是整张特征图展平后采样得到的 9 个向量)。

    • k_h :记忆库里的 4096 个记忆位置(过去帧+当前帧已编码的高置信度像素特征,带空间-时间坐标)。

    • v_h :与 k_h 一一对应的 16 维掩码嵌入(memory value),后续要拿来加权求和,生成当前像素的掩码 logits。


  1. 打分 → softmax attn = softmax(q_h @ k_h.T / √d_k)

    "当前 9 个像素" 分别对 "记忆库里 4096 个历史像素" 算相似度,再归一化。

    结果每行是一个 4096 维的注意力权重 ,告诉模型:

    "为了知道眼前这个像素是不是'狗',我最该看记忆库里的哪几个历史像素?"


  1. 加权求和 out_h = attn @ v_h

    → 用刚刚的权重,把记忆库里的 4096 条掩码嵌入 做加权平均,得到 9 条 16 维的聚合嵌入

    这 16 维向量已经融合了跨帧、跨视角的物体一致性信息 ,后续再喂给轻量级解码器就能直接吐出当前 9 个像素的前景/背景概率实例掩码 logits


一句话场景化总结

"当前像素问:'我以前长什么样?'

记忆库回答:'你跟我这些历史像素最像,按相似程度给你加权平均一下以前的掩码特征,你就知道自己该属于哪一类了。 '"

那两行公式就是**"找最像的历史像素 → 把它们的掩码特征借过来"** 的数学实现,保证视频里物体 mask 在帧间时序一致、ID 不跳变

但现在只是第一帧图像预测

第一帧就用"记忆-查询"式注意力,看起来像是"自己查自己",但这一步仍然有意义,而且并不浪费。把公式放到 SAM 2 第一帧的上下文里,可以这样理解:

  1. 记忆库 ≠ 下一帧,而是同一帧内部

    在 SAM 2 的 encoder 端,输入图像先被拆成 16×16 的窗口 token(外加绝对位置编码),这些 token 同时扮演 kv ;而 q 则是更稀疏的 prompt token (用户给的 1 个点或 9 个框采样点对应的特征)。

    因此

    q_h:9 个 prompt token

    k_h / v_h:4096 个图像 patch token

    注意力就是"prompt 点应该去看图像的哪些区域 "------本质上是一次空间非局部聚合,而不是时序聚合。

  2. softmax 打分 → 拿到"空间关联图"

    attn = softmax(q_h @ k_h.T / √d_k)

    这一步得到 9×4096 的权重:每个 prompt 点对全图所有 patch 的关注程度。

    第一帧没有"历史",但全图 patch 本身就是最丰富的上下文 ;让 prompt 点直接"看到"整幅图,就能迅速把局部提示传播到全局目标区域,解决"一个点太小,边界在哪"的问题。

  3. 加权求和 → 生成"全局感知的 prompt 嵌入"

    out_h = attn @ v_h

    把权重乘回图像特征,得到 9 个已经融合了整幅图语义 的 16 维向量。

    这些向量再送进解码器,就能回归出精细的实例掩码,而不会出现"只给一个点就只能猜出一个小圆" 的局部歧义。

一句话总结

第一帧的"记忆库"就是当前帧自身 的全部 patch;注意力做的事是

"用稀疏提示去查询整张图,把提示特征扩散成完整的物体表示 "------

所以它依旧是那两行公式,只是查询范围从'跨帧历史'变成了'同帧全局' ,目的都是把局部线索扩展成完整目标

4.1.2.4 自注意力和交叉注意力有什么区别?为什么要先自注意力?

把"自注意力"和"交叉注意力"放到 SAM 2 的 TwoWayAttentionBlock 里,可以一句话区分:

  • 自注意力 :同一组 token 内部"互相聊天"------稀疏点之间交换信息。

  • 交叉注意力 :两组不同 token"跨组对话"------稀疏点 ⇄ 整张图交换信息。


1. 为什么先跑自注意力?

  1. 稀疏提示点先自我对齐

    用户可能一次给了 1 个点、3 个点或 9 个框,这些提示彼此独立喂进网络。自注意力让它们先"通气":

    • 点 A 知道点 B 也在物体上,于是更确信"这是一整个物体"而非孤立边缘。

    • 点 C 落在背景,通过自注意力被其他前景点"投票"抑制,减少误报。

      结果:每个 query token 的 256 维向量已经携带了多提示共识 ,再拿去和图像特征交叉时,查询信号更稳定、噪声更小

  2. 保持排列等变性 (permutation-equivariant)

    自注意力对输入顺序不敏感,保证"交换两个提示点顺序"输出不变,符合提示点的集合特性。

  3. 工程经验:先内聚、再外联

    几乎所有 Transformer 块都遵循
    "Self-Attn → Cross-Attn → FFN" 三部曲:

    • Self:把自己人说明白;

    • Cross:再出去谈合作;

    • FFN:最后各自消化。

      这样梯度、收敛、精度都更友好。


2. 自注意力 vs. 交叉注意力 对照表

维度 自注意力 交叉注意力
Q/K/V 来源 同一组 token 不同组 token
输入形状 [B, 9, 256] Q: [B, 9, 256] K/V: [B, 4096, 256]
作用范围 稀疏点之间 稀疏点 ↔ 图像块
目的 提示点内部共识 把提示扩散到全图 / 把图特征回传提示
参数量 一套投影矩阵 两套投影矩阵(q_proj 与 kv_proj 分开)
计算复杂度 O(9²) 可忽略 O(9×4096) 主导

一句话总结

先自注意力 = 让"点"先达成内部共识 ,再代表整个物体去和整张图谈判;

否则每个点各说各话,交叉注意力会被噪声带偏,分割边缘就飘了。

4.1.2.5 为什么q和k都要加个绝对位置编码,然后v没加而直接就是keys?

q = queries + query_pe # 给 query 加 PE

q: torch.Size([1, 9, 256])

k = keys + key_pe # 给 key 加 PE

k: torch.Size([1, 4096, 256])

q: torch.Size([1, 9, 256])

k: torch.Size([1, 4096, 256])

keys: torch.Size([1, 4096, 256])

attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) # 下采样在内部完成

attn_out: torch.Size([1, 9, 256])

为什么q和k都要加个绝对位置编码,然后v没加而直接就是keys?

一句话:
Q 和 K 需要"位置"来对齐"谁问谁",V 只负责"给答案"不需要位置。


  1. 注意力打分靠 Q·Kᵀ

    相似度 scoreᵢⱼ = (qᵢ + peᵢ) · (kⱼ + peⱼ)

    这里 pe 让网络知道

    "第 i 个提示点" 与 "第 j 个图像块" 在图像坐标上是否对应

    去掉 pe,所有向量就只剩"外观"没有"坐标",点乘结果会乱掉------提示点可能去 attend 很远但外观相似的区域


  1. V 只是"被加权取平均"的原材料

    attn_out = Σⱼ αᵢⱼ · vⱼ

    权重 αᵢⱼ 已经由 带位置的 Q、K 算好,vⱼ 用纯外观特征即可;

    再把 pe 加进 V 只会把位置信息重复注入,反而让网络难以区分"外观"与"位置",实验上通常掉点。


  1. 工程惯例

    Transformer 原始论文、DETR、SAM/SAM2 都沿用
    "Q、K 加位置,V 不加" 的套路,已成为视觉任务默认配置。


记忆口诀
"打分需要地址,送货只看内容。"

Q、K 带地址(pe)才能寄对快递;V 只管把货(特征)搬过来,地址早由权重 α 指定好了。

4.1.2.6 MLP.forward

sam2/modeling/sam2_utils.py

TwoWayAttentionBlock.forward里面调用了

mlp_out = self.mlp(queries)

python 复制代码
# Lightly adapted from
# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa


class MLP(nn.Module):
    """
    经典多层感知机(MLP):
    - 支持任意层数
    - 最后一层不加激活
    - 可选 sigmoid 输出
    常用于 Transformer 中的 FFN 子模块。
    """

    def __init__(
        self,
        input_dim: int,
        hidden_dim: int,
        output_dim: int,
        num_layers: int,
        activation: nn.Module = nn.ReLU,
        sigmoid_output: bool = False,
    ) -> None:
        super().__init__()
        self.num_layers = num_layers
        # 构造隐藏层维度列表:中间层全部用 hidden_dim
        h = [hidden_dim] * (num_layers - 1)
        # 顺序拼接 Linear:输入 → 隐藏 → ... → 输出
        self.layers = nn.ModuleList(
            nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
        )
        self.sigmoid_output = sigmoid_output  # 是否对最后一层加 sigmoid
        self.act = activation()               # 实例化激活函数

    def forward(self, x):
        # x: torch.Size([1, 9, 256])

        # 逐层前向:除最后一层外均接激活
        for i, layer in enumerate(self.layers):
            x = self.act(layer(x)) if i < self.num_layers - 1 else layer(x)

            # i=0 x: torch.Size([1, 9, 2048])   # 第一层升维
            # i=1 x: torch.Size([1, 9, 256])    # 第二层降回原维(残差分支用)

        # self.sigmoid_output: False
        if self.sigmoid_output:
            x = F.sigmoid(x)  # 若需要 0~1 范围则再套 sigmoid

        # x: torch.Size([1, 9, 256])
        return x
  1. 一个可复用的 MLP 积木,通常作为 Transformer 块里的 FFN(Feed-Forward Network)。

  2. 默认 2 层:先升维到 2048,再降回 256,配合残差连接,给模型增加非线性且保持通道维度一致。

  3. sigmoid_output 开关方便在需要概率输出(如 mask logits 后处理)时直接得到 0~1 值。

相关推荐
古城小栈4 小时前
AI + 区块链:去中心化智能的未来形态
人工智能·去中心化·区块链
心疼你的一切4 小时前
自然语言处理_NLP与Transformer架构
人工智能·深度学习·目标检测·机器学习·计算机视觉·自然语言处理·transformer
望外追晚4 小时前
mask_color_map.json丢失,导致分割标签.png无法导入X-Anylabeling的解决办法
人工智能·计算机视觉·json·paddlepaddle
沫儿笙4 小时前
安川YASKAWA焊接机器人管材焊接节气
人工智能·机器人
五月君_4 小时前
Node.js 企业级框架 Egg 4.0 发布:原生支持 AI 开发,架构全面革新
人工智能·架构·node.js
Java后端的Ai之路4 小时前
【分析式AI】-机器学习的分类以及学派
人工智能·机器学习·分类·aigc·分析式ai
飞哥数智坊4 小时前
Cursor 可视化编辑器实测:前端效率新利器,但仍需完善
人工智能·ai编程·cursor
海棠AI实验室4 小时前
从“会出图”到“能交付”:用 ChatGPT + Nano Banana/Midjourney 做一套现代高校图书馆方案
人工智能·chatgpt·midjourney·图书馆
Baihai_IDP4 小时前
对长上下文能力有不同要求,怎么选择合适的模型?
人工智能·面试·llm