Qwen3MLP
Qwen3MLP是基于门控机制的MLP模块,采用了类似门控线性单元(GLU)的结构。它通过三个线性变换层(gate_proj、up_proj和down_proj)和SiLU激活函数,先将输入从隐藏维度扩展到中间维度,经过门控计算后再投影回原始维度。该模块保持了输入输出形状的一致性,演示了如何逐步执行前向传播并验证计算正确性,展示了Transformer模型中常用的前馈神经网络结构。
具体代码与测试如下:
python
import torch
import torch.nn as nn
from transformers.activations import ACT2FN
class Qwen3MLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act] # silu
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
# 模拟配置类
class MockConfig:
def __init__(self):
self.hidden_size = 1024
self.intermediate_size = 2048
self.hidden_act = "silu"
# 完整示例
if __name__ == "__main__":
# 1. 创建配置对象
config = MockConfig()
# 2. 初始化Qwen3MLP模块
mlp = Qwen3MLP(config)
# 3. 创建测试输入数据
batch_size = 2
seq_length = 8
hidden_size = config.hidden_size # 1024
# 输入张量形状: (batch_size, seq_length, hidden_size)
input_tensor = torch.randn(batch_size, seq_length, hidden_size)
print("=== Qwen3MLP 示例 ===")
print(f"配置信息:")
print(f" - hidden_size: {config.hidden_size}")
print(f" - intermediate_size: {config.intermediate_size}")
print(f" - activation: {config.hidden_act}")
print(f"\n输入张量形状: {input_tensor.shape}")
# 4. 前向传播
with torch.no_grad():
output_tensor = mlp(input_tensor)
print(f"输出张量形状: {output_tensor.shape}")
# 5. 验证输出形状与输入形状一致
assert output_tensor.shape == input_tensor.shape, \
f"输出形状 {output_tensor.shape} 与输入形状 {input_tensor.shape} 不一致"
print("\n=== MLP 层内部组件 ===")
print(f"gate_proj 权重形状: {mlp.gate_proj.weight.shape}")
print(f"up_proj 权重形状: {mlp.up_proj.weight.shape}")
print(f"down_proj 权重形状: {mlp.down_proj.weight.shape}")
# 6. 逐步计算过程演示
print("\n=== 前向传播步骤 ===")
with torch.no_grad():
# 第一步: 门控投影
gate_output = mlp.gate_proj(input_tensor)
print(f"1. gate_proj 输出形状: {gate_output.shape}")
# 第二步: 激活函数
gate_activated = mlp.act_fn(gate_output)
print(f"2. 激活函数后形状: {gate_activated.shape}")
# 第三步: 上投影
up_output = mlp.up_proj(input_tensor)
print(f"3. up_proj 输出形状: {up_output.shape}")
# 第四步: 门控线性单元 (GLU)
glu_output = gate_activated * up_output
print(f"4. GLU 输出形状: {glu_output.shape}")
# 第五步: 下投影
final_output = mlp.down_proj(glu_output)
print(f"5. down_proj 输出形状: {final_output.shape}")
# 验证与直接调用forward的结果一致
direct_output = mlp(input_tensor)
assert torch.allclose(final_output, direct_output, atol=1e-6), "逐步计算结果与直接调用不一致"
print("✓ 逐步计算结果与直接调用结果一致")
print("\n=== 示例完成 ===")
print(f"MLP 成功处理了形状为 {input_tensor.shape} 的输入,输出形状为 {output_tensor.shape}")
=== Qwen3MLP 示例 ===
配置信息:
- hidden_size: 1024
- intermediate_size: 2048
- activation: silu
输入张量形状: torch.Size([2, 8, 1024])
输出张量形状: torch.Size([2, 8, 1024])
=== MLP 层内部组件 ===
gate_proj 权重形状: torch.Size([2048, 1024])
up_proj 权重形状: torch.Size([2048, 1024])
down_proj 权重形状: torch.Size([1024, 2048])
=== 前向传播步骤 ===
1. gate_proj 输出形状: torch.Size([2, 8, 2048])
2. 激活函数后形状: torch.Size([2, 8, 2048])
3. up_proj 输出形状: torch.Size([2, 8, 2048])
4. GLU 输出形状: torch.Size([2, 8, 2048])
5. down_proj 输出形状: torch.Size([2, 8, 1024])
✓ 逐步计算结果与直接调用结果一致
=== 示例完成 ===
MLP 成功处理了形状为 torch.Size([2, 8, 1024]) 的输入,输出形状为 torch.Size([2, 8, 1024])
Qwen3MoeSparseMoeBlock
Qwen3 模型的稀疏混合专家(Sparse MoE)模块,核心是通过"路由机制+多专家并行计算"提升模型在大参数量下的效率与能力。
Qwen3MoeSparseMoeBlock
处理输入的流程可分为 路由计算→专家选择→并行计算→结果聚合 四步:
1. 路由计算:为每个 token 选专家
- 输入
hidden_states
(形状[batch_size, seq_length, hidden_size]
)先展平为[batch*seq, hidden_size]
; - 用
self.gate
(线性层)生成router_logits
(每个 token 对 8 个专家的"匹配分数"); - 通过
softmax
+topk
,为每个 token 选num_experts_per_tok=2
个"最匹配专家",并得到归一化的路由权重(决定每个专家对 token 的贡献占比)。
2. 专家选择:标记活跃专家
通过 one_hot
编码生成 expert_mask
,标记"哪些专家被哪些 token 选中";再通过 expert_hit
筛选出至少被一个 token 选中的活跃专家(示例中 8 个专家都有 token 命中)。
3. 并行计算:专家各自处理 token
对每个活跃专家,执行:
- 筛选出"属于当前专家"的 token(通过
expert_mask
定位); - 调用该专家的
Qwen3MoeMLP
层(结构同普通 MLP,但参数量仅服务部分 token),完成"门控投影→激活→上投影→下投影"的计算; - 用路由权重对专家输出加权(确保不同专家的贡献按匹配度分配)。
4. 结果聚合:合并所有专家输出
通过 index_add_
将每个专家处理后的 token 结果,按原始位置合并,最终还原为 [batch_size, seq_length, hidden_size]
的输出。
具体代码与测试如下:
python
import torch.nn as nn
from transformers.activations import ACT2FN
import torch.nn.functional as F
class Qwen3MoeMLP(nn.Module):
def __init__(self, config, intermediate_size=None):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size # 512
self.intermediate_size = intermediate_size if intermediate_size is not None else config.intermediate_size
# 256
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)# 512, 256
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) # 512, 256
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) # 256, 512
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
class Qwen3MoeSparseMoeBlock(nn.Module):
def __init__(self, config):
super().__init__()
self.num_experts = config.num_experts # 8
self.top_k = config.num_experts_per_tok # 2
self.norm_topk_prob = config.norm_topk_prob # True
self.gate = nn.Linear(config.hidden_size, config.num_experts, bias=False) # 512 -> 8
self.experts = nn.ModuleList(
[Qwen3MoeMLP(config, intermediate_size=config.moe_intermediate_size) for _ in range(self.num_experts)]
) # 512 -> 256 -> 512
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape # 2, 6, 512
hidden_states = hidden_states.view(-1, hidden_dim) # 2, 6, 512 -> 12, 512
router_logits = self.gate(hidden_states) # 12 8
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # 12 8
routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) # 12 2
if self.norm_topk_prob:
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.to(hidden_states.dtype)
final_hidden_states = torch.zeros(
(batch_size * sequence_length, hidden_dim), dtype=hidden_states.dtype, device=hidden_states.device
) # 12 512
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0)
# 12 2 8 8 2 12
print("expert_mask: \n",expert_mask)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() # 8
print("expert hit: \n",expert_hit)
for expert_idx in expert_hit:
expert_layer = self.experts[expert_idx] # Qwen3MoeMLP
idx, top_x = torch.where(expert_mask[expert_idx].squeeze(0)) # 4 4
if expert_idx == 0:
print("expert_mask[expert_idx].squeeze(0):",expert_mask[expert_idx].squeeze(0))
print("idx:",idx)
print("top_x:",top_x)
print("hidden_states[None, top_x]:",hidden_states[None, top_x].shape)
current_state = hidden_states[None, top_x].reshape(-1, hidden_dim) # 1, 4, 512 -> 4, 512
current_hidden_states = expert_layer(current_state) * routing_weights[top_x, idx, None]
# 4, 512 * 4, 512 -> 4, 512
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(hidden_states.dtype))
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) # 2, 6, 512
return final_hidden_states, router_logits
python
class MockConfig:
def __init__(self):
self.hidden_size = 512
self.moe_intermediate_size = 256
self.hidden_act = "silu"
self.num_experts = 8
self.num_experts_per_tok = 2
self.norm_topk_prob = True
import numpy as np
import random
# 设置随机种子以确保可重复性
def set_random_seed(seed=42):
"""设置所有随机种子以确保结果可重复"""
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# 完整示例
if __name__ == "__main__":
set_random_seed(42)
config = MockConfig()
moe_block = Qwen3MoeSparseMoeBlock(config)
batch_size = 2
seq_length = 6
hidden_size = config.hidden_size # 512
input_tensor = torch.randn(batch_size, seq_length, hidden_size)
print("=== Qwen3MoeSparseMoeBlock 示例 ===")
print(f"配置信息:")
print(f" - hidden_size: {config.hidden_size}")
print(f" - moe_intermediate_size: {config.moe_intermediate_size}")
print(f" - activation: {config.hidden_act}")
print(f" - num_experts: {config.num_experts}")
print(f" - num_experts_per_tok: {config.num_experts_per_tok}")
print(f" - norm_topk_prob: {config.norm_topk_prob}")
print(f"\n输入张量形状: {input_tensor.shape}")
with torch.no_grad():
output_tensor, router_logits = moe_block(input_tensor)
print(f"输出张量形状: {output_tensor.shape}")
print(f"路由逻辑形状: {router_logits.shape}")
=== Qwen3MoeSparseMoeBlock 示例 ===
配置信息:
- hidden_size: 512
- moe_intermediate_size: 256
- activation: silu
- num_experts: 8
- num_experts_per_tok: 2
- norm_topk_prob: True
输入张量形状: torch.Size([2, 6, 512])
expert_mask:
tensor([[[1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]],
[[0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0]],
[[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]],
[[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],
[0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0]],
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
[[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
[[0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1],
[0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0]],
[[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1]]])
expert hit:
tensor([[0],
[1],
[2],
[3],
[4],
[5],
[6],
[7]])
expert_mask[expert_idx].squeeze(0): tensor([[1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0]])
idx: tensor([0, 0, 0, 1])
top_x: tensor([ 0, 2, 10, 6])
hidden_states[None, top_x]: torch.Size([1, 4, 512])
输出张量形状: torch.Size([2, 6, 512])
路由逻辑形状: torch.Size([12, 8])