实习07-混合大模型的学习

1 看架构(Configuration 配置文件)

首先,拿到代码,看配置文件里面的架构选型,其中 configuration.py 文件里面包含了模型每层的选型,以下是 layer 构建的代码:

python 复制代码
    @property
    def layers_block_type(self):
        """
        Returns a list of block types for each layer.
        
        Block types:
        - 'emb_adapter' : Embedding Adapter
        - 'qwen_attention': Qwen3 Attention with RoPE
        - 'qwen_mlp': Qwen3 MLP with gate/up/down structure
        - 'nq_adapter' : Nemotron + Qwen3 Adapter
        - 'mamba': Nemotron Mamba2 layer
        - 'mlp': Nemotron MLP
        - 'attention': Nemotron Attention
        
        """
        block_types = []
        for char in self.hybrid_override_pattern:
            if char == '&':
                block_types.append('qwen_attention')
            elif char == '^':
                block_types.append('qwen_mlp')
            elif char == 'M':
                block_types.append('mamba')
            elif char == '-':
                block_types.append('mlp')
            elif char == '*':
                block_types.append('attention')
            elif char == '!':
                block_types.append('nq_adapter')
            elif char == '#':
                block_types.append('emb_adapter')
            else:
                raise ValueError(f"Unknown pattern character: {char}")
        return block_types
  • 输入: 模式字符串;
  • 输出: 返回的是一个 layer 集合;

然后我们再看 hybrid_override_pattern 的内容:

python 复制代码
hybrid_override_pattern="# &^ &^ &^ &^ &^ &^ &^ &^ &^ &^ ! M-M-M-M*-M-M-M-M*-M-",  # Default 42-layer pattern
  • # = emb_adapter(嵌入适配器)
  • & = qwen_attention(Qwen 注意力)
  • ^ = qwen_mlp(Qwen MLP)
  • ! = nq_adapter(中间适配器)
  • M / - / * = Nemotron 相关层(Mamba、MLP、Attention)

2 看逻辑(QwenNemotronModel 的 forward)

我们只看关键部分:

第一部分:默认配置 & 边界条件检查(不看)

python 复制代码
 #========== 1. 参数默认值配置 ==========
 # 是否输出注意力权重:使用传入值 or 配置文件默认值
 output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
 # 是否输出所有隐藏层状态
 output_hidden_states = (
     output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
 )
 # 是否使用缓存:推理开启、训练关闭(关键!训练时use_cache=False)
 use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
 # 返回格式:字典 or 元组
 return_dict = return_dict if return_dict is not None else self.config.use_return_dict

 # ========== 2. 输入合法性检查 ==========
 # XOR 判断:必须只提供 input_ids 或 inputs_embeds 其中一个
 if (input_ids is None) ^ (inputs_embeds is not None):
     raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

第二部分:词嵌入

根据 Tokenizer 分词 chunk 后得到的 input_ids ([batch_size, max_length]),传给 Embedding 层(Shape 为 [vocab_size, hidden_size]),得到 [batch_size, max_length, hidden_size]:

python 复制代码
 # ========== 3. 词嵌入层 ==========
 # 如果没给嵌入向量,就用 token ID 查嵌入表得到
 if inputs_embeds is None:
     inputs_embeds = self.embed_tokens(input_ids)

第三部分:开启缓存之类的(不看)

python 复制代码
 # ========== 4. 训练冲突处理:梯度检查点 ≠ 缓存 ==========
 # 梯度检查点(省显存)和 KV 缓存(加速推理)互斥,训练时强制关闭缓存
 if self.gradient_checkpointing and self.training and use_cache:
     logger.warning_once(
         "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
     )
     use_cache = False

 # ========== 5. 缓存初始化检查 ==========
 # 想开启缓存但没传缓存对象 → 警告
 if use_cache and cache_params is None:
     logger.warning_once(
         "QwenNemotronModel requires an initialized `HybridQwenNemotronDynamicCache` to return a cache. "
         "None was provided, so no cache will be returned."
     )

 # 初始隐藏状态 = 词嵌入向量
 hidden_states = inputs_embeds

 # ========== 6. 位置与缓存索引 ==========
 # 缓存位置:默认从 0 到 序列长度-1
 if cache_position is None:
     cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
 # 位置ID = 缓存位置(RoPE用)
 if position_ids is None:
     position_ids = cache_position.unsqueeze(0)

第四部分:生成位置编码

python 复制代码
# ========== 7. 生成 RoPE 旋转位置编码(给Qwen注意力用) ==========
position_embeddings = self.rotary_emb(hidden_states, position_ids)
  • 根据 Embedding 的输出作为输入,算出 cos 和 sin 旋转矩阵,所以 position_embeddings 应该是一个形似 (cos, sin) 的元组;

第五部分:创建 Full Attention 和 Mamba 的掩码矩阵

这是第一点和传统 Decoder 不一样的地方:

  • 前者是用于 Full Attention 的位置编码,是一个矩阵,所以需要输入为 inputs_embedsmax_length 构建矩阵,目的是通过下三角进行掩码,输出为一个 N x N 的矩阵;
  • 后者是一个单行向量,只 padding 掩码 (没有因果掩码!),因为 Mamba 天生就是串行、单向、看不到未来的!速度极快。
python 复制代码
 # ========== 8. 创建两种掩码:分别给 Attention 和 Mamba 使用 ==========
 causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)  # 因果掩码
 mamba_mask = self._update_mamba_mask(attention_mask, cache_position)                   # Mamba专用掩码

第六部分:逐层循环 Loop

python 复制代码
 for layer_idx, layer in enumerate(self.layers):
     # 如果需要输出隐藏态,先把上一层结果存起来
     if output_hidden_states:
         all_hidden_states = all_hidden_states + (hidden_states,)
  • 存储上一层的隐藏状态;

第七部分:"梯度检查点" 换取 "低显存占用"

模型在前向传播的时候,会有很多中间结果,如果都存储起来会导致显存爆炸;所以我们会采用 "懒加载" 的方式,不存中间结果,当在反向传播的时候,再重新算一遍;(时间换空间)

python 复制代码
if isinstance(layer, QwenDecoderLayer):
   # 训练 + 梯度检查点:使用torch激活检查点省显存
   if self.gradient_checkpointing and self.training:
       hidden_states = self._gradient_checkpointing_func(
           layer.__call__,
           hidden_states,
           causal_mask,
           position_embeddings,
           cache_params,
           cache_position,
       )
   else:
       # 正常前向:传入 隐藏态、因果掩码、RoPE、缓存
       hidden_states = layer(
           hidden_states,
           attention_mask=causal_mask,
           position_embeddings=position_embeddings,
           cache_params=cache_params,
           cache_position=cache_position,
       )

然后我们看看 QwenDecoderLayer 是怎么做的:

python 复制代码
def forward(
    self,
    hidden_states: torch.Tensor,        # 输入:上一层的输出向量
    attention_mask: Optional[torch.Tensor] = None,  # 因果掩码
    position_ids: Optional[torch.LongTensor] = None,
    position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,  # RoPE
    cache_params: Optional[HybridQwenNemotronDynamicCache] = None,  # KV缓存
    use_cache: Optional[bool] = False,
    cache_position: Optional[torch.LongTensor] = None,
    **kwargs,
) -> torch.Tensor:

    # ===================== 第一部分:自注意力模块 =====================
    residual = hidden_states  # 保存残差(shortcut)
    hidden_states = self.input_layernorm(hidden_states)  # 先归一化(Pre-LN)
    
    # 进入 Qwen3 自注意力(带RoPE、QK归一化、KV缓存)
    hidden_states, _ = self.self_attn(
        hidden_states=hidden_states,
        attention_mask=attention_mask,
        position_embeddings=position_embeddings,
        cache_params=cache_params,
        cache_position=cache_position,
        **kwargs,
    )
    
    hidden_states = residual + hidden_states  # 残差连接:注意力输出 + 原始输入


    # ===================== 第二部分:MLP 前馈网络 =====================
    residual = hidden_states  # 再次保存残差
    hidden_states = self.post_attention_layernorm(hidden_states)  # 第二次归一化
    
    # MLP 两层线性变换 + 激活
    hidden_states = self.mlp(hidden_states)
    
    hidden_states = residual + hidden_states  # 第二次残差连接

    return hidden_states  # 返回本层最终输出
  • 输入做一次归一化,局部的输出也会做一次归一化;
  • 局部输出必残差一次;
txt 复制代码
输入
  │
  ├─→ 残差1
  │
  └─→ LayerNorm → Self-Attention → + ←─ 残差1
          │
          ├─→ 残差2
          │
          └─→ LayerNorm → MLP → + ←─ 残差2
                           │
                         输出

第八部分:MiddleAdapter 中间层适配器
MiddleAdapter 就做了一个非线性变换,别无其它。

python 复制代码
class MiddleAdapter(nn.Module):
    def __init__(self, config, layer_idx):
        self.pre_norm = RMSNorm(...)
        self.linear1 = Linear(...)
        self.act1 = SiLU()
        self.linear2 = Linear(...)

    def forward(self, x):
        residual = x
        x = self.pre_norm(x)
        x = self.linear1(x)
        x = self.act1(x)
        x = self.linear2(x)
        return x + residual
  • 激活函数:选用的是 swish 函数,我的理解是增强 QwenDecoder 输出的特征表示;
  • 为什么用 swish ? 因为 ReLU 函数不支持负值,导致信息丢失;Sigmoid 容易导致梯度爆炸;GELU 计算比较慢。而 Swish 平滑、快、非单调(具有选择性抑制的作用),支持负数;

激活函数复习:
https://blog.csdn.net/weixin_57128596/article/details/157550165?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522f1c56c3dfac289d37914aecf7dc1ff61%2522%252C%2522scm%2522%253A%252220140713.130102334.pc%255Fblog.%2522%257D&request_id=f1c56c3dfac289d37914aecf7dc1ff61&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2blogfirst_rank_ecpm_v1~rank_v31_ecpm-1-157550165-null-null.nonecase&utm_term=%E6%BF%80%E6%B4%BB%E5%87%BD%E6%95%B0&spm=1018.2226.3001.4450

第九部分:Nemotron 模块

QwenDecoder 同级,在最顶层接口 QwenNemotronModel_build_layers() 里面(初始化的时候就做好了)

python 复制代码
# 最顶层接口
class QwenNemotronModel(QwenNemotronPreTrainedModel):
    """
    The bare Qwen3-Nemotron hybrid model outputting raw hidden-states without any specific head on top.
    """
    
    def __init__(self, config: QwenNemotronConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        
        # Embedding layer (Qwen3 style)
        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        
        # Build layers based on hybrid_override_pattern
        self.layers = nn.ModuleList()
        # 构建模型结构
        self._build_layers(config)
        ......
        ......

然后,我们看 Nemotron 模块前向传播的逻辑:

python 复制代码
......
......
# ============== 分支C:Nemotron 模块(Mamba/Attention/MLP) ==============
else:
    block_type = layer.block_type  # 获取当前Nemotron模块类型
    # 不同模块用不同掩码
    if block_type == "mamba":
        layer_mask = mamba_mask       # Mamba用自己的掩码
    elif block_type == "attention":
        layer_mask = causal_mask      # Attention用标准因果掩码
    elif block_type == "mlp":
        layer_mask = None             # MLP不需要掩码
    else:
        raise ValueError(f"Invalid block_type: {self.block_type}")
    
    # 梯度检查点 or 正常前向
    if self.gradient_checkpointing and self.training:
        hidden_states = self._gradient_checkpointing_func(
            layer.__call__,
            hidden_states,
            cache_params,
            cache_position,
            layer_mask,
        )
    else:
        hidden_states = layer(
            hidden_states,
            cache_params=cache_params,
            cache_position=cache_position,
            attention_mask=layer_mask,
        )

第九部分(补充):Nemotron 模块前向传播逻辑

核心代码如下所示,很明显能看到,是一个 mambaattentionmlp 三选一的结构;数据的走向还是比较传统的,先归一化一下,然后非线性弄一下,最后接一个残差;

python 复制代码
def forward(
    self,
    hidden_states: torch.Tensor,
    cache_params: Optional[HybridQwenNemotronDynamicCache] = None,
    cache_position: Optional[torch.LongTensor] = None,
    attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    # with torch.cuda.stream(torch.cuda.default_stream(hidden_states.device)):
    # * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs
  
    residual = hidden_states
    # 归一化
    hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
    if self.residual_in_fp32:
        residual = residual.to(torch.float32)
        
		# 非线性变换:三选一
    if self.block_type == "mamba":
        hidden_states = self.mixer(
            hidden_states, cache_params=cache_params, cache_position=cache_position
        )
    elif self.block_type == "attention":
        hidden_states = self.mixer(
            hidden_states, cache_position=cache_position,
        )
        hidden_states = hidden_states[0]
    elif self.block_type == "mlp":
        hidden_states = self.mixer(hidden_states)
    else:
            raise ValueError(f"Invalid block_type: {self.block_type}")
    
    # 残差
    hidden_states = residual + hidden_states
    return hidden_states

第十部分:Namotron(Mamaba) 模块在 forward 的核心流程:

python 复制代码
def cuda_kernels_forward(
    self,
    hidden_states: torch.Tensor,            # 输入特征 (batch, seq_len, hidden_size)
    cache_params: Optional[HybridQwenNemotronDynamicCache] = None,  # 推理缓存(conv + ssm)
    cache_position: Optional[torch.LongTensor] = None,              # 当前在序列的位置
    attention_mask: Optional[torch.Tensor] = None,                  # padding 掩码(Mamba 只用这个)
):
    # ==============================================
    # 1. 预处理:把 padding 位置的特征抹零 + 输入投影
    # ==============================================
    hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
    projected_states = self.in_proj(hidden_states)  # 【😊第一部分:核心投影】把输入映射成 Mamba 需要的 5 部分

    batch_size, seq_len, _ = hidden_states.shape
    groups_time_state_size = self.n_groups * self.ssm_state_size

    # 计算一下投影后的各部分维度(不用深究,就是切分用)
    d_mlp = (
        projected_states.shape[-1]
        - 2 * self.intermediate_size
        - 2 * self.n_groups * self.ssm_state_size
        - self.num_heads
    ) // 2

    # ==============================================
    # 分支 A:推理阶段(增量生成,一次只来一个 token)
    # ==============================================
    if cache_params is not None and cache_position is not None and cache_position[0] > 0:
        # 【😊第一部分:核心投影】把投影结果切分成 5 份:Mamba 核心五件套
        _, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
            [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
        )
        
        # --------------------
        # 2. 因果卷积(局部上下文信息,左边几个Token信息糅合在一起)【😊第二部分:因果卷积,得到局部特征】
        # --------------------
        hidden_states_B_C = causal_conv1d_update(
            hidden_states_B_C,
            cache_params.conv_states[self.layer_idx],  # 读取缓存
            self.conv1d.weight.squeeze(1),
            self.conv1d.bias,
            self.activation,
        )
        # 卷积输出的内容:【 局部信息特征 x | SSM参数 B(当前输入,对状态有多强的影响) | SSM参数 C(当前状态,如何输出成结果) ]
        
        # 再切:得到 真正的输入x、B、C(SSM 三大参数)
        hidden_states, B, C = torch.split(
            hidden_states_B_C,
            [self.intermediate_size, groups_time_state_size, groups_time_state_size],
            dim=-1,
        )
        # B: 要把多少信息写入记忆?
        # C: 怎么从记忆里提取信息,输出当前结果
        
        # --------------------
        # 3. SSM 核心(Mamba 灵魂)【😊第三部分:抓全局依赖】
        # --------------------
        A = -torch.exp(self.A_log.float())  # A 是对数初始化,转成实数(记忆遗忘率)
        A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
        dt = dt[:, :, None].expand(-1, -1, self.head_dim) # 时间步
        dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
        D = self.D[:, None, ...].expand(-1, self.head_dim)
        
        B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
        C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
        hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
        
        # 【核心】状态空间更新:从历史状态生成当前输出
        hidden_states = selective_state_update(
            cache_params.ssm_states[self.layer_idx],
            hidden_states_reshaped,
            dt, A, B, C, D,
            z=None,
            dt_bias=dt_bias,
            dt_softplus=True,
        )

        # 形状恢复 + 门控归一化
        hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
        hidden_states = self.norm(hidden_states, gate)

        # 4. 输出投影,变回模型维度
        out = self.out_proj(hidden_states)[:, None, ...]

    # ==============================================
    # 分支 B:训练阶段 / 首次推理(整段序列输入)
    # ==============================================
    else:
        A = -torch.exp(self.A_log.float())
        dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
        
        # --------------------
        # 训练:直接调用 CUDA 融合核函数(超快)
        # --------------------
        if self.training and cache_params is None:
            out = mamba_split_conv1d_scan_combined(
                projected_states,
                self.conv1d.weight.squeeze(1),
                self.conv1d.bias,
                self.dt_bias,
                A,
                D=self.D,
                chunk_size=self.chunk_size,
                activation=self.activation,
                rmsnorm_weight=self.norm.weight,
                rmsnorm_eps=self.norm.variance_epsilon,
                outproj_weight=self.out_proj.weight,
                outproj_bias=self.out_proj.bias,
                headdim=self.head_dim,
                ngroups=self.n_groups,
                norm_before_gate=False,
                return_final_states=False,
                **dt_limit_kwargs,
            )
        
        # --------------------
        # 非训练(推理初始化)
        # --------------------
        else:
            # 切分投影
            _, _, gate, hidden_states_B_C, dt = projected_states.split(
                [d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
            )

            # 2. 因果卷积
            if cache_params is not None:
                hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
                conv_states = nn.functional.pad(
                    hidden_states_B_C_transposed,
                    (cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0),
                )
                cache_params.update_conv_state(
                    layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True
                )
            
            # 调用 CUDA 卷积核
            if self.activation not in ["silu", "swish"]:
                hidden_states_B_C = self.act(
                    self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
                )
            else:
                hidden_states_B_C = causal_conv1d_fn(
                    x=hidden_states_B_C.transpose(1, 2),
                    weight=self.conv1d.weight.squeeze(1),
                    bias=self.conv1d.bias,
                    activation=self.activation,
                ).transpose(1, 2)
            
            hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
            
            # 切分出 x, B, C
            hidden_states, B, C = torch.split(
                hidden_states_B_C,
                [self.intermediate_size, groups_time_state_size, groups_time_state_size],
                dim=-1,
            )

            # 3. SSM 块扫描(整段序列扫一遍)
            scan_output, ssm_state = mamba_chunk_scan_combined(
                hidden_states.view(batch_size, seq_len, -1, self.head_dim),
                dt,
                A,
                B.view(batch_size, seq_len, self.n_groups, -1),
                C.view(batch_size, seq_len, self.n_groups, -1),
                chunk_size=self.chunk_size,
                D=self.D,
                z=None,
                dt_bias=self.dt_bias,
                dt_softplus=True,
                **dt_limit_kwargs,
            )
            
            # 保存 SSM 状态到缓存
            if ssm_state is not None and cache_params is not None:
                cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state)
            
            scan_output = scan_output.view(batch_size, seq_len, -1)

            # 门控归一化
            scan_output = self.norm(scan_output, gate)

            # 4. 最终输出
            out = self.out_proj(scan_output)
    
    return out

第十一部分:Namotron 模块中 Attention 的实现:

python 复制代码
class NemotronFlashAttention2(NemotronAttention):
    """
    FlashAttention2 加速版的 Nemotron 注意力
    继承自普通 NemotronAttention,权重不变,只重写 forward 实现
    作用:GPU 上速度更快、显存占用更低
    """
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # FlashAttention 版本兼容处理
        # 2.1 以下版本的因果掩码是 top-left 对齐,2.1+ 改成了 bottom-right
        # 这里标记一下版本,避免掩码错误
        self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()

    def forward(
        self,
        hidden_states: torch.Tensor,       # 输入特征 [batch, seq_len, hidden_size]
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        cache_params: Optional[HybridQwenNemotronDynamicCache] = None,  # KV缓存
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ):
        # 输入形状:batch_size, sequence_length, hidden_size
        bsz, q_len, _ = hidden_states.size()

        # ========== 1. QKV 线性投影 ==========
        # 把输入分别投影成 Q、K、V
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # ========== 2. 形状拆分(给多头注意力用) ==========
        # FlashAttention 要求形状:[batch, seq_len, n_heads, head_dim]
        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        # ========== 3. KV 缓存更新(推理加速) ==========
        if cache_params is not None:
            key_states, value_states = cache_params.update(key_states, value_states, self.layer_idx)

        # ========== 4. GQA 分组注意力:K/V 头重复 ==========
        # 多个 Q 头共享一组 K/V 头,需要把 K/V 重复到和 Q 一样多
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        # dropout 概率:推理时=0,训练时=配置值
        dropout_rate = 0.0 if not self.training else self.attention_dropout

        # ========== 5. 精度处理(非常重要!GPU 必须半精度) ==========
        # 如果输入是 float32,强制转回 float16/bfloat16,否则 FlashAttention 会报错/变慢
        input_dtype = query_states.dtype
        if input_dtype == torch.float32:
            # 自动获取正确的 GPU 精度
            if torch.is_autocast_enabled():
                target_dtype = torch.get_autocast_gpu_dtype()
            elif hasattr(self.config, "_pre_quantization_dtype"):
                target_dtype = self.config._pre_quantization_dtype
            else:
                target_dtype = self.q_proj.weight.dtype

            # 警告 + 强制转换精度
            logger.warning_once(
                f"输入是 float32,FlashAttention2 需要半精度,已强制转为 {target_dtype}."
            )
            query_states = query_states.to(target_dtype)
            key_states = key_states.to(target_dtype)
            value_states = value_states.to(target_dtype)

        # ========== 6. 最终形状转换,适配 FlashAttention ==========
        key_states = key_states.transpose(1, 2)
        value_states = value_states.transpose(1, 2)

        # ========== 7. 调用 FlashAttention2 核心计算!==========
        attn_output = _flash_attention_forward(
            query_states,
            key_states,
            value_states,
            attention_mask,
            q_len,
            dropout=dropout_rate,
            sliding_window=getattr(self.config, "sliding_window", None),
            is_causal=self.is_causal,           # 因果注意力(看不到未来)
            use_top_left_mask=self._flash_attn_uses_top_left_mask,
        )

        # ========== 8. 多头结果拼接 ==========
        # [batch, seq_len, n_heads, head_dim] → [batch, seq_len, hidden_size]
        attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()

        # ========== 9. 输出投影 ==========
        attn_output = self.o_proj(attn_output)

        # 不需要输出注意力权重,就设为 None(省显存)
        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, cache_params

第十二部分:Namotron 模块中 MLP 的实现:

python 复制代码
class NemotronMLP(nn.Module):
    """Nemotron-style MLP with relu2 activation."""
    
    def __init__(self, config: QwenNemotronConfig, layer_idx: Optional[int] = None):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.hidden_size = config.hidden_size 
        self.intermediate_size = config.intermediate_size
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
        self.act_fn = ACT2FN[config.mlp_hidden_act]

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.down_proj(self.act_fn(self.up_proj(x)))
相关推荐
华清远见IT开放实验室2 小时前
AI 算法核心知识清单(深度实战版1)
人工智能·python·深度学习·学习·算法·机器学习·ai
_李小白2 小时前
【OSG学习笔记】Day 40: EventCallback(事件回调)
笔记·学习
世人万千丶2 小时前
开源鸿蒙跨平台Flutter开发:步数统计应用
学习·flutter·华为·开源·harmonyos·鸿蒙
爱宇阳2 小时前
Supabase Self-Hosting with Docker 学习笔记
笔记·学习·docker
盟接之桥2 小时前
盟接之桥®说制造:从“制造”到“智造”,以品类品牌重塑制造业的生态未来
大数据·网络·人工智能·学习·制造
迷你可可小生2 小时前
图像视觉面经学习(一)
图像处理·人工智能·python·学习
自信150413057592 小时前
重生之从0开始学习c++之类与对象(中)
c++·学习
AI_零食2 小时前
开源鸿蒙跨平台Flutter开发:快递单号批量查询应用
学习·flutter·华为·开源·harmonyos·鸿蒙
四谎真好看2 小时前
Redis学习笔记(高级篇2)
redis·笔记·学习·学习笔记