Qwen2.5-VL模型架构解读——模型部分

Qwen2.5-VL模型架构解读------数据处理部分

模型架构图:

Visual Encoder

python 复制代码
class Qwen2_5_VisionTransformerPretrainedModel(Qwen2_5_VLPreTrainedModel):
    config: Qwen2_5_VLVisionConfig
    _no_split_modules = ["Qwen2_5_VLVisionBlock"]

    def __init__(self, config, *inputs, **kwargs) -> None:
        super().__init__(config, *inputs, **kwargs)
        self.spatial_merge_size = config.spatial_merge_size
        self.patch_size = config.patch_size
        self.fullatt_block_indexes = config.fullatt_block_indexes
        self.window_size = config.window_size
        self.spatial_merge_unit = self.spatial_merge_size * self.spatial_merge_size

        self.patch_embed = Qwen2_5_VisionPatchEmbed(
            patch_size=config.patch_size,
            temporal_patch_size=config.temporal_patch_size,
            in_channels=config.in_channels,
            embed_dim=config.hidden_size,
        )

        head_dim = config.hidden_size // config.num_heads
        self.rotary_pos_emb = Qwen2_5_VisionRotaryEmbedding(head_dim // 2)

        self.blocks = nn.ModuleList([Qwen2_5_VLVisionBlock(config) for _ in range(config.depth)])
        self.merger = Qwen2_5_VLPatchMerger(
            dim=config.out_hidden_size,
            context_dim=config.hidden_size,
            spatial_merge_size=config.spatial_merge_size,
        )
        self.gradient_checkpointing = False

    def rot_pos_emb(self, grid_thw):
        pos_ids = []
        for t, h, w in grid_thw:
            hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
            hpos_ids = hpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            hpos_ids = hpos_ids.permute(0, 2, 1, 3)
            hpos_ids = hpos_ids.flatten()

            wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
            wpos_ids = wpos_ids.reshape(
                h // self.spatial_merge_size,
                self.spatial_merge_size,
                w // self.spatial_merge_size,
                self.spatial_merge_size,
            )
            wpos_ids = wpos_ids.permute(0, 2, 1, 3)
            wpos_ids = wpos_ids.flatten()
            pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
        pos_ids = torch.cat(pos_ids, dim=0)
        max_grid_size = grid_thw[:, 1:].max()
        rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
        rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
        return rotary_pos_emb

    def get_window_index(self, grid_thw):
        window_index: list = []
        cu_window_seqlens: list = [0]
        window_index_id = 0
        vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size

        for grid_t, grid_h, grid_w in grid_thw:
            llm_grid_h, llm_grid_w = (
                grid_h // self.spatial_merge_size,
                grid_w // self.spatial_merge_size,
            )
            index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
            pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
            pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
            num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
            num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
            index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
            index_padded = index_padded.reshape(
                grid_t,
                num_windows_h,
                vit_merger_window_size,
                num_windows_w,
                vit_merger_window_size,
            )
            index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
                grid_t,
                num_windows_h * num_windows_w,
                vit_merger_window_size,
                vit_merger_window_size,
            )
            seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
            index_padded = index_padded.reshape(-1)
            index_new = index_padded[index_padded != -100]
            window_index.append(index_new + window_index_id)
            cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
            cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
            window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
        window_index = torch.cat(window_index, dim=0)

        return window_index, cu_window_seqlens

    def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
        """
        Args:
            hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
                The final hidden states of the model.
            grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
                The temporal, height and width of feature shape of each image in LLM.

        Returns:
            `torch.Tensor`: hidden_states.
        """
        hidden_states = self.patch_embed(hidden_states)
        rotary_pos_emb = self.rot_pos_emb(grid_thw)
        window_index, cu_window_seqlens = self.get_window_index(grid_thw)
        cu_window_seqlens = torch.tensor(
            cu_window_seqlens,
            device=hidden_states.device,
            dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
        )
        cu_window_seqlens = torch.unique_consecutive(cu_window_seqlens)

        seq_len, _ = hidden_states.size()
        hidden_states = hidden_states.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
        hidden_states = hidden_states[window_index, :, :]
        hidden_states = hidden_states.reshape(seq_len, -1)
        rotary_pos_emb = rotary_pos_emb.reshape(seq_len // self.spatial_merge_unit, self.spatial_merge_unit, -1)
        rotary_pos_emb = rotary_pos_emb[window_index, :, :]
        rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1)
        emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
        position_embeddings = (emb.cos(), emb.sin())

        cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
            dim=0,
            # Select dtype based on the following factors:
            #  - FA2 requires that cu_seqlens_q must have dtype int32
            #  - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
            # See https://github.com/huggingface/transformers/pull/34852 for more information
            dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
        )
        cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)

        for layer_num, blk in enumerate(self.blocks):
            if layer_num in self.fullatt_block_indexes:
                cu_seqlens_now = cu_seqlens
            else:
                cu_seqlens_now = cu_window_seqlens

            hidden_states = blk(
                hidden_states,
                cu_seqlens=cu_seqlens_now,
                position_embeddings=position_embeddings,
                **kwargs,
            )

        hidden_states = self.merger(hidden_states)
        reverse_indices = torch.argsort(window_index)
        hidden_states = hidden_states[reverse_indices, :]

        return hidden_states

Step1:Patch Embedding

将输入的图像或视频帧(已预切成固定大小的时空 patch)线性投影为统一维度的 token。

python 复制代码
class Qwen2_5_VisionPatchEmbed(nn.Module):
    def __init__(
        self,
        patch_size: int = 14,
        temporal_patch_size: int = 2,
        in_channels: int = 3,
        embed_dim: int = 1280,
    ) -> None:
        super().__init__()
        self.patch_size = patch_size
        self.temporal_patch_size = temporal_patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim

        kernel_size = [temporal_patch_size, patch_size, patch_size]
        self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        target_dtype = self.proj.weight.dtype
        hidden_states = hidden_states.view(
            -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
        )
        hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
        return hidden_states
  • input shape: [num_images*grid_w*grid_h, channel * temporal_patch_size * patch_size * patch_size]
  • output shape: [num_images*grid_h*grid_w, embed_dim]

Step2: 旋转位置编码(RoPE, Rotary Position Embedding)

Qwen2.5-VL 模型中的旋转位置编码(RoPE, Rotary Position Embedding),用于为视觉 token 提供空间位置信息。它包含两个部分:

  1. Qwen2_5_VisionRotaryEmbedding :生成基础的 RoPE 频率表(freqs)。
  2. rot_pos_emb 方法 (属于某个视觉编码器类):根据图像/视频的网格布局(grid_thw)和空间合并策略(spatial_merge_size),为每个视觉 patch 构建对应的 RoPE,并从预计算的 RoPE 表中索引出实际位置编码。

Qwen2_5_VisionRotaryEmbedding 类:

RoPE 的数学构造(关键公式)

对维度为 d d d的向量(通常 d d d为偶数),将其分成 d / 2 d/2 d/2 个二维子空间。对于位置 m m m,在第 i i i个子空间(对应维度 2 i 2i 2i 和 2 i + 1 2i+1 2i+1)上,定义旋转角度:

θ i = m ⋅ 1 θ 2 i / d , 其中 θ = 10000 θ_i=m⋅\frac {1} {θ^{2i/d}}, 其中 θ=10000 θi=m⋅θ2i/d1,其中θ=10000

则 RoPE 将向量 ( x 2 i , x 2 i + 1 ) (x_{2i},x_{2i+1}) (x2i,x2i+1) 变换为:

( c o s θ i − s i n θ i s i n θ i c o s θ i ) ( x 2 i x 2 i + 1 ) (\begin{matrix} cos \theta_i&-sin\theta_i\\ sin\theta_i&cos \theta_i \end{matrix})(\begin{matrix} x_{2i}\\ x_{2i+1} \end{matrix}) (cosθisinθi−sinθicosθi)(x2ix2i+1)

为了高效计算,通常预计算:

  • freqs= [ m ⋅ θ 0 , m ⋅ θ 1 , ... , m ⋅ θ d / 2 − 1 ] [m⋅θ_0,m⋅θ_1,...,m⋅θ_{d/2−1}] [m⋅θ0,m⋅θ1,...,m⋅θd/2−1]
  • 然后在运行时用 torch.cos(freqs)torch.sin(freqs) 构造旋转。
python 复制代码
class Qwen2_5_VisionRotaryEmbedding(nn.Module):
    inv_freq: torch.Tensor  # fix linting for `register_buffer`

    def __init__(self, dim: int, theta: float = 10000.0) -> None:
        super().__init__()
        inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

    def forward(self, seqlen: int) -> torch.Tensor:
        seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.outer(seq, self.inv_freq)
        return freqs

标准的 2D 旋转位置编码的频率生成器 ,但这里只生成一维序列的 RoPE 基础频率(后续通过 (x, y) 坐标组合实现 2D)。

  • dim:RoPE 的维度(通常是 head_dim 的一半,因为 RoPE 作用于复数形式,每两个维度组成一个旋转对)。
  • theta=10000.0:RoPE 的基频缩放因子(与原始 Transformer 一致)。

rot_pos_emb(self, grid_thw) 方法:

python 复制代码
def rot_pos_emb(self, grid_thw):
    pos_ids = []
    for t, h, w in grid_thw:
        # 1. 构建高度方向位置 ID
        hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
        hpos_ids = hpos_ids.reshape(
            h // self.spatial_merge_size,
            self.spatial_merge_size,
            w // self.spatial_merge_size,
            self.spatial_merge_size,
        )
        hpos_ids = hpos_ids.permute(0, 2, 1, 3)
        hpos_ids = hpos_ids.flatten()

        # 2. 构建宽度方向位置 ID
        wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
        wpos_ids = wpos_ids.reshape(
            h // self.spatial_merge_size,
            self.spatial_merge_size,
            w // self.spatial_merge_size,
            self.spatial_merge_size,
        )
        wpos_ids = wpos_ids.permute(0, 2, 1, 3)
        wpos_ids = wpos_ids.flatten()

        # 3. 合并 h/w,重复 t 次(处理时间维度)
        pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
    
    pos_ids = torch.cat(pos_ids, dim=0)
    
    # 4. 预计算最大 RoPE 表
    max_grid_size = grid_thw[:, 1:].max()
    rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
    
    # 5. 根据 pos_ids 索引 RoPE
    rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
    return rotary_pos_emb

输入:grid_thw

  • 形状为 (B, 3)(num_images, 3),每一行 [t, h, w] 表示一个图像/视频片段被划分成 h × w 个空间 patch,共 t 帧(时间维度)。

关键参数:self.spatial_merge_size

  • 用于 空间合并(spatial merging) ,例如 spatial_merge_size=2 表示每 2×2 个原始 patch 被合并为一个 token。
  • 因此,实际 token 数量 = (h // merge) × (w // merge) × t

步骤详解:

  • 1. 构建高度位置 ID(hpos_ids
    • torch.arange(h).unsqueeze(1).expand(-1, w):生成 h × w 的网格,每行是 [0, 0, ..., 0], [1, 1, ..., 1], ...
    • reshape(h//M, M, w//M, M):将网格划分为 (h//M × w//M)M×M 块(M = spatial_merge_size
    • permute(0, 2, 1, 3):调整顺序为 [block_row, block_col, in_block_h, in_block_w]
    • flatten():展平后,顺序变为:先遍历每个 block,再遍历 block 内部的行和列,得到每个合并 token 的"代表高度位置"。
  • 2. 宽度位置 ID(wpos_ids)同理
    • 生成 h × w 的列索引网格。
    • 同样 reshape + permute + flatten,得到每个合并 token 的"代表宽度位置"。
  • 3. 合并 hw,并复制 t
    • torch.stack([hpos_ids, wpos_ids], dim=-1):每个 token 对应一个 (h_id, w_id) 坐标。
    • .repeat(t, 1):因为时间维度有 t 帧,每帧共享相同的空间位置
  • 4. 预计算 RoPE 表
    • max_grid_size = max(h, w):取所有图像中最大边长。
    • self.rotary_pos_emb(max_grid_size):调用 Qwen2_5_VisionRotaryEmbeddingforward,生成 (max_grid_size, dim//2) 的频率表。
  • 5. 索引并展平
    • rotary_pos_emb_full[pos_ids]
      • pos_ids 形状为 (N, 2),其中 N = total tokens
      • 索引后得到 (N, 2, dim//2):每个 token 有两个 RoPE 向量(一个来自 h,一个来自 w)
    • .flatten(1):展平为 (N, dim),拼接 hw 的 RoPE,形成完整的 2D 位置编码。

Step3: 构建窗口注意力(windowed attention)的索引

Qwen2.5-VL 视觉编码器中用于构建窗口注意力(windowed attention)的索引机制 ,尤其用于支持 可变分辨率图像稀疏/分块注意力(例如在 ViT-Merger 或 FlashAttention-2 的变体中)。其核心目标是:

将不规则大小的视觉 token 网格(来自不同图像)按固定窗口大小分组,并生成:

  1. window_index:全局 token 索引,表示每个窗口内包含哪些 token;
  2. cu_window_seqlens:每个窗口的累计 token 长度(用于高效 attention 计算,如 FlashAttention 的 cu_seqlens 接口)。
python 复制代码
def get_window_index(self, grid_thw):
      window_index: list = []
      cu_window_seqlens: list = [0]
      window_index_id = 0
      vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size

      for grid_t, grid_h, grid_w in grid_thw:
          llm_grid_h, llm_grid_w = (
              grid_h // self.spatial_merge_size,
              grid_w // self.spatial_merge_size,
          )
          index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
          pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
          pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
          num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
          num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
          index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
          index_padded = index_padded.reshape(
              grid_t,
              num_windows_h,
              vit_merger_window_size,
              num_windows_w,
              vit_merger_window_size,
          )
          index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
              grid_t,
              num_windows_h * num_windows_w,
              vit_merger_window_size,
              vit_merger_window_size,
          )
          seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
          index_padded = index_padded.reshape(-1)
          index_new = index_padded[index_padded != -100]
          window_index.append(index_new + window_index_id)
          cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
          cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
          window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
      window_index = torch.cat(window_index, dim=0)

      return window_index, cu_window_seqlens

输入:grid_thw

  • 形状为 (B, 3)(num_images, 3),每一行 [t, h, w] 表示一个图像/视频片段被划分成 h × w 个空间 patch,共 t 帧(时间维度)。

关键参数说明:

python 复制代码
vit_merger_window_size = self.window_size // self.spatial_merge_size // self.patch_size
  • self.patch_size:原始图像被划分的 patch 尺寸(如 14)。
  • self.spatial_merge_size:空间合并因子(如 2),表示每 M×M 个原始 patch 合并为 1 个 token。
  • self.window_size原始像素尺度下的窗口大小(如 112 像素)。

因此:

vit_merger_window_size = 在合并后的 token 网格中,每个窗口应包含的 token 数量(边长)。

步骤详解:

  • 1. 计算合并后的 token 网格尺寸

    python 复制代码
    llm_grid_h, llm_grid_w = grid_h // self.spatial_merge_size, grid_w // self.spatial_merge_size
    • 经过空间合并后,每张图像变为 t × llm_grid_h × llm_grid_w视觉 token
  • 2. 为当前图像生成全局 token 索引(局部)

    python 复制代码
    index = torch.arange(grid_t * llm_grid_h * llm_grid_w).reshape(grid_t, llm_grid_h, llm_grid_w)
    • index 是一个 (t, H, W) 张量,每个位置的值是该 token 在 当前图像内的局部索引(从 0 开始)。
  • 3. 计算 padding 以对齐窗口边界

    python 复制代码
    pad_h = vit_merger_window_size - llm_grid_h % vit_merger_window_size
    pad_w = vit_merger_window_size - llm_grid_w % vit_merger_window_size
    • 目的:将 H × W token 网格 补零到能被 vit_merger_window_size 整除,以便划分为完整窗口。
    • 使用 100 作为 padding 值(后续用于 mask 掉无效 token)。
  • 4. 计算窗口数量

    python 复制代码
    num_windows_h = (llm_grid_h + pad_h) // vit_merger_window_size
    num_windows_w = (llm_grid_w + pad_w) // vit_merger_window_size
    • 总窗口数 = t × num_windows_h × num_windows_w
  • 5. 填充并重排为窗口格式

    python 复制代码
    index_padded = F.pad(index, (0, pad_w, 0, pad_h), "constant", -100)
    index_padded = index_padded.reshape(
        grid_t,
        num_windows_h,
        vit_merger_window_size,
        num_windows_w,
        vit_merger_window_size,
    )
    index_padded = index_padded.permute(0, 1, 3, 2, 4).reshape(
        grid_t,
        num_windows_h * num_windows_w,
        vit_merger_window_size,
        vit_merger_window_size,
    )
    • reshape + permute:将 (t, H, W) 重排为 (t, num_win_h, win_h, num_win_w, win_w),再合并窗口维度。
    • 最终形状:(t, num_windows, win_h, win_w)
  • 6. 计算每个窗口的实际 token 数量(非 padding)

    python 复制代码
    seqlens = (index_padded != -100).sum([2, 3]).reshape(-1)
    • seqlens:长度为 t × num_windows,每个元素是该窗口中 有效 token 的数量
  • 7. 提取有效 token 的全局索引

    python 复制代码
    index_padded = index_padded.reshape(-1)
    index_new = index_padded[index_padded != -100]
    window_index.append(index_new + window_index_id)
    • window_index_id 在每张图像后累加,确保索引不冲突。
    • + window_index_id:将其转换为 全局 token 索引(跨所有图像)。
    • index_new:当前图像所有窗口中 有效 token 的局部索引列表
  • 8. 构建累计长度(cu_seqlens)

    python 复制代码
    cu_seqlens_tmp = seqlens.cumsum(0) * self.spatial_merge_unit + cu_window_seqlens[-1]
    cu_window_seqlens.extend(cu_seqlens_tmp.tolist())
  • 9. 更新全局索引偏移

    python 复制代码
    window_index_id += (grid_t * llm_grid_h * llm_grid_w).item()
    • 为下一张图像的 token 索引做准备。

Step4: VisionBlock

python 复制代码
class Qwen2_5_VLVisionBlock(GradientCheckpointingLayer):
    def __init__(self, config, attn_implementation: str = "sdpa") -> None:
        super().__init__()
        self.norm1 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
        self.norm2 = Qwen2RMSNorm(config.hidden_size, eps=1e-6)
        self.attn = Qwen2_5_VLVisionAttention(config=config)
        self.mlp = Qwen2_5_VLMLP(config, bias=True)

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: Optional[torch.Tensor] = None,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
        **kwargs,
    ) -> torch.Tensor:
        hidden_states = hidden_states + self.attn(
            self.norm1(hidden_states),
            cu_seqlens=cu_seqlens,
            rotary_pos_emb=rotary_pos_emb,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
        return hidden_states

RMSNorm(Root Mean Square Layer Normalization)

RMSNorm(Root Mean Square Layer Normalization),用于替代标准 LayerNorm 的归一化方法。它在结构上比传统 LayerNorm 更简单、计算更高效,且在大语言模型中表现良好。

python 复制代码
class Qwen2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps: float = 1e-6) -> None:
        """
        Qwen2RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

RMSNorm 的数学原理

RMSNorm 的核心思想是:只对激活值进行缩放(scale),不进行平移(no bias/centering)

对于一个向量 x ∈ R d x∈R^d x∈Rd,RMSNorm 定义为:

R M S N o r m ( x ) = x R M S ( x ) ⊙ γ RMSNorm(x)=\frac {x} {RMS(x)}⊙γ RMSNorm(x)=RMS(x)x⊙γ

其中:

  • R M S ( x ) = 1 d ∑ i = 1 d x i 2 RMS(x)=\sqrt {\frac {1} {d}∑_{i=1}^d x_i^2} RMS(x)=d1∑i=1dxi2 是 均方根(Root Mean Square)
  • γ ∈ R d γ∈R^d γ∈Rd是可学习的缩放参数(即 self.weight
  • ⊙ ⊙ ⊙表示逐元素相乘(Hadamard product)

注意:没有减去均值(no centering),这是与标准 LayerNorm 的主要区别。

自注意力机制(Vision Attention)

Qwen2.5-VL 多模态大模型中视觉编码器的自注意力机制(Vision Attention) ,其设计专门用于高效处理变长、高分辨率图像 (或视频帧)产生的视觉 token 序列。它支持多种注意力后端(如 Eager、FlashAttention-2),并集成了旋转位置编码(RoPE)

python 复制代码
class Qwen2_5_VLVisionAttention(nn.Module):
    def __init__(self, config: Qwen2_5_VLVisionConfig) -> None:
        super().__init__()
        self.dim = config.hidden_size
        self.num_heads = config.num_heads
        self.head_dim = self.dim // self.num_heads
        self.num_key_value_groups = 1  # needed for eager attention
        self.qkv = nn.Linear(self.dim, self.dim * 3, bias=True)
        self.proj = nn.Linear(self.dim, self.dim)
        self.scaling = self.head_dim**-0.5
        self.config = config
        self.attention_dropout = 0.0
        self.is_causal = False

    def forward(
        self,
        hidden_states: torch.Tensor,
        cu_seqlens: torch.Tensor,
        rotary_pos_emb: Optional[torch.Tensor] = None,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
        **kwargs,
    ) -> torch.Tensor:
        seq_length = hidden_states.shape[0]
        query_states, key_states, value_states = (
            self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
        )
        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)

        query_states = query_states.transpose(0, 1).unsqueeze(0)
        key_states = key_states.transpose(0, 1).unsqueeze(0)
        value_states = value_states.transpose(0, 1).unsqueeze(0)

        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        if self.config._attn_implementation == "flash_attention_2":
            # Flash Attention 2: Use cu_seqlens for variable length attention
            max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
            attn_output, _ = attention_interface(
                self,
                query_states,
                key_states,
                value_states,
                attention_mask=None,
                scaling=self.scaling,
                dropout=0.0 if not self.training else self.attention_dropout,
                cu_seq_lens_q=cu_seqlens,
                cu_seq_lens_k=cu_seqlens,
                max_length_q=max_seqlen,
                max_length_k=max_seqlen,
                is_causal=False,
                **kwargs,
            )
        else:
            # Other implementations: Process each chunk separately
            lengths = cu_seqlens[1:] - cu_seqlens[:-1]
            splits = [
                torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
            ]

            attn_outputs = [
                attention_interface(
                    self,
                    q,
                    k,
                    v,
                    attention_mask=None,
                    scaling=self.scaling,
                    dropout=0.0 if not self.training else self.attention_dropout,
                    is_causal=False,
                    **kwargs,
                )[0]
                for q, k, v in zip(*splits)
            ]
            attn_output = torch.cat(attn_outputs, dim=1)

        attn_output = attn_output.reshape(seq_length, -1).contiguous()
        attn_output = self.proj(attn_output)
        return attn_output
  • 步骤 1:QKV 投影 + 分头

    python 复制代码
    query_states, key_states, value_states = (
        self.qkv(hidden_states)
        .reshape(seq_length, 3, self.num_heads, -1)
        .permute(1, 0, 2, 3)
        .unbind(0)
    )
    • self.qkv(hidden_states): (N, 3*D)
    • reshape(N, 3, H, D/H): (N, 3, num_heads, head_dim)
    • permute(1, 0, 2, 3): (3, N, H, head_dim)
    • unbind(0): 拆分为三个 (N, H, head_dim) 张量
  • 步骤 2:应用旋转位置编码(RoPE)

    python 复制代码
    cos, sin = position_embeddings
    query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)

    旋转位置编码(Rotary Position Embedding, RoPE)在视觉模型中的应用 ,核心思想是通过 向量旋转 将位置信息注入到注意力机制的 querykey 向量中。

    python 复制代码
    def apply_rotary_pos_emb_vision(
        q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
    ) -> tuple[torch.Tensor, torch.Tensor]:
        orig_q_dtype = q.dtype
        orig_k_dtype = k.dtype
        q, k = q.float(), k.float()
        cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
        q_embed = (q * cos) + (rotate_half(q) * sin)
        k_embed = (k * cos) + (rotate_half(k) * sin)
        q_embed = q_embed.to(orig_q_dtype)
        k_embed = k_embed.to(orig_k_dtype)
        return q_embed, k_embed
    python 复制代码
    def rotate_half(x):
        """Rotates half the hidden dims of the input."""
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    RoPE 的目标是:让 attention 能感知 token 的相对位置

    对于两个位置 m m m 和 n n n,RoPE 保证:

    A t t e n t i o n ( q m , k n ) = f ( q , k , m − n ) Attention(q_m,k_n)=f(q,k,m−n) Attention(qm,kn)=f(q,k,m−n)

    即只依赖相对位置 m − n m−n m−n,这在视觉和语言任务中都非常重要。

    实现方式:复数旋转

    • 将向量 x ∈ R d x∈R^d x∈Rd( ∗ d ∗ *d* ∗d∗ 为偶数)视为 d / 2 d/2 d/2 个复数:

    x c o m p l e x = ( x 0 + i x d / 2 ,    x 1 + i x d / 2 + 1 ,    ... ,    x d / 2 − 1 + i x d − 1 ) x_{complex}=(x_0+ix_{d/2},  x_1+ix_{d/2+1},  ...,  x_{d/2−1}+ix_{d−1}) xcomplex=(x0+ixd/2,  x1+ixd/2+1,  ...,  xd/2−1+ixd−1)

    • 对位置 ∗ m ∗ *m* ∗m∗,旋转角度 ∗ θ i = m ⋅ θ − 2 i / d ∗ *θ_i=m⋅θ^{−2i/d}* ∗θi=m⋅θ−2i/d∗。
    • 旋转操作: ∗ x r o t = x ⋅ e i θ = x ⋅ ( c o s θ + i s i n θ ) ∗ *x_{rot}=x⋅e^{iθ}=x⋅(cosθ+isinθ)* ∗xrot=x⋅eiθ=x⋅(cosθ+isinθ)∗
    • 展开为实部和虚部:

    x i n e w x i + d / 2 n e w \] = \[ c o s θ i − s i n θ i s i n θ i c o s θ i \] \[ x i x i + d / 2 \] \[\\begin{matrix} x_{i}\^{new}\\\\ x_{i+d/2}\^{new} \\end{matrix}\]=\[\\begin{matrix} cos \\theta_i\&-sin\\theta_i\\\\ sin\\theta_i\&cos \\theta_i \\end{matrix}\]\[\\begin{matrix} x_{i}\\\\ x_{i+d/2} \\end{matrix}\] \[xinewxi+d/2new\]=\[cosθisinθi−sinθicosθi\]\[xixi+d/2

    rotate_half:实现 90° 旋转的向量操作

    python 复制代码
    def rotate_half(x):
        x1 = x[..., : x.shape[-1] // 2]      # 前半部分:x0, x1, ..., x_{d/2-1}
        x2 = x[..., x.shape[-1] // 2 :]      # 后半部分:x_{d/2}, ..., x_{d-1}
        return torch.cat((-x2, x1), dim=-1)  # 拼接为 [-x_{d/2}, ..., -x_{d-1}, x0, ..., x_{d/2-1}]

    应用 RoPE 公式

    python 复制代码
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    • 前半部分( i i i)

    q i n e w = q i ⋅ c o s ⁡ θ i − q i + d / 2 ⋅ s i n ⁡ θ i q_i^{new}=qi⋅cos⁡θ_i−q_{i+d/2}⋅sin⁡θ_i qinew=qi⋅cos⁡θi−qi+d/2⋅sin⁡θi

    • 后半部分( i + d / 2 i+d/2 i+d/2)

    q i + d / 2 n e w = q i ⋅ s i n ⁡ θ i + q i + d / 2 ⋅ c o s ⁡ θ i q_{i+d/2}^{new}=q_i⋅sin⁡θ_i+q_{i+d/2}⋅cos⁡θ_i qi+d/2new=qi⋅sin⁡θi+qi+d/2⋅cos⁡θi

  • 步骤 3:调整维度以适配 attention 接口

    python 复制代码
    query_states = query_states.transpose(0, 1).unsqueeze(0)  # (1, H, N, head_dim)
    key_states   = key_states.transpose(0, 1).unsqueeze(0)    # (1, H, N, head_dim)
    value_states = value_states.transpose(0, 1).unsqueeze(0)  # (1, H, N, head_dim)
  • 步骤 4:分路径处理注意力(关键!)

    • 路径 A:FlashAttention-2(推荐)

      python 复制代码
      if self.config._attn_implementation == "flash_attention_2":
          max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
          attn_output, _ = attention_interface(
              self,
              query_states,
              key_states,
              value_states,
              cu_seq_lens_q=cu_seqlens,
              cu_seq_lens_k=cu_seqlens,
              max_length_q=max_seqlen,
              max_length_k=max_seqlen,
              is_causal=False,
              ...
          )
      • cu_seqlens:告诉 FlashAttention 每张图像的 token 边界
      • max_seqlen:最长子序列长度(用于 kernel 显存分配)
      • 优势
        • O(1) 显存复杂度(不存储 attention 矩阵)
        • 支持变长序列
        • 极速计算
    • 路径 B:Eager / 其他实现(fallback)

      python 复制代码
      else:
          lengths = cu_seqlens[1:] - cu_seqlens[:-1]  
          splits = [torch.split(tensor, lengths.tolist(), dim=2) for tensor in (q, k, v)]
          attn_outputs = [
              attention_interface(self, q_i, k_i, v_i, is_causal=False, ...)[0]
              for q_i, k_i, v_i in zip(*splits)
          ]
          attn_output = torch.cat(attn_outputs, dim=1)
      • 手动拆分 :按 lengths(1, H, N, D) 拆分为多个 (1, H, L_i, D)
      • 对每张图像单独做 attention
      • 拼接结果 → (1, H, N, D)
      • 缺点:显存占用高(需存储 full attention matrix),速度慢
  • 步骤 5:输出投影

    python 复制代码
    attn_output = attn_output.reshape(seq_length, -1).contiguous()  # (N, D)
    attn_output = self.proj(attn_output)  # (N, D)
    • reshape(seq_length, -1): (N, H * head_dim) = (N, D)
    • contiguous(): 确保内存连续,提升后续操作效率

多层感知机(MLP)模块

Qwen2.5-VL 模型中的多层感知机(MLP)模块 ,采用的是现代大语言模型(如 LLaMA、Qwen、Mistral 等)中广泛使用的 Gated MLP(或称 SwiGLU 激活结构)

python 复制代码
class Qwen2_5_VLMLP(nn.Module):
    def __init__(self, config, bias: bool = False):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, hidden_state):
        return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))

Qwen2.5 LM Decoder

python 复制代码
class Qwen2_5_VLTextModel(Qwen2_5_VLPreTrainedModel):
    config: Qwen2_5_VLTextConfig

    def __init__(self, config: Qwen2_5_VLTextConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [Qwen2_5_VLDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self._attn_implementation = config._attn_implementation
        self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config)
        self.has_sliding_layers = "sliding_attention" in self.config.layer_types

        self.gradient_checkpointing = False
        # Initialize weights and apply final processing
        self.post_init()

    @auto_docstring
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Union[tuple, BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        # torch.jit.trace() doesn't support cache objects in the output
        if use_cache and past_key_values is None and not torch.jit.is_tracing():
            past_key_values = DynamicCache(config=self.config)

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )

        # the hard coded `3` is for temporal, height and width.
        if position_ids is None:
            position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
        elif position_ids.ndim == 2:
            position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)

        # NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
        # where each dim indicates visual spatial positions for temporal/height/width grids.
        # There are two scenarios when FA2-like packed masking might be activated.
        # 1. User specifically passed packed `position_ids` and no attention mask.
        #    In this case we expect the useer to create correct position ids for all 3 grids
        #    and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
        # 2. User runs forward with no attention mask and no position ids. In this case, position ids
        #    are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
        #    prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
        #    text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
        if position_ids.ndim == 3 and position_ids.shape[0] == 4:
            text_position_ids = position_ids[0]
            position_ids = position_ids[1:]
        else:
            # If inputs are not packed (usual 3D positions), do not prepare mask from position_ids
            text_position_ids = None

        # It may already have been prepared by e.g. `generate`
        if not isinstance(causal_mask_mapping := attention_mask, dict):
            # Prepare mask arguments
            mask_kwargs = {
                "config": self.config,
                "input_embeds": inputs_embeds,
                "attention_mask": attention_mask,
                "cache_position": cache_position,
                "past_key_values": past_key_values,
                "position_ids": text_position_ids,
            }
            # Create the masks
            causal_mask_mapping = {
                "full_attention": create_causal_mask(**mask_kwargs),
            }
            # The sliding window alternating layers are not always activated depending on the config
            if self.has_sliding_layers:
                causal_mask_mapping["sliding_attention"] = create_sliding_window_causal_mask(**mask_kwargs)

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None

        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=causal_mask_mapping[decoder_layer.attention_type],
                position_ids=text_position_ids,
                past_key_values=past_key_values,
                output_attentions=output_attentions,
                use_cache=use_cache,
                cache_position=cache_position,
                position_embeddings=position_embeddings,
                **kwargs,
            )

            hidden_states = layer_outputs[0]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        if not return_dict:
            return tuple(
                v for v in [hidden_states, past_key_values, all_hidden_states, all_self_attns] if v is not None
            )
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=past_key_values,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

Step1: Embedding

用于将离散的整数索引 (如 token ID)映射为稠密的连续向量(即嵌入向量)

python 复制代码
nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
  • config.vocab_size
    • 作用:词汇表大小,即模型能识别的不同 token 的总数。
    • 示例 :如果词汇表有 151,936 个词(如 Qwen 系列),则 vocab_size = 151936
    • 嵌入矩阵形状 :第一维为 vocab_size,每个 token 对应一行向量。
  • config.hidden_size
    • 作用:每个 token 嵌入向量的维度(也称为 embedding dimension 或 hidden size)。
    • 示例 :若 hidden_size = 4096,则每个 token 被映射为一个 4096 维的向量。
    • 与模型一致 :这个维度通常等于 Transformer 层的隐藏层大小(即 d_model

Step2: 旋转位置编码(RoPE, Rotary Position Embedding)模块

Qwen2.5-VL (多模态大语言模型)中的 旋转位置编码 (RoPE, Rotary Position Embedding)模块,专门用于处理视觉-语言模型中 非标准的位置 ID

python 复制代码
class Qwen2_5_VLRotaryEmbedding(nn.Module):
    inv_freq: torch.Tensor  # fix linting for `register_buffer`

    def __init__(self, config: Qwen2_5_VLTextConfig, device=None):
        super().__init__()
        # BC: "rope_type" was originally "type"
        if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
            self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
        else:
            self.rope_type = "default"
        self.max_seq_len_cached = config.max_position_embeddings
        self.original_max_seq_len = config.max_position_embeddings

        self.config = config
        self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]

        inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

    @torch.no_grad()
    @dynamic_rope_update  # power user: used with advanced RoPE types (e.g. dynamic rope)
    def forward(self, x, position_ids):
        # In contrast to other models, Qwen2_5_VL has different position ids for the grids
        # So we expand the inv_freq to shape (3, ...)
        inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
        position_ids_expanded = position_ids[:, :, None, :].float()  # shape (3, bs, 1, positions)

        device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False):  # Force float32
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos() * self.attention_scaling
            sin = emb.sin() * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
  • 1. 扩展 inv_freq 以匹配多维位置编码

    python 复制代码
    inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
    • self.inv_freq: 原始形状 (dim // 2,)
    • 经过索引 [None, None, :, None](1, 1, dim//2, 1)
    • .expand(3, position_ids.shape[1], -1, 1) → 形状变为 (3, bs, dim//2, 1)
  • 2. 处理 position_ids

    python 复制代码
    position_ids_expanded = position_ids[:, :, None, :].float()  # shape (3, bs, 1, positions)
    • 经过索引 [:, :, None, :](3, bs, 1, seq_len)
  • 3. 计算频率 freqs

    python 复制代码
    device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
    with torch.autocast(device_type=device_type, enabled=False):  # Force float32
        freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
    • 禁用 autocast:确保 RoPE 计算在 float32 下进行,避免精度损失(尤其在 bfloat16 训练时)。
    • 矩阵乘法:inv_freq_expanded @ position_ids_expanded
      • (3, bs, dim//2, 1) @ (3, bs, 1, seq_len) → 广播后得到 (3, bs, dim//2, seq_len)
      • .transpose(2, 3)(3, bs, seq_len, dim//2)
    • 每个坐标维度(3 个)独立计算 RoPE 频率。
  • 4. 构建 cossin

    python 复制代码
    emb = torch.cat((freqs, freqs), dim=-1)  # (3, bs, seq_len, dim)
    cos = emb.cos() * self.attention_scaling
    sin = emb.sin() * self.attention_scaling
    • freqs 拼接两次,得到完整维度 dim(因为 RoPE 假设维度成对出现:[cos θ, sin θ])。
    • 应用 attention_scaling
    • 最终 cossin 形状:(3, bs, seq_len, dim)

Setp3: 解码器层Qwen2_5_VLDecoderLayer

Qwen2.5-VL 模型中的一个解码器层Qwen2_5_VLDecoderLayer),并实现了标准的 Transformer 解码器结构 ,但针对 多模态输入 (文本 + 图像/视频)和 长上下文处理 进行了专门优化。

python 复制代码
class Qwen2_5_VLDecoderLayer(GradientCheckpointingLayer):
    def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: int):
        super().__init__()
        self.hidden_size = config.hidden_size

        if config.use_sliding_window and config._attn_implementation != "flash_attention_2":
            logger.warning_once(
                f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
                "unexpected results may be encountered."
            )
        self.self_attn = Qwen2_5_VLAttention(config, layer_idx)

        self.mlp = Qwen2MLP(config)
        self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.attention_type = config.layer_types[layer_idx]

    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
                `(batch, sequence_length)` where padding elements are indicated by 0.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention layers. See `attentions` under
                returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
                (see `past_key_values`).
            past_key_values (`Cache`, *optional*): cached past key and value projection states
            cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
                Indices depicting the position of the input sequence tokens in the sequence.
            position_embeddings (`tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
                Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
                with `head_dim` being the embedding dimension of each attention head.
            kwargs (`dict`, *optional*):
                Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
                into the model
        """

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            output_attentions=output_attentions,
            use_cache=use_cache,
            cache_position=cache_position,
            position_embeddings=position_embeddings,
            **kwargs,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        return outputs
  • 1、Attention 子层(Pre-Norm + Residual)

    python 复制代码
    residual = hidden_states
    hidden_states = self.input_layernorm(hidden_states)
    
    # Self Attention
    hidden_states, self_attn_weights = self.self_attn(
        hidden_states=hidden_states,
        attention_mask=attention_mask,
        position_ids=position_ids,
        past_key_values=past_key_values,
        output_attentions=output_attentions,
        use_cache=use_cache,
        cache_position=cache_position,
        position_embeddings=position_embeddings,
        **kwargs,
    )
    hidden_states = residual + hidden_states
    • 先 LayerNorminput_layernorm 对输入做 RMS 归一化
    • 调用多模态注意力:内部使用 M-ROPE 处理 3D 位置,GQA 降低 KV 显存,滑动窗口限制注意力范围
    • 残差连接:将原始输入加回注意力输出
  • 2、MLP 子层(Pre-Norm + Residual)

    python 复制代码
    residual = hidden_states
    hidden_states = self.post_attention_layernorm(hidden_states)
    hidden_states = self.mlp(hidden_states)
    hidden_states = residual + hidden_states
    • 第二个 LayerNormpost_attention_layernorm 对 attention 输出归一化
    • MLP 结构 :Qwen2 使用 SwiGLU 激活(比 ReLU/GELU 更强)
    • 残差连接:将 attention 输出加回 MLP 输出
  • 3、 输出

    python 复制代码
    outputs = (hidden_states,)
    
    if output_attentions:
        outputs += (self_attn_weights,)
    
    return outputs

RMSNorm(Root Mean Square Layer Normalization)

RMSNorm(Root Mean Square Layer Normalization),用于替代标准 LayerNorm 的归一化方法。它在结构上比传统 LayerNorm 更简单、计算更高效,且在大语言模型中表现良好。

python 复制代码
class Qwen2RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps: float = 1e-6) -> None:
        """
        Qwen2RMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

RMSNorm 的数学原理

RMSNorm 的核心思想是:只对激活值进行缩放(scale),不进行平移(no bias/centering)

对于一个向量 x ∈ R d x∈R^d x∈Rd,RMSNorm 定义为:

R M S N o r m ( x ) = x R M S ( x ) ⊙ γ RMSNorm(x)=\frac {x} {RMS(x)}⊙γ RMSNorm(x)=RMS(x)x⊙γ

其中:

  • R M S ( x ) = 1 d ∑ i = 1 d x i 2 RMS(x)=\sqrt {\frac {1} {d}∑_{i=1}^d x_i^2} RMS(x)=d1∑i=1dxi2 是 均方根(Root Mean Square)
  • γ ∈ R d γ∈R^d γ∈Rd是可学习的缩放参数(即 self.weight
  • ⊙ ⊙ ⊙表示逐元素相乘(Hadamard product)

注意:没有减去均值(no centering),这是与标准 LayerNorm 的主要区别。

多头注意力模块Qwen2_5_VLAttention

Qwen2.5-VL 模型中的多头注意力模块Qwen2_5_VLAttention),它针对 多模态输入 (文本 + 图像/视频)进行了专门设计,融合了 多模态旋转位置编码(M-ROPE)

python 复制代码
class Qwen2_5_VLAttention(nn.Module):
    """
    Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
    and "Generating Long Sequences with Sparse Transformers".
    """

    def __init__(self, config: Qwen2_5_VLTextConfig, layer_idx: Optional[int] = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        if layer_idx is None:
            logger.warning_once(
                f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
                "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
                "when creating this class."
            )

        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.is_causal = True
        self.attention_dropout = config.attention_dropout
        self.rope_scaling = config.rope_scaling
        self.scaling = self.head_dim**-0.5

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
        self.sliding_window = config.sliding_window if config.layer_types[layer_idx] == "sliding_attention" else None

        self.rotary_emb = Qwen2_5_VLRotaryEmbedding(config=config)

    @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,  # necessary, but kept here for BC
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_multimodal_rotary_pos_emb(
            query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
        )

        if past_key_values is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}  # Specific to RoPE models
            key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)

        attention_interface: Callable = eager_attention_forward
        if self.config._attn_implementation != "eager":
            attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]

        attn_output, attn_weights = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling,
            sliding_window=self.sliding_window,
            position_ids=position_ids,  # pass positions for FA2
            **kwargs,
        )

        attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights

步骤一、线性投影 Q/K/V

python 复制代码
query_states = self.q_proj(hidden_states)  # (B, L, num_heads * head_dim)
key_states = self.k_proj(hidden_states)    # (B, L, num_kv_heads * head_dim)
value_states = self.v_proj(hidden_states)

步骤二、reshape 为多头格式

python 复制代码
query_states = query_states.view(bsz, q_len, -1, self.head_prot).transpose(1, 2)  # (B, H, L, D_h)
key_states = key_states.view(...).transpose(1, 2)   # (B, KV_H, L, D_h)
value_states = ...                                  # (B, KV_H, L, D_h)
  • 标准 BHSD 格式(Batch, Head, Seq_Len, HeadDim)

步骤三、应用多模态 RoPE

python 复制代码
cos, sin = position_embeddings
query_states, key_states = apply_multimodal_rotary_pos_emb(
    query_states, key_states, cos, sin, self.rope_scaling["mrope_section"]
)

Qwen2.5-VL 中用于多模态输入(文本 + 图像/视频)的 多模态旋转位置编码 (Multimodal Rotary Position Embedding, M-ROPE

python 复制代码
def apply_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
    """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/).

    Explanation:
        Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding
        sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For
        vision embedding part, we apply rotary position embedding on temporal, height and width dimension separately.
        Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding.
        For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal,
        height and width) of text embedding is always the same, so the text embedding rotary position embedding has no
        difference with modern LLMs.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`):
            The position indices of the tokens corresponding to the query and key tensors. For example, this can be
            used to pass offsetted position ids when working with a KV-cache.
        mrope_section(`List(int)`):
            Multimodal rope section is for channel dimension of temporal, height and width in rope calculation.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    mrope_section = mrope_section * 2
    cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
        unsqueeze_dim
    )
    sin = torch.cat([m[i % 3] for i, m in enumerate(sin.split(mrope_section, dim=-1))], dim=-1).unsqueeze(
        unsqueeze_dim
    )

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed
  • 1. 扩展 mrope_section

    python 复制代码
    mrope_section = mrope_section * 2  # [T_dim, H_dim, W_dim] → [T_dim, H_dim, W_dim, T_dim, H_dim, W_dim]

    head_dim 切分为 6 段,对应:

    Temporal 前半段、Height 前半段、Width 前半段、Temporal 后半段、Height 后半段、Width 后半段

  • 2. 对 cossin 沿 head_dim 切分并重组

    python 复制代码
    cos = torch.cat([m[i % 3] for i, m in enumerate(cos.split(mrope_section, dim=-1))], dim=-1).unsqueeze(unsqueeze_dim)
    • **cos.split(mrope_section, dim=-1) :**将 cos(形状 [3, B, L, D])沿 head_dimD)切分为 6 块:
      • 块0:Temporal RoPE 的前 T_dim 对通道(即 2*T_dim 个浮点)
      • 块1:Height RoPE 的前 H_dim
      • 块2:Width RoPE 的前 W_dim
      • 块3:Temporal RoPE 的后 T_dim
      • 块4:Height RoPE 的后 H_dim
      • 块5:Width RoPE 的后 W_dim
    • m[i % 3] :分
      • 块0和块3对应 Temporal(时间)维度 的 RoPE 通道(前半段和后半段)
      • 块1和块4应 Height(高度)维度 的 RoPE 通道(前半段和后半段)
      • 块2和块5对应 Width(宽度)维度的 RoPE 通道(前半段和后半段)
  • 3. 对 Query 和 Key 应用旋转

    python 复制代码
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)

步骤四、KV 缓存更新(推理时)

python 复制代码
if past_key_values is not None:
    cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
    key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
  • 将当前 K/V 拼接到缓存中,用于下一次生成。
  • 传递 cos/sin 是为了支持 动态 RoPE (如 YaRN 在缓存扩展时更新位置编码)。

步骤五、选择注意力实现后端

python 复制代码
attention_interface = eager_attention_forward  # 默认 PyTorch 实现
if self.config._attn_implementation != "eager":
    attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  • 支持多种后端:
    • "eager":标准实现(O(L²) 显存)
    • "sdpa"torch.nn.functional.scaled_dot_product_attention(自动选 FlashAttention / Mem-Efficient)
    • "flash_attention_2":强制使用 FlashAttention-2(最快)
  • 通过配置 config._attn_implementation 切换。

步骤六、调用注意力函数

python 复制代码
attn_output, attn_weights = attention_interface(
    self,
    query_states,
    key_states,
    value_states,
    attention_mask,
    dropout=...,
    scaling=self.scaling,
    sliding_window=self.sliding_window,
    position_ids=position_ids,
    **kwargs,
)
python 复制代码
# "sdpa":torch.nn.functional.scaled_dot_product_attention
attn_output = torch.nn.functional.scaled_dot_product_attention(
        query,
        key,
        value,
        attn_mask=attention_mask,
        dropout_p=dropout,
        scale=scaling,
        is_causal=is_causal,
        **sdpa_kwargs,
    )
  • **is_causal(布尔值):**自动应用 下三角 causal mask(即只允许 token 关注自身及之前的位置)

步骤七、输出投影

python 复制代码
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()  # (B, L, H * D_h)
attn_output = self.o_proj(attn_output)  # (B, L, D)
  • 合并头维度,投影回 hidden_size

多层感知机(MLP)模块

Qwen2.5-VL 模型中的多层感知机(MLP)模块 ,采用的是现代大语言模型(如 LLaMA、Qwen、Mistral 等)中广泛使用的 Gated MLP(或称 SwiGLU 激活结构)

python 复制代码
class Qwen2_5_VLMLP(nn.Module):
    def __init__(self, config, bias: bool = False):
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
        self.act_fn = ACT2FN[config.hidden_act]

    def forward(self, hidden_state):
        return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
相关推荐
weixin_4093831212 小时前
强化lora训练 这次好点 下次在训练数据增加正常对话
人工智能·深度学习·机器学习·qwen
Robot侠1 天前
从 Python 到 Ollama:将微调后的 Llama-3/Qwen 一键导出为 GGUF
开发语言·python·llama·qwen
Robot侠1 天前
给自己做一个 ChatGPT:基于 Gradio 的本地 LLM 网页对话界面
人工智能·chatgpt·llm·llama·qwen·gradio
Robot侠2 天前
赋予 AI 记忆:在 RTX 3090 上搭建本地 RAG 知识库问答系统
人工智能·langchain·llm·llama·qwen·rag·chromadb
fanruitian4 天前
windows安装离线大模型 ollama qwen
大模型·qwen·ollama
weixin_409383124 天前
在kaggle训练Qwen/Qwen2.5-1.5B-Instruct 通过中二时期qq空间记录作为训练数据 训练出中二的模型为目标 第一次训练 好像太二了
人工智能·深度学习·机器学习·qwen
ASS-ASH4 天前
视觉语言大模型Qwen3-VL-8B-Instruct概述
人工智能·python·llm·多模态·qwen·视觉语言模型·vlm
北极象4 天前
千问大模型接入示例
ios·iphone·qwen
Aspect of twilight6 天前
Qwen1/2/2.5/3 VL的图像处理与位置编码方式讲解
qwen·qwen2 vl