[Token]ALGM: 基于自适应局部-全局token合并的简单视觉Transformer用于高效语义分割, CVPR2024

ALGM: Adaptive Local-then-Global Token Merging for Efficient Semantic Segmentation with Plain Vision Transformers
paper|code

Background & Motivation

具有高余弦相似度的token可以合并,而不会降低分割质量。

  1. CTS表明,在早期网络阶段进行局部token共享 可以提高效率,而不会影响分割质量,但它需要一个预处理网络 。 因此,我们的第一个目标是在网络浅层合并冗余符元,而无需预处理,同时保持分割质量。
  2. 像ToMe这样的token合并方法表明,逐渐合并整张图像 上的冗余token可以大大提高效率,但全局范围内合并损害分割质量。 因此,我们的第二个目标是应用全局token合并以进一步提高效率,同时不会损害分割质量。

Challenge

如何创造一个新方法,既能像CTS一样在早期就合并局部Token,又能像ToMe一样在全局范围内高效合并,同时没有额外的预处理网络,不损害分割质量。

沿用余弦相似度的标准,发现随着模型加深:

  1. 在早期,它足以在局部区分开不同物体
  2. 在后期,它能在全局上更清晰地区分不同物体

Method

基于这些发现,提出了Adaptive Local-then-Global Merging (ALGM) module,该模块集成了两个token合并阶段。在第一网络层中,ALGM 采用局部合并策略。 在中间层采用全局合并机制,以减少全局token冗余。 此外,不预设token的合并数量,而是根据图像内容的语义复杂度动态决定合并token的数量。

Token相似度分析

在何种情况下以及何时,余弦相似度能够成为一种有效的指标,用于识别代表同一类别的标记,从而使其适合进行局部和全局合并。

提取并比较了分词器生成的token与在 ADE20K训练集中训练的 ViT-S的相似性。

(1)首先,分析了第一层转置前向层中k×k窗口内的局部相似性。如图 2a 所示,窗口大小 k 越小,余弦相似度就越能准确地反映token属于同一类别。因此,在第一层中,在小局部窗口内具有高余弦相似度的token很可能可以合并,而不会导致分割质量下降。

(2)计算整张图像中所有 Transformer 层的类别间和类别内token的余弦相似度来分析全局相似度。如图 2b 所示,早期层中的全局相似度并不能准确反映类别对应关系,因此不应将其用于识别需要合并的token。然而,在网络更深的部分,余弦相似度成为一种更好的衡量标准,可以用于在全局范围内识别可以合并的标记,而不会影响分割质量。

Adaptive Local-then-Global Merging(ALGM)

(a)早期层中的局部token相似性以及(b)中间层中的全局token相似性很可能是衡量token合并能力的指标。提出自适应局部-然后全局合并(ALGM)方法。首先在第一层使用条件局部平均池化(CLAP)模块进行局部合并。在中间层,采用基于 BSM算法的全局二分合并(GBM)模块进行全局合并。整个过程以一个token解合并模块结束,以恢复原始的token解析。

Local token merging.

如果一个Token和它在一个小窗口内的邻居们高度相似,就将它们合并。CLAP模块,它被放置在第一层(L1)的MHSA和MLP模块之间,用来实现这个功能。
Step 1.

它接收来自第一层(L1)的Token T'1,并将其重新排列成一个空间网格 T'G1。然后,定义k×k大小的窗口,并将每个窗口内的Tokens分组到不同的集合W中。
Step 2.

计算小组内所有Token之间的余弦相似度,并求出这些相似度的平均值μw。然后,根据相似度代表可合并性的假设,CLAP模块只合并那些平均相似度 μ w μw μw大于阈值 τ τ τ的窗口。
Step 3.

被选中的窗口 w w w内的所有Tokens,通过计算它们的平均值,合并成一个Token。这些被合并的Tokens的原始索引也会被存储起来,以备后续的"解合并"(unmerging)操作。完成后,合并产生的新Token和那些未被合并的Token被连接在一起,生成最终的输出,其数量小于或等于原始数量。

Global token merging.


Step 1.

token分组与图构建,分成两组,构造二分图
Step 2.

寻找最佳匹配,找到唯一一个最合适的合并对象。
Step 3.

应用相似度阈值,保证足够相似的token对才被允许进入最后的合并阶段。
Step 4.

对于所有经过前两轮筛选后仍然保留下来的边,其连接的token对将被合并,并且存储索引 。所有未参与合并的token和那些合并后更新了的token被拼接在一起,形成一个新的、数量更少的token集合,作为下一层的输入。

Token unmerging.

利用合并时记录的索引信息,通过"复制粘贴"的方式,将被合并的Token还原到其原始位置,从而恢复出与输入图像同样尺寸的特征表示。

这个过程的执行时机取决于下游的解码器:

如果解码器是Transformer(不怕乱序),就先解码,后还原,效率更高。

如果解码器是CNN(要求整齐),就先还原,后解码,以满足其输入要求。

Adaptive token merging.

在训练之前,使用想要应用 ALGM 的基础分割模型,并在训练集中进行对比测试。然后,在每一层 Ll 中提取 MHSA 块之后的token,计算所有token对之间的余弦相似度,并计算整个训练集的平均相似度 µ s i m µsim µsim 和标准差 σ s i m σsim σsim。根据这些统计数据,设置阈值 τ = µ s i m + σ s i m τ = µsim + σsim τ=µsim+σsim。使用此阈值,经过 CLAP 和 GBM 模块后的剩余令牌数量 N' 和 N'' 会因图像而异。在训练过程中,为了便于对图像和标记进行分组处理,确定每次分组的最大剩余token数量 N' 和 N'',然后将这些数值应用于该批次中的所有图像。

算法实现

local_merge.py

code

python 复制代码
import math
from typing import Callable, Tuple

import torch
import torch.nn
import torch.nn.functional as F
from einops import rearrange
import numpy as np


def conditional_pooling(
    feat: torch.Tensor,
    threshold:float,
    window_size: Tuple[int, int],
) -> Tuple[Callable, Callable]:
    
    with torch.no_grad():
        
        ws_h, ws_w = int(window_size[0]), int(window_size[1])
        stride_h, stride_w = ws_h, ws_w
        num_token_window = stride_h * stride_w
        
        x_cls, feat = feat[:, :1, :], feat[:, 1:, :]
        B, N, D = feat.size()
        base_grid_H = int(math.sqrt(N))
        base_grid_W = base_grid_H
        assert base_grid_H * base_grid_W == N and base_grid_H % ws_h == 0 and base_grid_W % ws_w == 0

        feat = rearrange(feat, "b (h w) c -> b c h w", h=base_grid_H)
    
        feat = rearrange(feat, 'b c (gh ps_h) (gw ps_w) -> b gh gw c ps_h ps_w', gh=base_grid_H//ws_h, gw=base_grid_W//ws_w)
        b, gh, gw, c, ps_h, ps_w = feat.shape

        # Flatten mxm window for pairwise operations
        tensor_flattened = feat.reshape(b, gh, gw, c, -1)
    
        # Expand dims for pairwise operations
        tensor_1 = tensor_flattened.unsqueeze(-1)
        tensor_2 = tensor_flattened.unsqueeze(-2)

        # Compute cosine similarities
        sims = F.cosine_similarity(tensor_1, tensor_2, dim=3)

        # Exclude the self-similarity (i.e., similarity with oneself will be 1)
        sims_mask = 1 - torch.eye(ps_h * ps_w).to(sims.device)
        sims = sims * sims_mask

        # Average similarities (excluding the self-similarity)
        similarity_map = sims.sum(-1).sum(-1) / ((ps_h * ps_w) * (ps_h * ps_w - 1))
            
        similarity_map = rearrange(similarity_map.unsqueeze(1), 'b c h w-> b (c h w)')
        
        #--- adaptive section ---#
     
        n_B, n_H = similarity_map.shape
        node_mean = torch.tensor(threshold).cuda(sims.device)
        node_mean=node_mean.repeat(1,n_H)
        r = torch.ge(similarity_map, node_mean).sum(dim=1).min()
        # -------------# 
    
        #   get top k similar super patches 
        _, sim_super_patch_idxs = similarity_map.topk(r,dim=-1)
    
        # --- creating the mergabel and unmergable super  pathes
        tensor = torch.arange(base_grid_H * base_grid_W).reshape(base_grid_H, base_grid_W).to(feat.device)

        # Repeat the tensor to create a batch of size 2
        tensor = tensor.unsqueeze(0).repeat(B, 1, 1)
        

        # Apply unfold operation on last two dimensions to create the sliding window
        windowed_tensor = tensor.unfold(1, ws_h, stride_h).unfold(2, ws_w, stride_w)

        # Reshape the tensor to the desired shape 
        windowed_tensor = windowed_tensor.reshape(B, -1, num_token_window)
    
        # Use torch.gather to collect the desired elements
        gathered_tensor = torch.gather(windowed_tensor, 1, sim_super_patch_idxs.unsqueeze(-1).expand(-1, -1, num_token_window))

        # Create a mask for all indices, for each batch
        mask = torch.ones((B, windowed_tensor.shape[1]), dtype=bool).to(feat.device)

        # Create a tensor that matches the shape of indices and fill it with False
        mask_values = torch.zeros_like(sim_super_patch_idxs, dtype=torch.bool).to(feat.device)

        # Use scatter_ to update the mask. This will set mask[b, indices[b]] = False for all b
        mask.scatter_(1, sim_super_patch_idxs, mask_values)

        # Get the remaining tensor
        remaining_tensor = windowed_tensor[mask.unsqueeze(-1).expand(-1, -1, num_token_window)].reshape(B, -1, num_token_window)
        unm_idx = remaining_tensor.reshape(B, -1).sort(dim=-1).values.unsqueeze(-1)
        dim_index = (num_token_window)- 1 
        src_idx= gathered_tensor[:, :, :dim_index].reshape(B, -1).unsqueeze(-1)
        dst_idx= gathered_tensor[:, :, dim_index].reshape(B, -1).unsqueeze(-1)
        merge_idx = torch.arange(src_idx.shape[1]//dim_index).repeat_interleave(dim_index).repeat(B, 1).unsqueeze(-1).to(feat.device)


    def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
       # TODO: num_token_window can be undefined
       
        x_cls , x_feat =  x[:, :1, :], x[:, 1:, :]
        n, t1, c = x_feat.shape
        src = x_feat.gather(dim=-2, index=src_idx.expand(n, r*dim_index, c))
        dst = x_feat.gather(dim=-2, index=dst_idx.expand(n, r, c))
        unm = x_feat.gather(dim=-2, index=unm_idx.expand(n, t1 - (r*num_token_window), c))
        dst = dst.scatter_reduce(-2, merge_idx.expand(n,r*dim_index, c), src, reduce=mode)
        x = torch.cat([dst, unm], dim=1)
        x = torch.cat((x_cls, x), dim=1)
        return x

    return merge

def merge_wavg(
    merge: Callable, x: torch.Tensor, size: torch.Tensor = None
) -> Tuple[torch.Tensor, torch.Tensor]:
    
    if size is None:
        size = torch.ones_like(x[..., 0, None])

    x = merge(x * size, mode="sum")
    size = merge(size, mode="sum")    
    x = x / size
    
    return x, size


def merge_source(
    merge: Callable, x: torch.Tensor, source: torch.Tensor = None
) -> torch.Tensor:

    if source is None:
        n, t, _ = x.shape
        source = torch.eye(t, device=x.device)[None, ...].expand(n, t, t)

    source = merge(source, mode="amax")
    return source

global_merge.py

code

python 复制代码
import math
from typing import Callable, Tuple
import torch


def do_nothing(x, mode=None):
    return x


def turbo_matching(
    metric: torch.Tensor,
    layer_idx:int,
    source: torch.Tensor,
    class_token: bool = False,
    distill_token: bool = False,
) -> Tuple[Callable, Callable]:
    
    
    protected = 0
    if class_token:
        protected += 1
    if distill_token:
        protected += 1

    t = metric.shape[1]
    r = (t - protected) // 2

    if r <= 0:
        return do_nothing, do_nothing

    with torch.no_grad():

        B,m_t,um_t = source.shape
        metric = metric / metric.norm(dim=-1, keepdim=True)
        a, b = metric[..., ::2, :], metric[..., 1::2, :]
        scores = a @ b.transpose(-1, -2)
       
        if class_token:
            scores[..., 0, :] = -math.inf
        if distill_token:
            scores[..., :, 0] = -math.inf


        node_max, node_idx = scores.max(dim=-1)
        edge_idx = node_max.argsort(dim=-1, descending=True)[..., None]

        # ------------------ start  addaptive section --------- 
        i = layer_idx
        n_B, n_H = node_max.shape
        node_mean= torch.add(node_max[:,1:].mean(dim=1).mean(),node_max[:,1:].std(dim=1).mean()/i)
        node_mean=node_mean.repeat(1,n_H)
        r = torch.ge(node_max, node_mean).sum(dim=1).min()

        # ------------------ end addaptive section --------- 
        
        unm_idx = edge_idx[..., r:, :]  # Unmerged Tokens
        src_idx = edge_idx[..., :r, :]  # Merged Tokens
        dst_idx = node_idx[..., None].gather(dim=-2, index=src_idx)

        if class_token:
            # Sort to ensure the class token is at the start
            unm_idx = unm_idx.sort(dim=1)[0]


    def merge(x: torch.Tensor, mode="mean") -> torch.Tensor:
        src, dst = x[..., ::2, :], x[..., 1::2, :]
        n, t1, c = src.shape
        unm = src.gather(dim=-2, index=unm_idx.expand(n, t1 - r, c))
        src = src.gather(dim=-2, index=src_idx.expand(n, r, c))
        dst = dst.scatter_reduce(-2, dst_idx.expand(n, r, c), src, reduce=mode)

        if distill_token:
            return torch.cat([unm[:, :1], dst[:, :1], unm[:, 1:], dst[:, 1:]], dim=1)
        else:
            return torch.cat([unm, dst], dim=1)

    return merge

Apply ALGM between the attention and mlp blocks

code

python 复制代码
class TurboBlock(Block):
    """
    Modifications:
     - Apply ALGM between the attention and mlp blocks
    """

    def _drop_path1(self, x):
        return self.drop_path1(x) if hasattr(self, "drop_path1") else self.drop_path(x)

    def _drop_path2(self, x):
        return self.drop_path2(x) if hasattr(self, "drop_path2") else self.drop_path(x)

    def forward(self, x: torch.Tensor ) -> torch.Tensor:
      
        attn_size = self._turbo_info["size"] if self._turbo_info["prop_attn"] else None
        x_attn, metric  = self.attn(self.norm1(x),attn_size)
        x =  x + self._drop_path1(x_attn)
        layer_idx = self._turbo_info["selected_layers"].pop(0)
           
        if self._turbo_info["source"] is None: # if layer_idx == 1:
                
                merge  = conditional_pooling(
                    x,
                    self._turbo_info["threshold"],
                    self._turbo_info["window_size"],
                )
                if self._turbo_info["trace_source"]:
                        self._turbo_info["source"] = merge_source(
                            merge, x, self._turbo_info["source"]
                        )
                x, self._turbo_info["size"] = merge_wavg(merge, x, self._turbo_info["size"])
                
              
        else:
              
                merge = turbo_matching(
                    x,
                    layer_idx,
                    self._turbo_info["source"],
                    self._turbo_info["class_token"],
                    self._turbo_info["distill_token"],
                )
                if self._turbo_info["trace_source"]:
                    self._turbo_info["source"] = merge_source(
                        merge, x, self._turbo_info["source"]
                    )
                x, self._turbo_info["size"] = merge_wavg(merge, x, self._turbo_info["size"])
           
        
        x = x + self._drop_path2(self.mlp(self.norm2(x)))
        
        return x 

实验结果


Inspire

  1. local划分、合并的策略是否在low-level像素级任务上是有效的,替代window attention(复杂度)
相关推荐
xingxing_F12 分钟前
Topaz Video AI for Mac AI视频无损放大 视频画质增强
人工智能·macos·音视频
普蓝机器人22 分钟前
面向智慧农业的自主移动果蔬采摘机器人:融合视觉识别与自动驾驶的智能化农作系统研究
人工智能·学习·机器人·移动机器人·三维仿真导航
卷福同学24 分钟前
AI浏览器comet拉新,一单20美元(附详细教程)
人工智能·后端
说私域1 小时前
基于开源链动2+1模式AI智能名片S2B2C商城小程序的市场份额扩张路径研究
人工智能·小程序·开源
文火冰糖的硅基工坊1 小时前
[人工智能-大模型-72]:模型层技术 - 模型训练六大步:①数据预处理 - 基本功能与对应的基本组成函数
开发语言·人工智能·python
东经116度2 小时前
权重初始化方法详解
深度学习·机器学习·xavier初始化·全零初始化·随机初始化·he初始化
晚霞apple2 小时前
三维重建技术的未来创新方向
论文阅读·人工智能·深度学习·神经网络·机器学习
NocoBase2 小时前
GitHub 上最值得关注的 14 个开源 AI 低代码工具
人工智能·低代码·github
无风听海2 小时前
神经网络之语义空间
人工智能·深度学习·神经网络