1 看架构(Configuration 配置文件)
首先,拿到代码,看配置文件里面的架构选型,其中 configuration.py 文件里面包含了模型每层的选型,以下是 layer 构建的代码:
python
@property
def layers_block_type(self):
"""
Returns a list of block types for each layer.
Block types:
- 'emb_adapter' : Embedding Adapter
- 'qwen_attention': Qwen3 Attention with RoPE
- 'qwen_mlp': Qwen3 MLP with gate/up/down structure
- 'nq_adapter' : Nemotron + Qwen3 Adapter
- 'mamba': Nemotron Mamba2 layer
- 'mlp': Nemotron MLP
- 'attention': Nemotron Attention
"""
block_types = []
for char in self.hybrid_override_pattern:
if char == '&':
block_types.append('qwen_attention')
elif char == '^':
block_types.append('qwen_mlp')
elif char == 'M':
block_types.append('mamba')
elif char == '-':
block_types.append('mlp')
elif char == '*':
block_types.append('attention')
elif char == '!':
block_types.append('nq_adapter')
elif char == '#':
block_types.append('emb_adapter')
else:
raise ValueError(f"Unknown pattern character: {char}")
return block_types
- 输入: 模式字符串;
- 输出: 返回的是一个
layer集合;
然后我们再看 hybrid_override_pattern 的内容:
python
hybrid_override_pattern="# &^ &^ &^ &^ &^ &^ &^ &^ &^ &^ ! M-M-M-M*-M-M-M-M*-M-", # Default 42-layer pattern
#= emb_adapter(嵌入适配器)&= qwen_attention(Qwen 注意力)^= qwen_mlp(Qwen MLP)!= nq_adapter(中间适配器)M/-/*= Nemotron 相关层(Mamba、MLP、Attention)
2 看逻辑(QwenNemotronModel 的 forward)
我们只看关键部分:
第一部分:默认配置 & 边界条件检查(不看)
python
#========== 1. 参数默认值配置 ==========
# 是否输出注意力权重:使用传入值 or 配置文件默认值
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
# 是否输出所有隐藏层状态
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
# 是否使用缓存:推理开启、训练关闭(关键!训练时use_cache=False)
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
# 返回格式:字典 or 元组
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# ========== 2. 输入合法性检查 ==========
# XOR 判断:必须只提供 input_ids 或 inputs_embeds 其中一个
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
第二部分:词嵌入
根据 Tokenizer 分词 chunk 后得到的 input_ids ([batch_size, max_length]),传给 Embedding 层(Shape 为 [vocab_size, hidden_size]),得到 [batch_size, max_length, hidden_size]:
python
# ========== 3. 词嵌入层 ==========
# 如果没给嵌入向量,就用 token ID 查嵌入表得到
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
第三部分:开启缓存之类的(不看)
python
# ========== 4. 训练冲突处理:梯度检查点 ≠ 缓存 ==========
# 梯度检查点(省显存)和 KV 缓存(加速推理)互斥,训练时强制关闭缓存
if self.gradient_checkpointing and self.training and use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
)
use_cache = False
# ========== 5. 缓存初始化检查 ==========
# 想开启缓存但没传缓存对象 → 警告
if use_cache and cache_params is None:
logger.warning_once(
"QwenNemotronModel requires an initialized `HybridQwenNemotronDynamicCache` to return a cache. "
"None was provided, so no cache will be returned."
)
# 初始隐藏状态 = 词嵌入向量
hidden_states = inputs_embeds
# ========== 6. 位置与缓存索引 ==========
# 缓存位置:默认从 0 到 序列长度-1
if cache_position is None:
cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
# 位置ID = 缓存位置(RoPE用)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
第四部分:生成位置编码
python
# ========== 7. 生成 RoPE 旋转位置编码(给Qwen注意力用) ==========
position_embeddings = self.rotary_emb(hidden_states, position_ids)
- 根据
Embedding的输出作为输入,算出 cos 和 sin 旋转矩阵,所以position_embeddings应该是一个形似 (cos, sin) 的元组;
第五部分:创建 Full Attention 和 Mamba 的掩码矩阵
这是第一点和传统 Decoder 不一样的地方:
- 前者是用于
Full Attention的位置编码,是一个矩阵,所以需要输入为inputs_embeds的max_length构建矩阵,目的是通过下三角进行掩码,输出为一个 N x N 的矩阵; - 后者是一个单行向量,只 padding 掩码 (没有因果掩码!),因为
Mamba天生就是串行、单向、看不到未来的!速度极快。
python
# ========== 8. 创建两种掩码:分别给 Attention 和 Mamba 使用 ==========
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) # 因果掩码
mamba_mask = self._update_mamba_mask(attention_mask, cache_position) # Mamba专用掩码
第六部分:逐层循环 Loop
python
for layer_idx, layer in enumerate(self.layers):
# 如果需要输出隐藏态,先把上一层结果存起来
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
- 存储上一层的隐藏状态;
第七部分:"梯度检查点" 换取 "低显存占用"
模型在前向传播的时候,会有很多中间结果,如果都存储起来会导致显存爆炸;所以我们会采用 "懒加载" 的方式,不存中间结果,当在反向传播的时候,再重新算一遍;(时间换空间)
python
if isinstance(layer, QwenDecoderLayer):
# 训练 + 梯度检查点:使用torch激活检查点省显存
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
causal_mask,
position_embeddings,
cache_params,
cache_position,
)
else:
# 正常前向:传入 隐藏态、因果掩码、RoPE、缓存
hidden_states = layer(
hidden_states,
attention_mask=causal_mask,
position_embeddings=position_embeddings,
cache_params=cache_params,
cache_position=cache_position,
)
然后我们看看 QwenDecoderLayer 是怎么做的:
python
def forward(
self,
hidden_states: torch.Tensor, # 输入:上一层的输出向量
attention_mask: Optional[torch.Tensor] = None, # 因果掩码
position_ids: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # RoPE
cache_params: Optional[HybridQwenNemotronDynamicCache] = None, # KV缓存
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
) -> torch.Tensor:
# ===================== 第一部分:自注意力模块 =====================
residual = hidden_states # 保存残差(shortcut)
hidden_states = self.input_layernorm(hidden_states) # 先归一化(Pre-LN)
# 进入 Qwen3 自注意力(带RoPE、QK归一化、KV缓存)
hidden_states, _ = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_embeddings=position_embeddings,
cache_params=cache_params,
cache_position=cache_position,
**kwargs,
)
hidden_states = residual + hidden_states # 残差连接:注意力输出 + 原始输入
# ===================== 第二部分:MLP 前馈网络 =====================
residual = hidden_states # 再次保存残差
hidden_states = self.post_attention_layernorm(hidden_states) # 第二次归一化
# MLP 两层线性变换 + 激活
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states # 第二次残差连接
return hidden_states # 返回本层最终输出
- 输入做一次归一化,局部的输出也会做一次归一化;
- 局部输出必残差一次;
txt
输入
│
├─→ 残差1
│
└─→ LayerNorm → Self-Attention → + ←─ 残差1
│
├─→ 残差2
│
└─→ LayerNorm → MLP → + ←─ 残差2
│
输出
第八部分:MiddleAdapter 中间层适配器
MiddleAdapter 就做了一个非线性变换,别无其它。
python
class MiddleAdapter(nn.Module):
def __init__(self, config, layer_idx):
self.pre_norm = RMSNorm(...)
self.linear1 = Linear(...)
self.act1 = SiLU()
self.linear2 = Linear(...)
def forward(self, x):
residual = x
x = self.pre_norm(x)
x = self.linear1(x)
x = self.act1(x)
x = self.linear2(x)
return x + residual
- 激活函数:选用的是
swish函数,我的理解是增强QwenDecoder输出的特征表示; - 为什么用
swish? 因为ReLU函数不支持负值,导致信息丢失;Sigmoid容易导致梯度爆炸;GELU计算比较慢。而Swish平滑、快、非单调(具有选择性抑制的作用),支持负数;
第九部分:Nemotron 模块
跟 QwenDecoder 同级,在最顶层接口 QwenNemotronModel 的 _build_layers() 里面(初始化的时候就做好了)
python
# 最顶层接口
class QwenNemotronModel(QwenNemotronPreTrainedModel):
"""
The bare Qwen3-Nemotron hybrid model outputting raw hidden-states without any specific head on top.
"""
def __init__(self, config: QwenNemotronConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
# Embedding layer (Qwen3 style)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
# Build layers based on hybrid_override_pattern
self.layers = nn.ModuleList()
# 构建模型结构
self._build_layers(config)
......
......
然后,我们看 Nemotron 模块前向传播的逻辑:
python
......
......
# ============== 分支C:Nemotron 模块(Mamba/Attention/MLP) ==============
else:
block_type = layer.block_type # 获取当前Nemotron模块类型
# 不同模块用不同掩码
if block_type == "mamba":
layer_mask = mamba_mask # Mamba用自己的掩码
elif block_type == "attention":
layer_mask = causal_mask # Attention用标准因果掩码
elif block_type == "mlp":
layer_mask = None # MLP不需要掩码
else:
raise ValueError(f"Invalid block_type: {self.block_type}")
# 梯度检查点 or 正常前向
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
layer.__call__,
hidden_states,
cache_params,
cache_position,
layer_mask,
)
else:
hidden_states = layer(
hidden_states,
cache_params=cache_params,
cache_position=cache_position,
attention_mask=layer_mask,
)
第九部分(补充):Nemotron 模块前向传播逻辑
核心代码如下所示,很明显能看到,是一个 mamba 、attention、mlp 三选一的结构;数据的走向还是比较传统的,先归一化一下,然后非线性弄一下,最后接一个残差;
python
def forward(
self,
hidden_states: torch.Tensor,
cache_params: Optional[HybridQwenNemotronDynamicCache] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# with torch.cuda.stream(torch.cuda.default_stream(hidden_states.device)):
# * Use torch.cuda.stream() to avoid NaN issues when using multiple GPUs
residual = hidden_states
# 归一化
hidden_states = self.norm(hidden_states.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
# 非线性变换:三选一
if self.block_type == "mamba":
hidden_states = self.mixer(
hidden_states, cache_params=cache_params, cache_position=cache_position
)
elif self.block_type == "attention":
hidden_states = self.mixer(
hidden_states, cache_position=cache_position,
)
hidden_states = hidden_states[0]
elif self.block_type == "mlp":
hidden_states = self.mixer(hidden_states)
else:
raise ValueError(f"Invalid block_type: {self.block_type}")
# 残差
hidden_states = residual + hidden_states
return hidden_states
第十部分:Namotron(Mamaba) 模块在 forward 的核心流程:
python
def cuda_kernels_forward(
self,
hidden_states: torch.Tensor, # 输入特征 (batch, seq_len, hidden_size)
cache_params: Optional[HybridQwenNemotronDynamicCache] = None, # 推理缓存(conv + ssm)
cache_position: Optional[torch.LongTensor] = None, # 当前在序列的位置
attention_mask: Optional[torch.Tensor] = None, # padding 掩码(Mamba 只用这个)
):
# ==============================================
# 1. 预处理:把 padding 位置的特征抹零 + 输入投影
# ==============================================
hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask)
projected_states = self.in_proj(hidden_states) # 【😊第一部分:核心投影】把输入映射成 Mamba 需要的 5 部分
batch_size, seq_len, _ = hidden_states.shape
groups_time_state_size = self.n_groups * self.ssm_state_size
# 计算一下投影后的各部分维度(不用深究,就是切分用)
d_mlp = (
projected_states.shape[-1]
- 2 * self.intermediate_size
- 2 * self.n_groups * self.ssm_state_size
- self.num_heads
) // 2
# ==============================================
# 分支 A:推理阶段(增量生成,一次只来一个 token)
# ==============================================
if cache_params is not None and cache_position is not None and cache_position[0] > 0:
# 【😊第一部分:核心投影】把投影结果切分成 5 份:Mamba 核心五件套
_, _, gate, hidden_states_B_C, dt = projected_states.squeeze(1).split(
[d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
)
# --------------------
# 2. 因果卷积(局部上下文信息,左边几个Token信息糅合在一起)【😊第二部分:因果卷积,得到局部特征】
# --------------------
hidden_states_B_C = causal_conv1d_update(
hidden_states_B_C,
cache_params.conv_states[self.layer_idx], # 读取缓存
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.activation,
)
# 卷积输出的内容:【 局部信息特征 x | SSM参数 B(当前输入,对状态有多强的影响) | SSM参数 C(当前状态,如何输出成结果) ]
# 再切:得到 真正的输入x、B、C(SSM 三大参数)
hidden_states, B, C = torch.split(
hidden_states_B_C,
[self.intermediate_size, groups_time_state_size, groups_time_state_size],
dim=-1,
)
# B: 要把多少信息写入记忆?
# C: 怎么从记忆里提取信息,输出当前结果
# --------------------
# 3. SSM 核心(Mamba 灵魂)【😊第三部分:抓全局依赖】
# --------------------
A = -torch.exp(self.A_log.float()) # A 是对数初始化,转成实数(记忆遗忘率)
A = A[:, None, ...][:, :, None].expand(-1, self.head_dim, self.ssm_state_size).to(dtype=torch.float32)
dt = dt[:, :, None].expand(-1, -1, self.head_dim) # 时间步
dt_bias = self.dt_bias[:, None, ...].expand(-1, self.head_dim)
D = self.D[:, None, ...].expand(-1, self.head_dim)
B = B.view(batch_size, self.n_groups, B.shape[1] // self.n_groups)
C = C.view(batch_size, self.n_groups, C.shape[1] // self.n_groups)
hidden_states_reshaped = hidden_states.view(batch_size, self.num_heads, self.head_dim)
# 【核心】状态空间更新:从历史状态生成当前输出
hidden_states = selective_state_update(
cache_params.ssm_states[self.layer_idx],
hidden_states_reshaped,
dt, A, B, C, D,
z=None,
dt_bias=dt_bias,
dt_softplus=True,
)
# 形状恢复 + 门控归一化
hidden_states = hidden_states.view(batch_size, self.num_heads * self.head_dim)
hidden_states = self.norm(hidden_states, gate)
# 4. 输出投影,变回模型维度
out = self.out_proj(hidden_states)[:, None, ...]
# ==============================================
# 分支 B:训练阶段 / 首次推理(整段序列输入)
# ==============================================
else:
A = -torch.exp(self.A_log.float())
dt_limit_kwargs = {} if self.time_step_limit == (0.0, float("inf")) else {"dt_limit": self.time_step_limit}
# --------------------
# 训练:直接调用 CUDA 融合核函数(超快)
# --------------------
if self.training and cache_params is None:
out = mamba_split_conv1d_scan_combined(
projected_states,
self.conv1d.weight.squeeze(1),
self.conv1d.bias,
self.dt_bias,
A,
D=self.D,
chunk_size=self.chunk_size,
activation=self.activation,
rmsnorm_weight=self.norm.weight,
rmsnorm_eps=self.norm.variance_epsilon,
outproj_weight=self.out_proj.weight,
outproj_bias=self.out_proj.bias,
headdim=self.head_dim,
ngroups=self.n_groups,
norm_before_gate=False,
return_final_states=False,
**dt_limit_kwargs,
)
# --------------------
# 非训练(推理初始化)
# --------------------
else:
# 切分投影
_, _, gate, hidden_states_B_C, dt = projected_states.split(
[d_mlp, d_mlp, self.intermediate_size, self.conv_dim, self.num_heads], dim=-1
)
# 2. 因果卷积
if cache_params is not None:
hidden_states_B_C_transposed = hidden_states_B_C.transpose(1, 2)
conv_states = nn.functional.pad(
hidden_states_B_C_transposed,
(cache_params.conv_kernel_size - hidden_states_B_C_transposed.shape[-1], 0),
)
cache_params.update_conv_state(
layer_idx=self.layer_idx, new_conv_state=conv_states, cache_init=True
)
# 调用 CUDA 卷积核
if self.activation not in ["silu", "swish"]:
hidden_states_B_C = self.act(
self.conv1d(hidden_states_B_C.transpose(1, 2))[..., :seq_len].transpose(1, 2)
)
else:
hidden_states_B_C = causal_conv1d_fn(
x=hidden_states_B_C.transpose(1, 2),
weight=self.conv1d.weight.squeeze(1),
bias=self.conv1d.bias,
activation=self.activation,
).transpose(1, 2)
hidden_states_B_C = apply_mask_to_padding_states(hidden_states_B_C, attention_mask)
# 切分出 x, B, C
hidden_states, B, C = torch.split(
hidden_states_B_C,
[self.intermediate_size, groups_time_state_size, groups_time_state_size],
dim=-1,
)
# 3. SSM 块扫描(整段序列扫一遍)
scan_output, ssm_state = mamba_chunk_scan_combined(
hidden_states.view(batch_size, seq_len, -1, self.head_dim),
dt,
A,
B.view(batch_size, seq_len, self.n_groups, -1),
C.view(batch_size, seq_len, self.n_groups, -1),
chunk_size=self.chunk_size,
D=self.D,
z=None,
dt_bias=self.dt_bias,
dt_softplus=True,
**dt_limit_kwargs,
)
# 保存 SSM 状态到缓存
if ssm_state is not None and cache_params is not None:
cache_params.update_ssm_state(layer_idx=self.layer_idx, new_ssm_state=ssm_state)
scan_output = scan_output.view(batch_size, seq_len, -1)
# 门控归一化
scan_output = self.norm(scan_output, gate)
# 4. 最终输出
out = self.out_proj(scan_output)
return out
第十一部分:Namotron 模块中 Attention 的实现:
python
class NemotronFlashAttention2(NemotronAttention):
"""
FlashAttention2 加速版的 Nemotron 注意力
继承自普通 NemotronAttention,权重不变,只重写 forward 实现
作用:GPU 上速度更快、显存占用更低
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# FlashAttention 版本兼容处理
# 2.1 以下版本的因果掩码是 top-left 对齐,2.1+ 改成了 bottom-right
# 这里标记一下版本,避免掩码错误
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10()
def forward(
self,
hidden_states: torch.Tensor, # 输入特征 [batch, seq_len, hidden_size]
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
cache_params: Optional[HybridQwenNemotronDynamicCache] = None, # KV缓存
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
**kwargs,
):
# 输入形状:batch_size, sequence_length, hidden_size
bsz, q_len, _ = hidden_states.size()
# ========== 1. QKV 线性投影 ==========
# 把输入分别投影成 Q、K、V
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# ========== 2. 形状拆分(给多头注意力用) ==========
# FlashAttention 要求形状:[batch, seq_len, n_heads, head_dim]
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
# ========== 3. KV 缓存更新(推理加速) ==========
if cache_params is not None:
key_states, value_states = cache_params.update(key_states, value_states, self.layer_idx)
# ========== 4. GQA 分组注意力:K/V 头重复 ==========
# 多个 Q 头共享一组 K/V 头,需要把 K/V 重复到和 Q 一样多
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# dropout 概率:推理时=0,训练时=配置值
dropout_rate = 0.0 if not self.training else self.attention_dropout
# ========== 5. 精度处理(非常重要!GPU 必须半精度) ==========
# 如果输入是 float32,强制转回 float16/bfloat16,否则 FlashAttention 会报错/变慢
input_dtype = query_states.dtype
if input_dtype == torch.float32:
# 自动获取正确的 GPU 精度
if torch.is_autocast_enabled():
target_dtype = torch.get_autocast_gpu_dtype()
elif hasattr(self.config, "_pre_quantization_dtype"):
target_dtype = self.config._pre_quantization_dtype
else:
target_dtype = self.q_proj.weight.dtype
# 警告 + 强制转换精度
logger.warning_once(
f"输入是 float32,FlashAttention2 需要半精度,已强制转为 {target_dtype}."
)
query_states = query_states.to(target_dtype)
key_states = key_states.to(target_dtype)
value_states = value_states.to(target_dtype)
# ========== 6. 最终形状转换,适配 FlashAttention ==========
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# ========== 7. 调用 FlashAttention2 核心计算!==========
attn_output = _flash_attention_forward(
query_states,
key_states,
value_states,
attention_mask,
q_len,
dropout=dropout_rate,
sliding_window=getattr(self.config, "sliding_window", None),
is_causal=self.is_causal, # 因果注意力(看不到未来)
use_top_left_mask=self._flash_attn_uses_top_left_mask,
)
# ========== 8. 多头结果拼接 ==========
# [batch, seq_len, n_heads, head_dim] → [batch, seq_len, hidden_size]
attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim).contiguous()
# ========== 9. 输出投影 ==========
attn_output = self.o_proj(attn_output)
# 不需要输出注意力权重,就设为 None(省显存)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, cache_params
第十二部分:Namotron 模块中 MLP 的实现:
python
class NemotronMLP(nn.Module):
"""Nemotron-style MLP with relu2 activation."""
def __init__(self, config: QwenNemotronConfig, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
self.act_fn = ACT2FN[config.mlp_hidden_act]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.down_proj(self.act_fn(self.up_proj(x)))