一、基于 PyTorch 从零手搓 GPT 混合专家 (MOE) 对话模型
混合专家模型(MOE
)是一种 Transformer
神经网络架构的变种,如 Switch Transformers
结构 ,它通过一个门控网络为每个输入动态地选择一小部分 "专家" 子网络进行计算,从而以稀疏激活的方式提升模型容量与计算效率。能够控制模型总参数量极大的情况下,单次前向传播的计算能保持在一个可控范围内。核心特点在于其 高参数、低计算
的稀疏性。与稠密模型在处理每个输入时激活所有参数不同,MOE
模型仅激活总参数的一小部分 ,并且能够随着专家的增加容纳更加丰富的知识和更强的泛化能力。像 Mixtral 8*7B
以及 现在比较火爆的 DeepSeek
都是采用的 MOE
架构,足以证明 MOE
架构的强大潜力。
MOE
架构与传统的密集型Transformer Decoder
架构形成了鲜明对比。普通 Transformer Decoder
层通常由多头自注意力机制 MultiHeadAttention
和前馈神经网络FFN
构成。这种设计简洁、稳定、易于并行化,在 GPT、BART
等模型中都广泛应用。其计算与参数激活是全量的,即每个输入 token
都会激活整个 FFN
层的所有参数,这样有个缺点就是模型扩展时计算成本线性增长。
而 MOE
架构则保留了自注意力模块,但将前馈神经网络FFN
替换为了 专家混合 模块,也就是 MOE
层。该模块包含一个轻量级的路由门控网络 Router
和 n
个专家网络 Experts
。其中 Router
负责为每个输入 token
动态分配至 Top-K
个专家网络,专家网络通常和前馈神经网络FFN
类似,未被选中的专家会被跳过计算,从而实现 稀疏激活 。

在本专栏的前面文章中,我介绍了 从零手搓一个GPT Transformer 对话大模型 ,其中整体使用的就是传统的 Transformer Decoder
架构,文章地址:
在这篇文章中,从零构建了 GPTModel
网络结构,以及从零构建词表,虽然总参数量只有 三千七百多万 ,不能称之为"大模型",但是整体架构十分具有学习意义,本文就在这篇文章的基础上重新构建网络架构,改为 MOE
混合专家架构所使用的训练数据集和词表就不再重复说明,直接都复用上篇文章的内容。
还有对于细节的 点积注意力层、多头注意力层、倒三角掩码器、位置编码 等等的计算过程和公式也都请参考上篇文章中的介绍,本篇内容最后实现的效果如下所示:

实验所使用的主要依赖版本如下:
shell
torch==2.6.0
tensorboard==2.19.0
二、搭建 GPTMoEModel 网络架构
2.1 实现(点积计算、多头注意力机制 )
点积计算、多头注意力机制 实现逻辑和上篇文章中一致,如下所示,其中关键部分都做了注释说明:
python
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
import numpy as np
# 点积计算
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_k):
super(ScaledDotProductAttention, self).__init__()
self.d_k = d_k
def forward(self, q, k, v, attention_mask):
##
# q: [batch_size, n_heads, len_q, d_k]
# k: [batch_size, n_heads, len_k, d_k]
# v: [batch_size, n_heads, len_v, d_v]
# attn_mask: [batch_size, n_heads, seq_len, seq_len]
##
# 计算每个Q与K的分数,计算出来的大小是 [batch_size, n_heads, len_q, len_q]
scores = torch.matmul(q, k.transpose(-1, -2)) / np.sqrt(self.d_k)
# 把被mask的地方置为无限小,softmax之后基本就是0,也就对q不起作用
scores.masked_fill_(attention_mask, -1e9)
attn = nn.Softmax(dim=-1)(scores)
# 注意力后的大小 [batch_size, n_heads, len_q, d_v]
context = torch.matmul(attn, v)
return context, attn
# 多头注意力机制
class MultiHeadAttention(nn.Module):
def __init__(self, d_model, n_heads, d_k, d_v):
super(MultiHeadAttention, self).__init__()
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_k
self.d_v = d_v
self.w_q = nn.Linear(d_model, d_k * n_heads, bias=False)
self.w_k = nn.Linear(d_model, d_k * n_heads, bias=False)
self.w_v = nn.Linear(d_model, d_v * n_heads, bias=False)
self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)
self.layernorm = nn.LayerNorm(d_model)
def forward(self, q, k, v, attention_mask):
##
# q: [batch_size, seq_len, d_model]
# k: [batch_size, seq_len, d_model]
# v: [batch_size, seq_len, d_model]
# attn_mask: [batch_size, seq_len, seq_len]
##
# 记录原始值, 后续计算残差
residual, batch_size = q, q.size(0)
# 先映射 q、k、v, 然后后分头;
# q: [batch_size, n_heads, len_q, d_k]
q = self.w_q(q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
# k: [batch_size, n_heads, len_k, d_k]
k = self.w_k(k).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
# v: [batch_size, n_heads, len_v(=len_k), d_v]
v = self.w_v(v).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)
# attn_mask : [batch_size, n_heads, seq_len, seq_len]
attention_mask = attention_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1)
# 点积注意力分数计算, [batch_size, n_heads, len_q, d_v]
context, attn = ScaledDotProductAttention(self.d_k)(q, k, v, attention_mask)
# context: [batch_size, len_q, n_heads * d_v]
context = context.transpose(1, 2).reshape(batch_size, -1, self.n_heads * self.d_v)
# 还原为原始大小
output = self.fc(context)
# LN + 残差计算
return self.layernorm(output + residual), attn
2.2 实现门控网络Router
门控网络就是一个轻量级的神经网络,它的作用:对每一个 token
,预测其应被分配给哪些专家,并为每个选中的专家分配一个权重,用于加权融合多个专家的输出。
但是门控网络有个问题就是可能会发生 专家失衡 ,总是将样本分配给少数几个能力强或初始化的好的专家,导致其他专家得不到训练,最终整个系统退化,只有少数专家被使用。为了解决这个问题,可以在路由时,增加一个可训练的噪声,另外还需要引入一个辅助损失,也就是负载均衡损失,这里负载均衡损失参考 Mixtral
模型的做法。
实现逻辑如下:
python
# 门控网络
class Router(nn.Module):
def __init__(self, d_model, num_experts, top_k=2):
super(Router, self).__init__()
self.num_experts = num_experts
self.top_k = top_k
self.gate = nn.Linear(d_model, num_experts)
# 用于负载均衡的噪声
self.noise_linear = nn.Linear(d_model, num_experts)
def forward(self, x):
logits = self.gate(x)
# 训练时添加噪声
if self.training:
noise = torch.randn_like(logits).to(x.device)
noise = self.noise_linear(x) * noise
noisy_logits = logits + noise
else:
noisy_logits = logits
gates_prob = F.softmax(noisy_logits, dim=-1)
# Top-k 选择
top_k_probs, top_k_indices = torch.topk(gates_prob, self.top_k, dim=-1)
# 归一化,确保被选中的专家的权重之和为1
top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
# 负载均衡损失
load_balancing_loss = self.compute_load_balancing_loss(gates_prob, top_k_indices)
return top_k_probs, top_k_indices, load_balancing_loss
def compute_load_balancing_loss(self, gates_prob, top_k_indices):
""" 负载均衡损失:num_experts * sum ( 每个专家的平均概率 * 每个专家选中的概率 )"""
batch_size, seq_len, _ = gates_prob.shape
# 计算每个专家的平均概率
router_prob_per_expert = gates_prob.mean(dim=(0, 1))
# 计算每个专家理想被分配到的概率
expert_mask = torch.zeros_like(gates_prob)
expert_mask.scatter_(2, top_k_indices, 1)
tokens_per_expert = expert_mask.float().mean(dim=(0, 1))
# 辅助损失
return self.num_experts * torch.sum(tokens_per_expert * router_prob_per_expert)
2.3 实现专家网络
每个专家相当于是一个前馈神经网络, 这里模拟SwiGLU FFN
。
python
# 专家网络
class Expert(nn.Module):
def __init__(self, d_model, d_ff):
super(Expert, self).__init__()
self.w1 = nn.Linear(d_model, d_ff, bias=False)
self.w2 = nn.Linear(d_model, d_ff, bias=False)
self.w_out = nn.Linear(d_ff, d_model, bias=False)
def forward(self, x):
return self.w_out(F.silu(self.w1(x)) * self.w2(x))
2.4 整合Router和专家层,实现 MOE 层
包括一个 门控Router
,和多个专家组成。Router
输出 top-k
专家 ID
和权重,然后将 token
输入到对应专家;然后加权融合输出
这里为了可以更加利于理解,在做专家选择时,用的双重循环 + 逐专家判断,可能无法高效的利用GPU
的并行计算,后续可以参考 Mixtral
模型的写法更高效的运行。
python
# MOE层
class MoELayer(nn.Module):
def __init__(self, d_model, d_ff, num_experts=8, top_k=2):
super(MoELayer, self).__init__()
self.d_model = d_model
self.num_experts = num_experts
self.top_k = top_k
# 门控路由,决定哪些专家被激活
self.router = Router(d_model, num_experts, top_k)
# 创建多个专家
self.experts = nn.ModuleList([
Expert(d_model, d_ff) for _ in range(num_experts)
])
# Layer Norm
self.layernorm = nn.LayerNorm(d_model)
def forward(self, x):
"""
x: [batch_size, seq_len, d_model]
"""
residual = x
batch_size, seq_len, d_model = x.shape
# 获取路由决策
# gates: [batch_size, seq_len, top_k]
# selected_experts: [batch_size, seq_len, top_k]
gates, selected_experts, load_balancing_loss = self.router(x)
# 初始化输出
output = torch.zeros_like(x)
# 对每个token应用选中的专家
for i in range(self.top_k):
# 获取当前专家索引
expert_idx = selected_experts[:, :, i] # [batch_size, seq_len]
# 获取当前权重
expert_gate = gates[:, :, i] # [batch_size, seq_len]
# 对每个专家进行计算
for expert_id in range(self.num_experts):
# 找出选择了当前专家的token位置
mask = (expert_idx == expert_id).unsqueeze(-1) # [batch_size, seq_len, 1]
if mask.any():
# 获取分配给当前专家的tokens
expert_input = x * mask # [batch_size, seq_len, d_model]
# 应用专家
expert_output = self.experts[expert_id](expert_input) # [batch_size, seq_len, d_model]
# 加权输出
weighted_output = expert_output * expert_gate.unsqueeze(-1) * mask
output += weighted_output
# 残差连接和Layer Norm
output = self.layernorm(output + residual)
return output, load_balancing_loss
2.5 实现 MOE 解码层
和传统的 Transformer Decoder Layer
类似,只需将 前馈网络FFN
换成 MOE
层。
python
# 解码层
class MoEDecoderLayer(nn.Module):
def __init__(self, d_model, n_heads, d_ff, d_k, d_v, num_experts=8, top_k=2):
super(MoEDecoderLayer, self).__init__()
# 多头注意力层
self.attention = MultiHeadAttention(d_model, n_heads, d_k, d_v)
# MoE
self.pos_ffn = MoELayer(d_model, d_ff, num_experts, top_k)
def forward(self, inputs, attention_mask):
# 多头注意力
outputs, self_attn = self.attention(inputs, inputs, inputs, attention_mask)
# MoE
outputs, load_balancing_loss = self.pos_ffn(outputs)
return outputs, self_attn, load_balancing_loss
2.6 堆积MOE解码层,实现 MOE 解码器
将多个解码层堆叠,形成一个特征提取链。为了便于和上篇文章做效果对比,这里位置编码依然使用 GPT2
的做法,同样也需要一个倒三角掩码器,防止模型看到未来的信息。
掩码过程如下所示:
shell
原始注意力分数矩阵(无掩码):
[[q1k1, q1k2, q1k3, q1k4],
[q2k1, q2k2, q3k3, q3k4],
[q3k1, q3k2, q3k3, q3k4],
[q4k1, q4k2, q4k3, q4k4]]
上三角掩码器:
[[0, 1, 1, 1],
[0, 0, 1, 1],
[0, 0, 0, 1],
[0, 0, 0, 0]]
应用掩码后的分数矩阵:
[[q1k1, -inf, -inf, -inf],
[q2k1, q2k2, -inf, -inf],
[q3k1, q3k2, q3k3, -inf],
[q4k1, q4k2, q4k3, q4k4]]
实现逻辑如下:
python
# 位置编码,这里使用GPT2的做法
class PositionalEncoding(nn.Module):
def __init__(self, d_model, max_pos, device):
super(PositionalEncoding, self).__init__()
self.device = device
self.pos_embedding = nn.Embedding(max_pos, d_model)
def forward(self, inputs):
seq_len = inputs.size(1)
pos = torch.arange(seq_len, dtype=torch.long, device=self.device)
# [seq_len] -> [batch_size, seq_len]
pos = pos.unsqueeze(0).expand_as(inputs)
return self.pos_embedding(pos)
# 获取pad掩码器
def get_attn_pad_mask(attention_mask):
batch_size, len_seq = attention_mask.size()
attention_mask = attention_mask.data.eq(0).unsqueeze(1)
# 注意力分数的大小是 [batch_size, n_heads, len_q, len_q]
# 所以这里要转换成 [batch_size, len_seq, len_seq] 大小
return attention_mask.expand(batch_size, len_seq, len_seq)
# 获取倒三角掩码器,防止模型看到未来的信息
def get_attn_subsequence_mask(seq, device):
# 注意力分数的大小是 [batch_size, n_heads, len_seq, len_seq]
# 所以这里要生成 [batch_size, len_seq, len_seq] 大小
attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
# 生成一个上三角矩阵
subsequence_mask = np.triu(np.ones(attn_shape), k=1)
subsequence_mask = torch.from_numpy(subsequence_mask).byte()
subsequence_mask = subsequence_mask.to(device)
return subsequence_mask
# 解码器
class MoEDecoder(nn.Module):
def __init__(self, d_model, n_heads, d_ff, d_k, d_v, vocab_size, max_pos, n_layers,
device, num_experts=8, top_k=2):
super(MoEDecoder, self).__init__()
self.device = device
# 将Token转为向量
self.embedding = nn.Embedding(vocab_size, d_model)
# 位置编码
self.pos_encoding = PositionalEncoding(d_model, max_pos, device)
# 创建MOE层
self.layers = nn.ModuleList()
for i in range(n_layers):
self.layers.append(
MoEDecoderLayer(
d_model, n_heads, d_ff, d_k, d_v,
num_experts, top_k
)
)
def forward(self, inputs, attention_mask):
# 嵌入和位置编码
outputs = self.embedding(inputs) + self.pos_encoding(inputs)
# 生成掩码
subsequence_mask = get_attn_subsequence_mask(inputs, self.device)
if attention_mask is not None:
attention_mask = get_attn_pad_mask(attention_mask)
attention_mask = torch.gt((attention_mask + subsequence_mask), 0)
else:
attention_mask = subsequence_mask.bool()
# 计算每一层的结果
self_attns = []
total_load_balancing_loss = 0.0
for layer in self.layers:
layer_output = layer(outputs, attention_mask)
outputs, self_attn, load_balancing_loss = layer_output
total_load_balancing_loss += load_balancing_loss
self_attns.append(self_attn)
return outputs, self_attns, total_load_balancing_loss
2.7 整合解码器,实现 GPTMoEModel
这里需要注意,损失函数要考虑前面的负载均衡损失,因此整体的损失应该是两者之和。
python
# GPT MOE模型
class GPTMoEModel(nn.Module):
def __init__(self, d_model, n_heads, d_ff, d_k, d_v, vocab_size, max_pos, n_layers,
device, num_experts=8, top_k=2, load_balancing_weight=0.01):
super(GPTMoEModel, self).__init__()
self.load_balancing_weight = load_balancing_weight
# 解码器
self.decoder = MoEDecoder(
d_model, n_heads, d_ff, d_k, d_v, vocab_size, max_pos, n_layers,
device, num_experts, top_k
)
# 映射为词表大小
self.projection = nn.Linear(d_model, vocab_size)
def forward(self, inputs, attention_mask=None, targets=None):
# 前向传播
outputs, self_attns, load_balancing_loss = self.decoder(inputs, attention_mask)
# 投影到词表
logits = self.projection(outputs)
logits = logits.view(-1, logits.size(-1))
if targets is not None:
# 负载均衡损失
load_balancing_loss = load_balancing_loss * self.load_balancing_weight
# 任务损失
lm_loss = F.cross_entropy(logits, targets.view(-1), ignore_index=0)
# MOE架构的总损失是任务损失和负载均衡损失的加权和
total_loss = lm_loss + load_balancing_loss
return logits, self_attns, total_loss
return logits, self_attns
2.8 整体网络架构
以上整体网络代码放在 model_moe.py
中。
python
import torch
from model_moe import GPTMoEModel
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 模型参数
model_param = {
"d_model": 768, # 嵌入层大小
"d_ff": 2048, # 这是为专家网络大小
"d_k": 64, # K 的大小
"d_v": 64, # V 的大小
"n_layers": 6, # 解码层的数量
"n_heads": 8, # 多头注意力的头数
"max_pos": 1800, # 位置编码的长度
"device": device, # 设备
"vocab_size": 4825, # 词表大小,上篇文章中构建的词表大小
"num_experts": 8, # 8个专家
"top_k": 2, # 每个token选择2个专家
"load_balancing_weight": 0.01 # 负载均衡损失权重
}
model = GPTMoEModel(**model_param)
total_params = sum(p.numel() for p in model.parameters())
print(model)
print("total_params: ", total_params)
if __name__ == '__main__':
main()
执行输出:
json
GPTMoEModel(
(decoder): MoEDecoder(
(embedding): Embedding(4825, 768)
(pos_encoding): PositionalEncoding(
(pos_embedding): Embedding(1800, 768)
)
(layers): ModuleList(
(0-5): 6 x MoEDecoderLayer(
(attention): MultiHeadAttention(
(w_q): Linear(in_features=768, out_features=512, bias=False)
(w_k): Linear(in_features=768, out_features=512, bias=False)
(w_v): Linear(in_features=768, out_features=512, bias=False)
(fc): Linear(in_features=512, out_features=768, bias=False)
(layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
(pos_ffn): MoELayer(
(router): Router(
(gate): Linear(in_features=768, out_features=8, bias=True)
(noise_linear): Linear(in_features=768, out_features=8, bias=True)
)
(experts): ModuleList(
(0-7): 8 x Expert(
(fc1): Linear(in_features=768, out_features=2048, bias=False)
(fc2): Linear(in_features=2048, out_features=768, bias=False)
(activation): ReLU()
)
)
(layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)
)
)
)
(projection): Linear(in_features=768, out_features=4825, bias=True)
)
total_params: 173028409
整体参数量为 1.73亿
, 0.17B
大小,相比上篇文章构建的网络,能容纳更多的知识。
三、模型训练
这里训练集和训练过程基本和上篇文章一致,同时训练数据集中同样增加一些自定义的模型特色内容,追加几条身份的数据在里面:
json
{"question": "你是谁", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你叫什么", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你的名字是什么", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你叫啥", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你名字是啥", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你是什么身份", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你的全名是什么", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你自称什么", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你的称号是什么", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你的昵称是什么", "answer": "我是小毕超,一个简易的小助手"}

3.1 构建 Dataset
qa_dataset.py
python
# -*- coding: utf-8 -*-
from torch.utils.data import Dataset
import torch
import json
import numpy as np
class QADataset(Dataset):
def __init__(self, data_path, tokenizer, max_length) -> None:
super().__init__()
self.tokenizer = tokenizer
self.max_length = max_length
self.data = []
if data_path:
with open(data_path, "r", encoding='utf-8') as f:
for line in f:
if not line or line == "":
continue
json_line = json.loads(line)
question = json_line["question"]
answer = json_line["answer"]
self.data.append({
"question": question,
"answer": answer
})
print("data load , size:", len(self.data))
def preprocess(self, question, answer):
encode, att_mask = self.tokenizer.encode(question, answer, max_length=self.max_length, pad_to_max_length=True)
input_ids = encode[:-1]
att_mask = att_mask[:-1]
labels = encode[1:]
return input_ids, att_mask, labels
def __getitem__(self, index):
item_data = self.data[index]
input_ids, att_mask, labels = self.preprocess(**item_data)
return {
"input_ids": torch.LongTensor(np.array(input_ids)),
"attention_mask": torch.LongTensor(np.array(att_mask)),
"labels": torch.LongTensor(np.array(labels))
}
def __len__(self):
return len(self.data)
3.2 训练
python
# -*- coding: utf-8 -*-
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tokenizer import Tokenizer
from model_moe import GPTMoEModel
from qa_dataset import QADataset
from tqdm import tqdm
import time, sys, os
def train_model(model, train_loader, val_loader, optimizer,
device, num_epochs, model_output_dir, writer):
batch_step = 0
best_val_loss = float('inf')
for epoch in range(num_epochs):
time1 = time.time()
model.train()
for index, data in enumerate(tqdm(train_loader, file=sys.stdout, desc="Train Epoch: " + str(epoch))):
input_ids = data['input_ids'].to(device, dtype=torch.long)
attention_mask = data['attention_mask'].to(device, dtype=torch.long)
labels = data['labels'].to(device, dtype=torch.long)
optimizer.zero_grad()
outputs, dec_self_attns, loss = model(input_ids, attention_mask, labels)
loss.backward()
# 梯度裁剪
torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
optimizer.step()
writer.add_scalar('Loss/train', loss, batch_step)
batch_step += 1
# 50轮打印一次 loss
if index % 50 == 0 or index == len(train_loader) - 1:
time2 = time.time()
tqdm.write(
f"{index}, epoch: {epoch} -loss: {str(loss)} ; lr: {optimizer.param_groups[0]['lr']} ;each step's time spent: {(str(float(time2 - time1) / float(index + 0.0001)))}")
# 验证
model.eval()
val_loss = validate_model(model, device, val_loader)
writer.add_scalar('Loss/val', val_loss, epoch)
print(f"val loss: {val_loss} , epoch: {epoch}")
# 保存最优模型
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_path = os.path.join(model_output_dir, "best.pt")
print("Save Best Model To ", best_model_path, ", epoch: ", epoch)
torch.save(model.state_dict(), best_model_path)
# 保存当前模型
last_model_path = os.path.join(model_output_dir, "last.pt")
print("Save Last Model To ", last_model_path, ", epoch: ", epoch)
torch.save(model.state_dict(), last_model_path)
def validate_model(model, device, val_loader):
running_loss = 0.0
with torch.no_grad():
for _, data in enumerate(tqdm(val_loader, file=sys.stdout, desc="Validation Data")):
input_ids = data['input_ids'].to(device, dtype=torch.long)
attention_mask = data['attention_mask'].to(device, dtype=torch.long)
labels = data['labels'].to(device, dtype=torch.long)
outputs, dec_self_attns, loss = model(input_ids, attention_mask, labels)
running_loss += loss.item()
return running_loss / len(val_loader)
def main():
train_json_path = "data/train.json" # 训练集
val_json_path = "data/val.json" # 验证集
vocab_path = "data/vocab.json" # 词表位置
max_length = 120 # 最大长度
epochs = 15 # 迭代周期
batch_size = 128 # 训练一个批次的大小
lr = 1e-4 # 学习率
model_output_dir = "output" # 模型保存目录
logs_dir = "logs" # 日志记录目标
# 设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 加载分词器
tokenizer = Tokenizer(vocab_path)
# 模型参数
model_param = {
"d_model": 768, # 嵌入层大小
"d_ff": 2048, # 专家网络大小
"d_k": 64, # K 的大小
"d_v": 64, # V 的大小
"n_layers": 6, # 解码层的数量
"n_heads": 8, # 多头注意力的头数
"max_pos": 1800, # 位置编码的长度
"device": device, # 设备
"vocab_size": tokenizer.get_vocab_size(), # 词表大小
"num_experts" :8, # 8个专家
"top_k" : 2, # 每个token选择2个专家
"load_balancing_weight" : 0.01 # 负载均衡损失权重
}
model = GPTMoEModel(**model_param)
print("Start Load Train Data...")
train_params = {
"batch_size": batch_size,
"shuffle": True,
"num_workers": 4,
}
training_set = QADataset(train_json_path, tokenizer, max_length)
training_loader = DataLoader(training_set, **train_params)
print("Start Load Validation Data...")
val_params = {
"batch_size": batch_size,
"shuffle": False,
"num_workers": 4,
}
val_set = QADataset(val_json_path, tokenizer, max_length)
val_loader = DataLoader(val_set, **val_params)
# 日志记录
writer = SummaryWriter(logs_dir)
# 优化器
optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr)
model = model.to(device)
# 开始训练
print("Start Training...")
train_model(
model=model,
train_loader=training_loader,
val_loader=val_loader,
optimizer=optimizer,
device=device,
num_epochs=epochs,
model_output_dir=model_output_dir,
writer=writer
)
writer.close()
if __name__ == '__main__':
main()
训练过程:

训练结果后使用 tensorboard
查看下 loss
趋势:

在训练 15
个epochs
情况下,验证集的 loss
,在前 9
个 epochs
一直处于下降趋势,第10
个epochs
开始上升,考虑出现过拟合情况,后续优化可以在网络中加入部分 dropout
来随机失活。
四、模型预测使用测试
python
import torch
from model_moe import GPTMoEModel
from tokenizer import Tokenizer
def generate(model, tokenizer, text, max_length, device):
input, att_mask = tokenizer.encode(text)
input = torch.tensor(input, dtype=torch.long, device=device).unsqueeze(0)
stop = False
input_len = len(input[0])
while not stop:
if len(input[0]) - input_len > max_length:
next_symbol = tokenizer.sep_token
input = torch.cat(
[input.detach(), torch.tensor([[next_symbol]], dtype=input.dtype, device=device)], -1)
break
projected, self_attns = model(input)
prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]
next_word = prob.data[-1]
next_symbol = next_word
if next_symbol == tokenizer.sep_token:
stop = True
input = torch.cat(
[input.detach(), torch.tensor([[next_symbol]], dtype=input.dtype, device=device)], -1)
decode = tokenizer.decode(input[0].tolist())
decode = decode[len(text):]
return "".join(decode)
def main():
model_path = "output/last.pt"
vocab_path = "data/vocab.json" # 词表位置
max_length = 120 # 最大长度
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 加载分词器
tokenizer = Tokenizer(vocab_path)
# 模型参数
model_param = {
"d_model": 768, # 嵌入层大小
"d_ff": 2048, # 专家网络大小
"d_k": 64, # K 的大小
"d_v": 64, # V 的大小
"n_layers": 6, # 解码层的数量
"n_heads": 8, # 多头注意力的头数
"max_pos": 1800, # 位置编码的长度
"device": device, # 设备
"vocab_size": tokenizer.get_vocab_size(), # 词表大小
"num_experts": 8, # 8个专家
"top_k": 2, # 每个token选择2个专家
"load_balancing_weight": 0.01 # 负载均衡损失权重
}
model = GPTMoEModel(**model_param)
model.load_state_dict(torch.load(model_path))
model.to(device)
while True:
text = input("请输入:")
if not text:
continue
if text == "q":
break
res = generate(model, tokenizer, text, max_length, device)
print("AI: ", res)
if __name__ == '__main__':
main()
预测效果:

五、总结
文本仅对MOE
架构做了下的实验,其中还有很多可以优化的地方,例如可以使用RoPE
旋转位置编码、加入 RMSNormal
、尝试更先进的路由策略、加入 dropout
等等,后续你可以继续尝试进行改造和优化。