[Token]ALGM: 基于自适应局部-全局token合并的简单视觉Transformer用于高效语义分割, CVPR2024

ALGM: Adaptive Local-then-Global Token Merging for Efficient Semantic Segmentation with Plain Vision Transformers
paper|code

Background & Motivation

具有高余弦相似度的token可以合并,而不会降低分割质量。

  1. CTS表明,在早期网络阶段进行局部token共享 可以提高效率,而不会影响分割质量,但它需要一个预处理网络 。 因此,我们的第一个目标是在网络浅层合并冗余符元,而无需预处理,同时保持分割质量。
  2. 像ToMe这样的token合并方法表明,逐渐合并整张图像 上的冗余token可以大大提高效率,但全局范围内合并损害分割质量。 因此,我们的第二个目标是应用全局token合并以进一步提高效率,同时不会损害分割质量。

Challenge

如何创造一个新方法,既能像CTS一样在早期就合并局部Token,又能像ToMe一样在全局范围内高效合并,同时没有额外的预处理网络,不损害分割质量。

沿用余弦相似度的标准,发现随着模型加深:

  1. 在早期,它足以在局部区分开不同物体
  2. 在后期,它能在全局上更清晰地区分不同物体

Method

基于这些发现,提出了Adaptive Local-then-Global Merging (ALGM) module,该模块集成了两个token合并阶段。在第一网络层中,ALGM 采用局部合并策略。 在中间层采用全局合并机制,以减少全局token冗余。 此外,不预设token的合并数量,而是根据图像内容的语义复杂度动态决定合并token的数量。

Token相似度分析

在何种情况下以及何时,余弦相似度能够成为一种有效的指标,用于识别代表同一类别的标记,从而使其适合进行局部和全局合并。

提取并比较了分词器生成的token与在 ADE20K训练集中训练的 ViT-S的相似性。

(1)首先,分析了第一层转置前向层中k×k窗口内的局部相似性。如图 2a 所示,窗口大小 k 越小,余弦相似度就越能准确地反映token属于同一类别。因此,在第一层中,在小局部窗口内具有高余弦相似度的token很可能可以合并,而不会导致分割质量下降。

(2)计算整张图像中所有 Transformer 层的类别间和类别内token的余弦相似度来分析全局相似度。如图 2b 所示,早期层中的全局相似度并不能准确反映类别对应关系,因此不应将其用于识别需要合并的token。然而,在网络更深的部分,余弦相似度成为一种更好的衡量标准,可以用于在全局范围内识别可以合并的标记,而不会影响分割质量。

Adaptive Local-then-Global Merging(ALGM)

(a)早期层中的局部token相似性以及(b)中间层中的全局token相似性很可能是衡量token合并能力的指标。提出自适应局部-然后全局合并(ALGM)方法。首先在第一层使用条件局部平均池化(CLAP)模块进行局部合并。在中间层,采用基于 BSM算法的全局二分合并(GBM)模块进行全局合并。整个过程以一个token解合并模块结束,以恢复原始的token解析。

Local token merging.

如果一个Token和它在一个小窗口内的邻居们高度相似,就将它们合并。CLAP模块,它被放置在第一层(L1)的MHSA和MLP模块之间,用来实现这个功能。
Step 1.

它接收来自第一层(L1)的Token T'1,并将其重新排列成一个空间网格 T'G1。然后,定义k×k大小的窗口,并将每个窗口内的Tokens分组到不同的集合W中。
Step 2.

计算小组内所有Token之间的余弦相似度,并求出这些相似度的平均值μw。然后,根据相似度代表可合并性的假设,CLAP模块只合并那些平均相似度 μ w μw μw大于阈值 τ τ τ的窗口。
Step 3.

被选中的窗口 w w w内的所有Tokens,通过计算它们的平均值,合并成一个Token。这些被合并的Tokens的原始索引也会被存储起来,以备后续的"解合并"(unmerging)操作。完成后,合并产生的新Token和那些未被合并的Token被连接在一起,生成最终的输出,其数量小于或等于原始数量。

Global token merging.


Step 1.

token分组与图构建,分成两组,构造二分图
Step 2.

寻找最佳匹配,找到唯一一个最合适的合并对象。
Step 3.

应用相似度阈值,保证足够相似的token对才被允许进入最后的合并阶段。
Step 4.

对于所有经过前两轮筛选后仍然保留下来的边,其连接的token对将被合并,并且存储索引 。所有未参与合并的token和那些合并后更新了的token被拼接在一起,形成一个新的、数量更少的token集合,作为下一层的输入。

Token unmerging.

利用合并时记录的索引信息,通过"复制粘贴"的方式,将被合并的Token还原到其原始位置,从而恢复出与输入图像同样尺寸的特征表示。

这个过程的执行时机取决于下游的解码器:

如果解码器是Transformer(不怕乱序),就先解码,后还原,效率更高。

如果解码器是CNN(要求整齐),就先还原,后解码,以满足其输入要求。

Adaptive token merging.

在训练之前,使用想要应用 ALGM 的基础分割模型,并在训练集中进行对比测试。然后,在每一层 Ll 中提取 MHSA 块之后的token,计算所有token对之间的余弦相似度,并计算整个训练集的平均相似度 µ s i m µsim µsim 和标准差 σ s i m σsim σsim。根据这些统计数据,设置阈值 τ = µ s i m + σ s i m τ = µsim + σsim τ=µsim+σsim。使用此阈值,经过 CLAP 和 GBM 模块后的剩余令牌数量 N' 和 N'' 会因图像而异。在训练过程中,为了便于对图像和标记进行分组处理,确定每次分组的最大剩余token数量 N' 和 N'',然后将这些数值应用于该批次中的所有图像。

算法实现

local_merge.py

code

python 复制代码
import math
from typing import Callable, Tuple

import torch
import torch.nn
import torch.nn.functional as F
from einops import rearrange
import numpy as np


def conditional_pooling(
    feat: torch.Tensor,
    threshold:float,
    window_size: Tuple[int, int],
) -> Tuple[Callable, Callable]:
    
    with torch.no_grad():
        
        ws_h, ws_w = int(window_size[0]), int(window_size[1])
        stride_h, stride_w = ws_h, ws_w
        num_token_window = stride_h * stride_w
        
        x_cls, feat = feat[:, :1, :], feat[:, 1:, :]
        B, N, D = feat.size()
        base_grid_H = int(math.sqrt(N))
        base_grid_W = base_grid_H
        assert base_grid_H * base_grid_W == N and base_grid_H % ws_h == 0 and base_grid_W % ws_w == 0

        feat = rearrange(feat, "b (h w) c -> b c h w", h=base_grid_H)
    
        feat = rearrange(feat, 'b c (gh ps_h) (gw ps_w) -> b gh gw c ps_h ps_w', gh=base_grid_H//ws_h, gw=base_grid_W//ws_w)
        b, gh, gw, c, ps_h, ps_w = feat.shape

        # Flatten mxm window for pairwise operations
        tensor_flattened = feat.reshape(b, gh, gw, c, -1)
    
        # Expand dims for pairwise operations
        tensor_1 = tensor_flattened.unsqueeze(-1)
        tensor_2 = tensor_flattened.unsqueeze(-2)

        # Compute cosine similarities
        sims = F.cosine_similarity(tensor_1, tensor_2, dim=3)

        # Exclude the self-similarity (i.e., similarity with oneself will be 1)
        sims_mask = 1 - torch.eye(ps_h * ps_w).to(sims.device)
        sims = sims * sims_mask

        # Average similarities (excluding the self-similarity)
        similarity_map = sims.sum(-1).sum(-1) / ((ps_h * ps_w) * (ps_h * ps_w - 1))
            
        similarity_map = rearrange(similarity_map.unsqueeze(1), 'b c h w-> b (c h w)')
        
        #--- adaptive section ---#
     
        n_B, n_H = similarity_map.shape
        node_mean = torch.tensor(threshold).cuda(sims.device)
        node_mean=node_mean.repeat(1,n_H)
        r = torch.ge(similarity_map, node_mean).sum(dim=1).min()
        # -------------# 
    
        #   get top k similar super patches 
        _, sim_super_patch_idxs = similarity_map.topk(r,dim=-1)
    
        # --- creating the mergabel and unmergable super  pathes
        tensor = torch.arange(base_grid_H * base_grid_W).reshape(base_grid_H, base_grid_W).to(feat.device)

        # Repeat the tensor to create a batch of size 2
        tensor = tensor.unsqueeze(0).repeat(B, 1, 1)
        

        # Apply unfold operation on last two dimensions to create the sliding window
        windowed_tensor = tensor.unfold(1, ws_h, stride_h).unfold(2, ws_w, stride_w)

        # Reshape the tensor to the desired shape 
        windowed_tensor = windowed_tensor.reshape(B, -1, num_token_window)
    
        # Use torch.gather to collect the desired elements
        gathered_tensor = torch.gather(windowed_tensor, 1, sim_super_patch_idxs.unsqueeze(-1).expand(-1, -1, num_token_window))

        # Create a mask for all indices, for each batch
        mask = torch.ones((B, windowed_tensor.shape[1]), dtype=bool).to(feat.device)

        # Create a tensor that matches the shape of indices and fill it with False
        mask_values = torch.zeros_like(sim_super_patch_idxs, dtype=torch.bool).to(feat.device)

        # Use scatter_ to update the mask. This will set mask[b, indices[b]] = False for all b
        mask.scatter_(1, sim_super_patch_idxs, mask_values)

        # Get the remaining tensor
        remaining_tensor = windowed_tensor[mask.unsqueeze(-1).expand(-1, -1, num_token_window)].reshape(B, -1, num_token_window)
        unm_idx = remaining_tensor.reshape(B, -1).sort(dim=-1).values.unsqueeze(-1)
        dim_index = (num_token_window)- 1 
        src_idx= gathered_tensor[:, :, :dim_index].reshape(B, -1).unsqueeze(-1)
        dst_idx= gathered_tensor[:, :, dim_index].reshape(B, -1).unsqueeze(-1)
        merge_idx = torch.arange(src_idx.shape[1]//dim_index).repeat_interleave(dim_index).repeat(B, 1).unsqueeze(-1).to(feat.device)


    def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
       # TODO: num_token_window can be undefined
       
        x_cls , x_feat =  x[:, :1, :], x[:, 1:, :]
        n, t1, c = x_feat.shape
        src = x_feat.gather(dim=-2, index=src_idx.expand(n, r*dim_index, c))
        dst = x_feat.gather(dim=-2, index=dst_idx.expand(n, r, c))
        unm = x_feat.gather(dim=-2, index=unm_idx.expand(n, t1 - (r*num_token_window), c))
        dst = dst.scatter_reduce(-2, merge_idx.expand(n,r*dim_index, c), src, reduce=mode)
        x = torch.cat([dst, unm], dim=1)
        x = torch.cat((x_cls, x), dim=1)
        return x

    return merge

def merge_wavg(
    merge: Callable, x: torch.Tensor, size: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    
    if size is None:
        size = torch.ones_like(x[..., 0, None])

    x = merge(x * size, mode="sum")
    size = merge(size, mode="sum")    
    x = x / size
    
    return x, size


def merge_source(
    merge: Callable, x: torch.Tensor, source: torch.Tensor = None
) -> torch.Tensor:

    if source is None:
        n, t, _ = x.shape
        source = torch.eye(t, device=x.device)[None, ...].expand(n, t, t)

    source = merge(source, mode="amax")
    return source

global_merge.py

code

python 复制代码
import math
from typing import Callable, Tuple
import torch


def do_nothing(x, mode=None):
    return x


def turbo_matching(
    metric: torch.Tensor,
    layer_idx:int,
    source: torch.Tensor,
    class_token: bool = False,
    distill_token: bool = False,
) -> Tuple[Callable, Callable]:
    
    
    protected = 0
    if class_token:
        protected += 1
    if distill_token:
        protected += 1

    t = metric.shape[1]
    r = (t - protected) // 2

    if r <= 0:
        return do_nothing, do_nothing

    with torch.no_grad():

        B,m_t,um_t = source.shape
        metric = metric / metric.norm(dim=-1, keepdim=True)
        a, b = metric[..., ::2, :], metric[..., 1::2, :]
        scores = a @ b.transpose(-1, -2)
       
        if class_token:
            scores[..., 0, :] = -math.inf
        if distill_token:
            scores[..., :, 0] = -math.inf


        node_max, node_idx = scores.max(dim=-1)
        edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]

        # ------------------ start  addaptive section --------- 
        i = layer_idx
        n_B, n_H = node_max.shape
        node_mean= torch.add(node_max[:,1:].mean(dim=1).mean(),node_max[:,1:].std(dim=1).mean()/i)
        node_mean=node_mean.repeat(1,n_H)
        r = torch.ge(node_max, node_mean).sum(dim=1).min()

        # ------------------ end addaptive section --------- 
        
        unm_idx = edge_idx[..., r:, :]  # Unmerged Tokens
        src_idx = edge_idx[..., :r, :]  # Merged Tokens
        dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)

        if class_token:
            # Sort to ensure the class token is at the start
            unm_idx = unm_idx.sort(dim=1)[0]


    def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
        src, dst = x[..., ::2, :], x[..., 1::2, :]
        n, t1, c = src.shape
        unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
        src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
        dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)

        if distill_token:
            return torch.cat([unm[:, :1], dst[:, :1], unm[:, 1:], dst[:, 1:]], dim=1)
        else:
            return torch.cat([unm, dst], dim=1)

    return merge

Apply ALGM between the attention and mlp blocks

code

python 复制代码
class TurboBlock(Block):
    """
    Modifications:
     - Apply ALGM between the attention and mlp blocks
    """

    def _drop_path1(self, x):
        return self.drop_path1(x) if hasattr(self, "drop_path1") else self.drop_path(x)

    def _drop_path2(self, x):
        return self.drop_path2(x) if hasattr(self, "drop_path2") else self.drop_path(x)

    def forward(self, x: torch.Tensor ) -> torch.Tensor:
      
        attn_size = self._turbo_info["size"] if self._turbo_info["prop_attn"] else None
        x_attn, metric  = self.attn(self.norm1(x),attn_size)
        x =  x + self._drop_path1(x_attn)
        layer_idx = self._turbo_info["selected_layers"].pop(0)
           
        if self._turbo_info["source"] is None: # if layer_idx == 1:
                
                merge  = conditional_pooling(
                    x,
                    self._turbo_info["threshold"],
                    self._turbo_info["window_size"],
                )
                if self._turbo_info["trace_source"]:
                        self._turbo_info["source"] = merge_source(
                            merge, x, self._turbo_info["source"]
                        )
                x, self._turbo_info["size"] = merge_wavg(merge, x, self._turbo_info["size"])
                
              
        else:
              
                merge = turbo_matching(
                    x,
                    layer_idx,
                    self._turbo_info["source"],
                    self._turbo_info["class_token"],
                    self._turbo_info["distill_token"],
                )
                if self._turbo_info["trace_source"]:
                    self._turbo_info["source"] = merge_source(
                        merge, x, self._turbo_info["source"]
                    )
                x, self._turbo_info["size"] = merge_wavg(merge, x, self._turbo_info["size"])
           
        
        x = x + self._drop_path2(self.mlp(self.norm2(x)))
        
        return x 

实验结果


Inspire

  1. local划分、合并的策略是否在low-level像素级任务上是有效的,替代window attention(复杂度)
相关推荐
007tg12 小时前
从ChatGPT家长控制功能看AI合规与技术应对策略
人工智能·chatgpt·企业数据安全
Memene摸鱼日报12 小时前
「Memene 摸鱼日报 2025.9.11」腾讯推出命令行编程工具 CodeBuddy Code, ChatGPT 开发者模式迎来 MCP 全面支持
人工智能·chatgpt·agi
linjoe9912 小时前
【Deep Learning】Ubuntu配置深度学习环境
人工智能·深度学习·ubuntu
先做个垃圾出来………13 小时前
残差连接的概念与作用
人工智能·算法·机器学习·语言模型·自然语言处理
AI小书房14 小时前
【人工智能通识专栏】第十三讲:图像处理
人工智能
fanstuck14 小时前
基于大模型的个性化推荐系统实现探索与应用
大数据·人工智能·语言模型·数据挖掘
多看书少吃饭15 小时前
基于 OpenCV 的眼球识别算法以及青光眼算法识别
人工智能·opencv·计算机视觉
一条数据库15 小时前
南京方言数据集|300小时高质量自然对话音频|专业录音棚采集|方言语音识别模型训练|情感计算研究|方言保护文化遗产数字化|语音情感识别|方言对话系统开发
人工智能·音视频·语音识别
Yingjun Mo16 小时前
1. 统计推断-基于神经网络与Langevin扩散的自适应潜变量建模与优化
人工智能·神经网络·算法·机器学习·概率论