Qwen2.5-VL - 多模态旋转位置嵌入(Multimodal Rotary Position Embedding, MRoPE)

Qwen2.5-VL - 多模态旋转位置嵌入(Multimodal Rotary Position Embedding, MRoPE)

flyfish

多模态旋转位置嵌入(Multimodal Rotary Position Embedding, MRoPE) 是 Qwen2-VL 及 Qwen2.5-VL 模型中用于处理多模态输入的关键技术,它通过扩展传统 RoPE(Rotary Position Embedding),实现了对文本、图像和视频等不同模态数据的统一位置编码。

Multimodal:多模态(指支持文本、图像、视频等多种数据类型)。

Rotary Position:旋转位置(源于旋转位置编码技术 RoPE,通过三角函数实现位置信息的编码)。

Embedding:嵌入(在深度学习中,指将位置信息映射为模型可处理的向量表示)。

传统RoPE在处理文本时只能理解一维的序列顺序,就像看一部没有时间轴的纪录片,面对视频和图像时完全无法感知时间流动和空间布局------比如两段帧率不同的视频,传统方法会把"10秒20帧"和"5秒20帧"都当作简单的帧编号序列,根本分不清动作快慢;面对图片时,也无法区分"左上角的猫"和"右下角的球"的空间位置。而MRoPE的核心创新在于给AI构建了一套"三维时空坐标系":在时间维度上,它用second_per_grid_ts参数将视频帧绑定真实时间,比如10秒20帧的视频每帧代表0.5秒,时间ID按0.5、1.0秒递增,5秒20帧的视频每帧代表0.25秒,时间ID按0.25、0.5秒递增,让AI能感知事件的真实节奏;在空间维度上,它把图像/视频帧划分为类似棋盘格的网格,用h_indexw_index生成每个块的(高度,宽度)坐标,比如2×2网格中左上角块对应(0,0),右上角对应(0,1),使模型能识别空间位置;在多模态衔接上,它确保视觉和文本的位置ID连续,比如视觉部分ID到100,文本就从101开始,让AI理解文字与画面的对应关系。

Qwen2.5的MRoPE进一步实现了绝对时间对齐,不再依赖帧编号,而是按视频实际时长计算时间ID------比如3秒的视频,不管采样3帧还是6帧,第1秒对应ID=1,第2秒对应ID=2,时间ID间隔会根据采样帧自动调整(3帧间隔1,6帧间隔0.5),代码中的time_tensor = expanded_range * second_per_grid_t * 2里,second_per_grid_t就是每秒的时间ID增量,2是放大倍数,让AI更敏感捕捉时间差异。这种设计让MRoPE在实际应用中展现出强大能力:在视频问答中,能通过时间ID精准定位"球员射门在第5秒";在图文文档解析时,给图片"左侧柱状图"标上(0,0)坐标,让文字"左侧"与之关联;在动态手势识别中,通过时间ID间隔区分"1秒1帧的缓慢挥手"和"1秒4帧的快速挥手"。

MRoPE就像AI的"时空翻译器",将视频的时间先后、图像的上下左右、文字的段落顺序,全部转化为"时间-高度-宽度"的三维坐标语言,让多模态模型不仅能"看到"信息,还能理解信息间的时空逻辑------这就好比人类看电影时能同时把握剧情的时间线、画面的空间布局和台词的前后关联。

MRoPE算法

传统RoPE

1. 旋转操作的复数表示

对于位置 m m m处的向量 x m ∈ R d x_m \in \mathbb{R}^d xm∈Rd,将其拆分为两个维度为 d / 2 d/2 d/2的子向量 x m ( 1 ) x_m^{(1)} xm(1)和 x m ( 2 ) x_m^{(2)} xm(2),RoPE的旋转操作可表示为:
RoPE ( x m , m ) = [ x m ( 1 ) cos ⁡ ( m θ ) − x m ( 2 ) sin ⁡ ( m θ ) x m ( 2 ) cos ⁡ ( m θ ) + x m ( 1 ) sin ⁡ ( m θ ) ] \text{RoPE}(x_m, m) = \begin{bmatrix} x_m^{(1)} \cos(m\theta) - x_m^{(2)} \sin(m\theta) \\ x_m^{(2)} \cos(m\theta) + x_m^{(1)} \sin(m\theta) \end{bmatrix} RoPE(xm,m)=[xm(1)cos(mθ)−xm(2)sin(mθ)xm(2)cos(mθ)+xm(1)sin(mθ)]

其中, θ = { θ 1 , θ 2 , ... , θ d / 2 } \theta = \{\theta_1, \theta_2, \ldots, \theta_{d/2}\} θ={θ1,θ2,...,θd/2}是一组可学习的频率参数。

2. 点积形式

RoPE通过旋转操作保持了位置感知的点积性质:
RoPE ( q m , m ) ⋅ RoPE ( k n , n ) = q m ⋅ k n cos ⁡ ( ( m − n ) θ ) + ( q m ⋅ k ~ n ) sin ⁡ ( ( m − n ) θ ) \text{RoPE}(q_m, m) \cdot \text{RoPE}(k_n, n) = q_m \cdot k_n \cos((m-n)\theta) + (q_m \cdot \tilde{k}_n) \sin((m-n)\theta) RoPE(qm,m)⋅RoPE(kn,n)=qm⋅kncos((m−n)θ)+(qm⋅k~n)sin((m−n)θ)

其中, k ~ n \tilde{k}_n k~n是 k n k_n kn的特定排列。

MRoPE的三维扩展公式

1. 三维位置编码分解

MRoPE将位置信息分解为时间 t t t、高度 h h h、宽度 w w w三个维度的旋转操作:
MRoPE ( x , t , h , w ) = RoPE t ( x ) ⊕ RoPE h ( x ) ⊕ RoPE w ( x ) \text{MRoPE}(x, t, h, w) = \text{RoPE}_t(x) \oplus \text{RoPE}_h(x) \oplus \text{RoPE}_w(x) MRoPE(x,t,h,w)=RoPEt(x)⊕RoPEh(x)⊕RoPEw(x)

其中, ⊕ \oplus ⊕表示三个维度的旋转操作的组合,通常通过张量拼接或加权求和实现。

2. 时间维度的绝对编码

Qwen2.5-VL引入绝对时间编码,将实际时间间隔映射为位置ID:
时间ID ( i ) = t 0 + t 1 − t 0 N ⋅ i ⋅ s \text{时间ID}(i) = t_0 + \frac{t_1 - t_0}{N} \cdot i \cdot s 时间ID(i)=t0+Nt1−t0⋅i⋅s

其中:

  • t 0 t_0 t0和 t 1 t_1 t1为视频的起始和结束时间,
  • N N N为总帧数,
  • i i i为当前帧索引,
  • s s s为可学习的缩放因子(代码中对应second_per_grid_t * 2)。

三维位置ID的生成公式

1. 时间网格的位置ID计算

代码中的时间位置ID生成对应公式:
时间ID ( t ) = t ⋅ Δ t ⋅ s \text{时间ID}(t) = t \cdot \Delta t \cdot s 时间ID(t)=t⋅Δt⋅s

其中:

  • t t t为时间块索引,
  • Δ t \Delta t Δt为每个时间块的秒数(second_per_grid_t),
  • s s s为缩放因子(代码中为2)。
2. 空间网格的位置ID计算

高度和宽度的位置ID通过网格索引生成:
高度ID ( h ) = h ( h = 0 , 1 , ... , H − 1 ) \text{高度ID}(h) = h \quad (h = 0, 1, \ldots, H-1) 高度ID(h)=h(h=0,1,...,H−1)
宽度ID ( w ) = w ( w = 0 , 1 , ... , W − 1 ) \text{宽度ID}(w) = w \quad (w = 0, 1, \ldots, W-1) 宽度ID(w)=w(w=0,1,...,W−1)

其中, H H H和 W W W分别为高度和宽度方向的网格数。

多模态融合的位置连续性公式

1. 视觉与文本的位置衔接

文本部分的起始位置ID为视觉部分的最大位置ID加1:
文本起始ID = max ⁡ ( 视觉时间ID , 视觉高度ID , 视觉宽度ID ) + 1 \text{文本起始ID} = \max(\text{视觉时间ID}, \text{视觉高度ID}, \text{视觉宽度ID}) + 1 文本起始ID=max(视觉时间ID,视觉高度ID,视觉宽度ID)+1

2. 整体位置ID序列

对于包含视觉和文本的混合序列,位置ID序列可表示为:
位置ID = [ 视觉ID 1 , 视觉ID 2 , ... , 视觉ID M , 文本起始ID , 文本起始ID + 1 , ... ] \text{位置ID} = [\text{视觉ID}_1, \text{视觉ID}_2, \ldots, \text{视觉ID}_M, \text{文本起始ID}, \text{文本起始ID}+1, \ldots] 位置ID=[视觉ID1,视觉ID2,...,视觉IDM,文本起始ID,文本起始ID+1,...]

其中, M M M为视觉token的数量。

基于大语言模型中的图像和视频的时间、高度和宽度维度,计算三维旋转位置编码索引。

复制代码
原理说明:
    每个嵌入序列包含视觉嵌入和文本嵌入,或仅包含文本嵌入。
    
    对于纯文本嵌入序列,旋转位置嵌入与现代大语言模型相同。
    示例:
        input_ids: [T T T T T],这里T代表文本。
        时间位置ID: [0, 1, 2, 3, 4]
        高度位置ID: [0, 1, 2, 3, 4]
        宽度位置ID: [0, 1, 2, 3, 4]
        
    对于视觉和文本混合嵌入序列,我们为视觉部分计算三维旋转位置嵌入,
    为文本部分计算一维旋转位置嵌入。
    示例:
        时间维度(Temporal):3个时间块,表示视频在时间上的不同片段。
        高度维度(Height):2个高度块,垂直划分每一帧。
        宽度维度(Width):2个宽度块,水平划分每一帧。
        我们还有一些重要参数:
        fps(每秒帧数):视频的帧率,设为1。这意味着每秒处理一帧。
        tokens_per_second:这是一个关键参数。它决定了概念上一秒视频间隔内包含多少"时间步"或"时间token"。
                            在这种情况下,我们每秒有25个token。因此,视频的每一秒将由25个不同的时间点表示。
                            它本质上定义了时间粒度。
        temporal_patch_size:构成一个时间块的帧数。这里是2帧。
        interval:时间位置ID的步长,计算为tokens_per_second * temporal_patch_size / fps。
                  在这种情况下,25 * 2 / 1 = 50。这意味着每个时间块的时间位置ID相差50。
        input_ids: [V V V V V V V V V V V V T T T T T],这里V代表视觉。
        视觉时间位置ID: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
        视觉高度位置ID: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
        视觉宽度位置ID: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
        文本时间位置ID: [101, 102, 103, 104, 105]
        文本高度位置ID: [101, 102, 103, 104, 105]
        文本宽度位置ID: [101, 102, 103, 104, 105]
        这里我们将文本的起始位置ID计算为视觉最大位置ID加1。

参数:
    input_ids (`torch.LongTensor`,形状为`(batch_size, sequence_length)`):
        输入序列在词汇表中的token索引。如果提供了注意力掩码,填充token将被忽略。
    image_grid_thw (`torch.LongTensor`,形状为`(num_images, 3)`,可选):
        大语言模型中每个图像特征的时间、高度和宽度维度。
    video_grid_thw (`torch.LongTensor`,形状为`(num_videos, 3)`,可选):
        大语言模型中每个视频特征的时间、高度和宽度维度。
    second_per_grid_ts (`torch.Tensor`,形状为`(num_videos)`,可选):
        3D位置ID中每个时间网格的时间间隔(以秒为单位)。
    attention_mask (`torch.Tensor`,形状为`(batch_size, sequence_length)`,可选):
        用于避免在填充token索引上执行注意力计算的掩码。掩码值选择为`[0, 1]`:

        - 1表示token**未被掩码**,
        - 0表示token**被掩码**。

返回:
    position_ids (`torch.LongTensor`,形状为`(3, batch_size, sequence_length)`)
    mrope_position_deltas (`torch.Tensor`,形状为`(batch_size)`)
py 复制代码
def get_rope_index_25(
    spatial_merge_size: Optional[int] = 2,
    input_ids: Optional[torch.LongTensor] = None,
    image_grid_thw: Optional[torch.LongTensor] = None,
    video_grid_thw: Optional[torch.LongTensor] = None,
    second_per_grid_ts: Optional[torch.Tensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:

    image_token_id = 151655
    video_token_id = 151656
    vision_start_token_id = 151652
    mrope_position_deltas = []
    if input_ids is not None and (
        image_grid_thw is not None or video_grid_thw is not None
    ):
        total_input_ids = input_ids
        if attention_mask is None:
            attention_mask = torch.ones_like(total_input_ids)
        position_ids = torch.ones(
            3,
            input_ids.shape[0],
            input_ids.shape[1],
            dtype=input_ids.dtype,
            device=input_ids.device,
        )
        image_index, video_index = 0, 0
        attention_mask = attention_mask.to(total_input_ids.device)
        for i, input_ids in enumerate(total_input_ids):
            input_ids = input_ids[attention_mask[i] == 1]
            image_nums, video_nums = 0, 0
            vision_start_indices = torch.argwhere(
                input_ids == vision_start_token_id
            ).squeeze(1)
            vision_tokens = input_ids[vision_start_indices + 1]
            image_nums = (vision_tokens == image_token_id).sum()
            video_nums = (vision_tokens == video_token_id).sum()
            input_tokens = input_ids.tolist()
            llm_pos_ids_list: list = []
            st = 0
            remain_images, remain_videos = image_nums, video_nums
            for _ in range(image_nums + video_nums):
                if image_token_id in input_tokens and remain_images > 0:
                    ed_image = input_tokens.index(image_token_id, st)
                else:
                    ed_image = len(input_tokens) + 1
                if video_token_id in input_tokens and remain_videos > 0:
                    ed_video = input_tokens.index(video_token_id, st)
                else:
                    ed_video = len(input_tokens) + 1
                if ed_image < ed_video:
                    t, h, w = (
                        image_grid_thw[image_index][0],
                        image_grid_thw[image_index][1],
                        image_grid_thw[image_index][2],
                    )
                    second_per_grid_t = 0
                    image_index += 1
                    remain_images -= 1
                    ed = ed_image

                else:
                    t, h, w = (
                        video_grid_thw[video_index][0],
                        video_grid_thw[video_index][1],
                        video_grid_thw[video_index][2],
                    )
                    if second_per_grid_ts is not None:
                        second_per_grid_t = second_per_grid_ts[video_index]
                    else:
                        second_per_grid_t = 1.0
                    video_index += 1
                    remain_videos -= 1
                    ed = ed_video
                llm_grid_t, llm_grid_h, llm_grid_w = (
                    t.item(),
                    h.item() // spatial_merge_size,
                    w.item() // spatial_merge_size,
                )
                text_len = ed - st

                st_idx = (
                    llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
                )
                llm_pos_ids_list.append(
                    torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
                )

                range_tensor = torch.arange(llm_grid_t).view(-1, 1)
                expanded_range = range_tensor.expand(-1, llm_grid_h * llm_grid_w)

                time_tensor = expanded_range * second_per_grid_t * 2

                time_tensor_long = time_tensor.long()
                t_index = time_tensor_long.flatten()

                h_index = (
                    torch.arange(llm_grid_h)
                    .view(1, -1, 1)
                    .expand(llm_grid_t, -1, llm_grid_w)
                    .flatten()
                )
                w_index = (
                    torch.arange(llm_grid_w)
                    .view(1, 1, -1)
                    .expand(llm_grid_t, llm_grid_h, -1)
                    .flatten()
                )
                llm_pos_ids_list.append(
                    torch.stack([t_index, h_index, w_index]) + text_len + st_idx
                )
                st = ed + llm_grid_t * llm_grid_h * llm_grid_w

            if st < len(input_tokens):
                st_idx = (
                    llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
                )
                text_len = len(input_tokens) - st
                llm_pos_ids_list.append(
                    torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx
                )

            llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
            position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(
                position_ids.device
            )
            mrope_position_deltas.append(
                llm_positions.max() + 1 - len(total_input_ids[i])
            )
        mrope_position_deltas = torch.tensor(
            mrope_position_deltas, device=input_ids.device
        ).unsqueeze(1)
        return position_ids, mrope_position_deltas
    else:
        if attention_mask is not None:
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask == 0, 1)
            position_ids = (
                position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
            )
            max_position_ids = position_ids.max(0, keepdim=False)[0].max(
                -1, keepdim=True
            )[0]
            mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
        else:
            position_ids = (
                torch.arange(input_ids.shape[1], device=input_ids.device)
                .view(1, 1, -1)
                .expand(3, input_ids.shape[0], -1)
            )
            mrope_position_deltas = torch.zeros(
                [input_ids.shape[0], 1],
                device=input_ids.device,
                dtype=input_ids.dtype,
            )

        return position_ids, mrope_position_deltas
相关推荐
X.Cristiano1 个月前
多模态大模型 Qwen2.5-VL 的学习之旅
多模态·qwen2.5-vl