基于 PyTorch 完全从零手搓 GPT 混合专家 (MOE) 对话模型

一、基于 PyTorch 从零手搓 GPT 混合专家 (MOE) 对话模型

混合专家模型(MOE)是一种 Transformer 神经网络架构的变种,如 Switch Transformers 结构 ,它通过一个门控网络为每个输入动态地选择一小部分 "专家" 子网络进行计算,从而以稀疏激活的方式提升模型容量与计算效率。能够控制模型总参数量极大的情况下,单次前向传播的计算能保持在一个可控范围内。核心特点在于其 高参数、低计算 的稀疏性。与稠密模型在处理每个输入时激活所有参数不同,MOE模型仅激活总参数的一小部分 ,并且能够随着专家的增加容纳更加丰富的知识和更强的泛化能力。像 Mixtral 8*7B 以及 现在比较火爆的 DeepSeek 都是采用的 MOE 架构,足以证明 MOE 架构的强大潜力。

MOE 架构与传统的密集型Transformer Decoder 架构形成了鲜明对比。普通 Transformer Decoder 层通常由多头自注意力机制 MultiHeadAttention 和前馈神经网络FFN构成。这种设计简洁、稳定、易于并行化,在 GPT、BART 等模型中都广泛应用。其计算与参数激活是全量的,即每个输入 token 都会激活整个 FFN 层的所有参数,这样有个缺点就是模型扩展时计算成本线性增长。

MOE 架构则保留了自注意力模块,但将前馈神经网络FFN替换为了 专家混合 模块,也就是 MOE 层。该模块包含一个轻量级的路由门控网络 Routern 个专家网络 Experts。其中 Router 负责为每个输入 token 动态分配至 Top-K 个专家网络,专家网络通常和前馈神经网络FFN类似,未被选中的专家会被跳过计算,从而实现 稀疏激活 。

在本专栏的前面文章中,我介绍了 从零手搓一个GPT Transformer 对话大模型 ,其中整体使用的就是传统的 Transformer Decoder 架构,文章地址:

基于 PyTorch 从零手搓一个GPT Transformer 对话大模型

在这篇文章中,从零构建了 GPTModel 网络结构,以及从零构建词表,虽然总参数量只有 三千七百多万 ,不能称之为"大模型",但是整体架构十分具有学习意义,本文就在这篇文章的基础上重新构建网络架构,改为 MOE 混合专家架构所使用的训练数据集和词表就不再重复说明,直接都复用上篇文章的内容。

还有对于细节的 点积注意力层、多头注意力层、倒三角掩码器、位置编码 等等的计算过程和公式也都请参考上篇文章中的介绍,本篇内容最后实现的效果如下所示:

实验所使用的主要依赖版本如下:

shell 复制代码
torch==2.6.0
tensorboard==2.19.0

二、搭建 GPTMoEModel 网络架构

2.1 实现(点积计算、多头注意力机制 )

点积计算、多头注意力机制 实现逻辑和上篇文章中一致,如下所示,其中关键部分都做了注释说明:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
import numpy as np

# 点积计算
class ScaledDotProductAttention(nn.Module):
    def __init__(self, d_k):
        super(ScaledDotProductAttention, self).__init__()
        self.d_k = d_k

    def forward(self, q, k, v, attention_mask):
        ##
        # q: [batch_size, n_heads, len_q, d_k]
        # k: [batch_size, n_heads, len_k, d_k]
        # v: [batch_size, n_heads, len_v, d_v]
        # attn_mask: [batch_size, n_heads, seq_len, seq_len]
        ##
        # 计算每个Q与K的分数,计算出来的大小是 [batch_size, n_heads, len_q, len_q]
        scores = torch.matmul(q, k.transpose(-1, -2)) / np.sqrt(self.d_k)
        # 把被mask的地方置为无限小,softmax之后基本就是0,也就对q不起作用
        scores.masked_fill_(attention_mask, -1e9)
        attn = nn.Softmax(dim=-1)(scores)
        # 注意力后的大小 [batch_size, n_heads, len_q, d_v]
        context = torch.matmul(attn, v)
        return context, attn


# 多头注意力机制
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, d_k, d_v):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_k
        self.d_v = d_v
        self.w_q = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.w_k = nn.Linear(d_model, d_k * n_heads, bias=False)
        self.w_v = nn.Linear(d_model, d_v * n_heads, bias=False)
        self.fc = nn.Linear(n_heads * d_v, d_model, bias=False)
        self.layernorm = nn.LayerNorm(d_model)

    def forward(self, q, k, v, attention_mask):
        ##
        # q: [batch_size, seq_len, d_model]
        # k: [batch_size, seq_len, d_model]
        # v: [batch_size, seq_len, d_model]
        # attn_mask: [batch_size, seq_len, seq_len]
        ##
        # 记录原始值, 后续计算残差
        residual, batch_size = q, q.size(0)
        # 先映射 q、k、v, 然后后分头;
        # q: [batch_size, n_heads, len_q, d_k]
        q = self.w_q(q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        # k: [batch_size, n_heads, len_k, d_k]
        k = self.w_k(k).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        # v: [batch_size, n_heads, len_v(=len_k), d_v]
        v = self.w_v(v).view(batch_size, -1, self.n_heads, self.d_v).transpose(1, 2)
        # attn_mask : [batch_size, n_heads, seq_len, seq_len]
        attention_mask = attention_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1)
        # 点积注意力分数计算,  [batch_size, n_heads, len_q, d_v]
        context, attn = ScaledDotProductAttention(self.d_k)(q, k, v, attention_mask)
        # context: [batch_size, len_q, n_heads * d_v]
        context = context.transpose(1, 2).reshape(batch_size, -1, self.n_heads * self.d_v)
        # 还原为原始大小
        output = self.fc(context)
        # LN + 残差计算
        return self.layernorm(output + residual), attn
        

2.2 实现门控网络Router

门控网络就是一个轻量级的神经网络,它的作用:对每一个 token,预测其应被分配给哪些专家,并为每个选中的专家分配一个权重,用于加权融合多个专家的输出。

但是门控网络有个问题就是可能会发生 专家失衡 ,总是将样本分配给少数几个能力强或初始化的好的专家,导致其他专家得不到训练,最终整个系统退化,只有少数专家被使用。为了解决这个问题,可以在路由时,增加一个可训练的噪声,另外还需要引入一个辅助损失,也就是负载均衡损失,这里负载均衡损失参考 Mixtral 模型的做法。

实现逻辑如下:

python 复制代码
# 门控网络
class Router(nn.Module):
    def __init__(self, d_model, num_experts, top_k=2):
        super(Router, self).__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        self.gate = nn.Linear(d_model, num_experts)
        # 用于负载均衡的噪声
        self.noise_linear = nn.Linear(d_model, num_experts)

    def forward(self, x):
        logits = self.gate(x)

        # 训练时添加噪声
        if self.training:
            noise = torch.randn_like(logits).to(x.device)
            noise = self.noise_linear(x) * noise
            noisy_logits = logits + noise
        else:
            noisy_logits = logits

        gates_prob = F.softmax(noisy_logits, dim=-1)
        # Top-k 选择
        top_k_probs, top_k_indices = torch.topk(gates_prob, self.top_k, dim=-1)
        # 归一化,确保被选中的专家的权重之和为1
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)
        # 负载均衡损失
        load_balancing_loss = self.compute_load_balancing_loss(gates_prob, top_k_indices)
        return top_k_probs, top_k_indices, load_balancing_loss

    def compute_load_balancing_loss(self, gates_prob, top_k_indices):
        """ 负载均衡损失:num_experts * sum ( 每个专家的平均概率 * 每个专家选中的概率 )"""
        batch_size, seq_len, _ = gates_prob.shape

        # 计算每个专家的平均概率
        router_prob_per_expert = gates_prob.mean(dim=(0, 1))

        # 计算每个专家理想被分配到的概率
        expert_mask = torch.zeros_like(gates_prob)
        expert_mask.scatter_(2, top_k_indices, 1)
        tokens_per_expert = expert_mask.float().mean(dim=(0, 1))

        # 辅助损失
        return self.num_experts * torch.sum(tokens_per_expert * router_prob_per_expert)

2.3 实现专家网络

每个专家相当于是一个前馈神经网络, 这里模拟SwiGLU FFN

python 复制代码
# 专家网络
class Expert(nn.Module):
    def __init__(self, d_model, d_ff):
        super(Expert, self).__init__()
        self.w1 = nn.Linear(d_model, d_ff, bias=False)
        self.w2 = nn.Linear(d_model, d_ff, bias=False)
        self.w_out = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x):
        return self.w_out(F.silu(self.w1(x)) * self.w2(x))

2.4 整合Router和专家层,实现 MOE 层

包括一个 门控Router,和多个专家组成。Router 输出 top-k 专家 ID 和权重,然后将 token 输入到对应专家;然后加权融合输出

这里为了可以更加利于理解,在做专家选择时,用的双重循环 + 逐专家判断,可能无法高效的利用GPU的并行计算,后续可以参考 Mixtral 模型的写法更高效的运行。

python 复制代码
# MOE层
class MoELayer(nn.Module):
    def __init__(self, d_model, d_ff, num_experts=8, top_k=2):
        super(MoELayer, self).__init__()
        self.d_model = d_model
        self.num_experts = num_experts
        self.top_k = top_k
        # 门控路由,决定哪些专家被激活
        self.router = Router(d_model, num_experts, top_k)
        # 创建多个专家
        self.experts = nn.ModuleList([
            Expert(d_model, d_ff) for _ in range(num_experts)
        ])
        # Layer Norm
        self.layernorm = nn.LayerNorm(d_model)

    def forward(self, x):
        """
        x: [batch_size, seq_len, d_model]
        """
        residual = x

        batch_size, seq_len, d_model = x.shape
        # 获取路由决策
        # gates: [batch_size, seq_len, top_k]
        # selected_experts: [batch_size, seq_len, top_k]
        gates, selected_experts, load_balancing_loss = self.router(x)

        # 初始化输出
        output = torch.zeros_like(x)

        # 对每个token应用选中的专家
        for i in range(self.top_k):
            # 获取当前专家索引
            expert_idx = selected_experts[:, :, i]  # [batch_size, seq_len]
            # 获取当前权重
            expert_gate = gates[:, :, i]  # [batch_size, seq_len]
            # 对每个专家进行计算
            for expert_id in range(self.num_experts):
                # 找出选择了当前专家的token位置
                mask = (expert_idx == expert_id).unsqueeze(-1)  # [batch_size, seq_len, 1]
                if mask.any():
                    # 获取分配给当前专家的tokens
                    expert_input = x * mask  # [batch_size, seq_len, d_model]
                    # 应用专家
                    expert_output = self.experts[expert_id](expert_input)  # [batch_size, seq_len, d_model]
                    # 加权输出
                    weighted_output = expert_output * expert_gate.unsqueeze(-1) * mask
                    output += weighted_output
        # 残差连接和Layer Norm
        output = self.layernorm(output + residual)
        return output, load_balancing_loss

2.5 实现 MOE 解码层

和传统的 Transformer Decoder Layer 类似,只需将 前馈网络FFN 换成 MOE 层。

python 复制代码
# 解码层
class MoEDecoderLayer(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, d_k, d_v, num_experts=8, top_k=2):
        super(MoEDecoderLayer, self).__init__()
        # 多头注意力层
        self.attention = MultiHeadAttention(d_model, n_heads, d_k, d_v)
        # MoE
        self.pos_ffn = MoELayer(d_model, d_ff, num_experts, top_k)

    def forward(self, inputs, attention_mask):
        # 多头注意力
        outputs, self_attn = self.attention(inputs, inputs, inputs, attention_mask)
        # MoE
        outputs, load_balancing_loss = self.pos_ffn(outputs)
        return outputs, self_attn, load_balancing_loss

2.6 堆积MOE解码层,实现 MOE 解码器

将多个解码层堆叠,形成一个特征提取链。为了便于和上篇文章做效果对比,这里位置编码依然使用 GPT2 的做法,同样也需要一个倒三角掩码器,防止模型看到未来的信息。

掩码过程如下所示:

shell 复制代码
原始注意力分数矩阵(无掩码):
[[q1k1, q1k2, q1k3, q1k4],
 [q2k1, q2k2, q3k3, q3k4],
 [q3k1, q3k2, q3k3, q3k4],
 [q4k1, q4k2, q4k3, q4k4]]

上三角掩码器:
[[0, 1, 1, 1],
 [0, 0, 1, 1],
 [0, 0, 0, 1],
 [0, 0, 0, 0]]

应用掩码后的分数矩阵:
[[q1k1, -inf, -inf, -inf],
 [q2k1, q2k2, -inf, -inf],
 [q3k1, q3k2, q3k3, -inf],
 [q4k1, q4k2, q4k3, q4k4]]

实现逻辑如下:

python 复制代码
# 位置编码,这里使用GPT2的做法
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_pos, device):
        super(PositionalEncoding, self).__init__()
        self.device = device
        self.pos_embedding = nn.Embedding(max_pos, d_model)

    def forward(self, inputs):
        seq_len = inputs.size(1)
        pos = torch.arange(seq_len, dtype=torch.long, device=self.device)
        # [seq_len] -> [batch_size, seq_len]
        pos = pos.unsqueeze(0).expand_as(inputs)
        return self.pos_embedding(pos)
        
# 获取pad掩码器
def get_attn_pad_mask(attention_mask):
    batch_size, len_seq = attention_mask.size()
    attention_mask = attention_mask.data.eq(0).unsqueeze(1)
    # 注意力分数的大小是 [batch_size, n_heads, len_q, len_q]
    # 所以这里要转换成 [batch_size, len_seq, len_seq] 大小
    return attention_mask.expand(batch_size, len_seq, len_seq)

# 获取倒三角掩码器,防止模型看到未来的信息
def get_attn_subsequence_mask(seq, device):
    # 注意力分数的大小是 [batch_size, n_heads, len_seq, len_seq]
    # 所以这里要生成 [batch_size, len_seq, len_seq] 大小
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    # 生成一个上三角矩阵
    subsequence_mask = np.triu(np.ones(attn_shape), k=1)
    subsequence_mask = torch.from_numpy(subsequence_mask).byte()
    subsequence_mask = subsequence_mask.to(device)
    return subsequence_mask

# 解码器
class MoEDecoder(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, d_k, d_v, vocab_size, max_pos, n_layers,
                 device, num_experts=8, top_k=2):
        super(MoEDecoder, self).__init__()
        self.device = device
        # 将Token转为向量
        self.embedding = nn.Embedding(vocab_size, d_model)
        # 位置编码
        self.pos_encoding = PositionalEncoding(d_model, max_pos, device)

        # 创建MOE层
        self.layers = nn.ModuleList()
        for i in range(n_layers):
            self.layers.append(
                MoEDecoderLayer(
                    d_model, n_heads, d_ff, d_k, d_v,
                    num_experts, top_k
                )
            )

    def forward(self, inputs, attention_mask):
        # 嵌入和位置编码
        outputs = self.embedding(inputs) + self.pos_encoding(inputs)

        # 生成掩码
        subsequence_mask = get_attn_subsequence_mask(inputs, self.device)
        if attention_mask is not None:
            attention_mask = get_attn_pad_mask(attention_mask)
            attention_mask = torch.gt((attention_mask + subsequence_mask), 0)
        else:
            attention_mask = subsequence_mask.bool()

        # 计算每一层的结果
        self_attns = []
        total_load_balancing_loss = 0.0
        for layer in self.layers:
            layer_output = layer(outputs, attention_mask)
            outputs, self_attn, load_balancing_loss = layer_output
            total_load_balancing_loss += load_balancing_loss
            self_attns.append(self_attn)

        return outputs, self_attns, total_load_balancing_loss

2.7 整合解码器,实现 GPTMoEModel

这里需要注意,损失函数要考虑前面的负载均衡损失,因此整体的损失应该是两者之和。

python 复制代码
# GPT MOE模型
class GPTMoEModel(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, d_k, d_v, vocab_size, max_pos, n_layers,
                 device, num_experts=8, top_k=2, load_balancing_weight=0.01):
        super(GPTMoEModel, self).__init__()
        self.load_balancing_weight = load_balancing_weight
        # 解码器
        self.decoder = MoEDecoder(
            d_model, n_heads, d_ff, d_k, d_v, vocab_size, max_pos, n_layers,
            device, num_experts, top_k
        )
        # 映射为词表大小
        self.projection = nn.Linear(d_model, vocab_size)

    def forward(self, inputs, attention_mask=None, targets=None):
        # 前向传播
        outputs, self_attns, load_balancing_loss = self.decoder(inputs, attention_mask)
        # 投影到词表
        logits = self.projection(outputs)
        logits = logits.view(-1, logits.size(-1))
        if targets is not None:
            # 负载均衡损失
            load_balancing_loss = load_balancing_loss * self.load_balancing_weight
            # 任务损失
            lm_loss = F.cross_entropy(logits, targets.view(-1), ignore_index=0)
            # MOE架构的总损失是任务损失和负载均衡损失的加权和
            total_loss = lm_loss + load_balancing_loss
            return logits, self_attns, total_loss
        return logits, self_attns

2.8 整体网络架构

以上整体网络代码放在 model_moe.py 中。

python 复制代码
import torch
from model_moe import GPTMoEModel

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # 模型参数
    model_param = {
        "d_model": 768,  # 嵌入层大小
        "d_ff": 2048,  # 这是为专家网络大小
        "d_k": 64,  # K 的大小
        "d_v": 64,  # V 的大小
        "n_layers": 6,  # 解码层的数量
        "n_heads": 8,  # 多头注意力的头数
        "max_pos": 1800,  # 位置编码的长度
        "device": device,  # 设备
        "vocab_size": 4825,  # 词表大小,上篇文章中构建的词表大小
        "num_experts": 8,  # 8个专家
        "top_k": 2,  # 每个token选择2个专家
        "load_balancing_weight": 0.01  # 负载均衡损失权重
    }
    model = GPTMoEModel(**model_param)
    total_params = sum(p.numel() for p in model.parameters())
    print(model)
    print("total_params: ", total_params)


if __name__ == '__main__':
    main()

执行输出:

json 复制代码
GPTMoEModel(
  (decoder): MoEDecoder(
    (embedding): Embedding(4825, 768)
    (pos_encoding): PositionalEncoding(
      (pos_embedding): Embedding(1800, 768)
    )
    (layers): ModuleList(
      (0-5): 6 x MoEDecoderLayer(
        (attention): MultiHeadAttention(
          (w_q): Linear(in_features=768, out_features=512, bias=False)
          (w_k): Linear(in_features=768, out_features=512, bias=False)
          (w_v): Linear(in_features=768, out_features=512, bias=False)
          (fc): Linear(in_features=512, out_features=768, bias=False)
          (layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (pos_ffn): MoELayer(
          (router): Router(
            (gate): Linear(in_features=768, out_features=8, bias=True)
            (noise_linear): Linear(in_features=768, out_features=8, bias=True)
          )
          (experts): ModuleList(
            (0-7): 8 x Expert(
              (fc1): Linear(in_features=768, out_features=2048, bias=False)
              (fc2): Linear(in_features=2048, out_features=768, bias=False)
              (activation): ReLU()
            )
          )
          (layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
      )
    )
  )
  (projection): Linear(in_features=768, out_features=4825, bias=True)
)
total_params:  173028409

整体参数量为 1.73亿0.17B 大小,相比上篇文章构建的网络,能容纳更多的知识。

三、模型训练

这里训练集和训练过程基本和上篇文章一致,同时训练数据集中同样增加一些自定义的模型特色内容,追加几条身份的数据在里面:

json 复制代码
{"question": "你是谁", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你叫什么", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你的名字是什么", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你叫啥", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你名字是啥", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你是什么身份", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你的全名是什么", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你自称什么", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你的称号是什么", "answer": "我是小毕超,一个简易的小助手"}
{"question": "你的昵称是什么", "answer": "我是小毕超,一个简易的小助手"}

3.1 构建 Dataset

qa_dataset.py

python 复制代码
# -*- coding: utf-8 -*-
from torch.utils.data import Dataset
import torch
import json
import numpy as np


class QADataset(Dataset):
    def __init__(self, data_path, tokenizer, max_length) -> None:
        super().__init__()
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = []
        if data_path:
            with open(data_path, "r", encoding='utf-8') as f:
                for line in f:
                    if not line or line == "":
                        continue
                    json_line = json.loads(line)
                    question = json_line["question"]
                    answer = json_line["answer"]
                    self.data.append({
                        "question": question,
                        "answer": answer
                    })
        print("data load , size:", len(self.data))

    def preprocess(self, question, answer):
        encode, att_mask = self.tokenizer.encode(question, answer, max_length=self.max_length, pad_to_max_length=True)
        input_ids = encode[:-1]
        att_mask = att_mask[:-1]
        labels = encode[1:]
        return input_ids, att_mask, labels

    def __getitem__(self, index):
        item_data = self.data[index]
        input_ids, att_mask, labels = self.preprocess(**item_data)
        return {
            "input_ids": torch.LongTensor(np.array(input_ids)),
            "attention_mask": torch.LongTensor(np.array(att_mask)),
            "labels": torch.LongTensor(np.array(labels))
        }

    def __len__(self):
        return len(self.data)

3.2 训练

python 复制代码
# -*- coding: utf-8 -*-
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tokenizer import Tokenizer
from model_moe import GPTMoEModel
from qa_dataset import QADataset
from tqdm import tqdm
import time, sys, os


def train_model(model, train_loader, val_loader, optimizer,
                device, num_epochs, model_output_dir, writer):
    batch_step = 0
    best_val_loss = float('inf')
    for epoch in range(num_epochs):
        time1 = time.time()
        model.train()
        for index, data in enumerate(tqdm(train_loader, file=sys.stdout, desc="Train Epoch: " + str(epoch))):
            input_ids = data['input_ids'].to(device, dtype=torch.long)
            attention_mask = data['attention_mask'].to(device, dtype=torch.long)
            labels = data['labels'].to(device, dtype=torch.long)
            optimizer.zero_grad()
            outputs, dec_self_attns, loss = model(input_ids, attention_mask, labels)
            loss.backward()
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()
            writer.add_scalar('Loss/train', loss, batch_step)
            batch_step += 1
            # 50轮打印一次 loss
            if index % 50 == 0 or index == len(train_loader) - 1:
                time2 = time.time()
                tqdm.write(
                    f"{index}, epoch: {epoch} -loss: {str(loss)} ; lr: {optimizer.param_groups[0]['lr']} ;each step's time spent: {(str(float(time2 - time1) / float(index + 0.0001)))}")
        # 验证
        model.eval()
        val_loss = validate_model(model, device, val_loader)
        writer.add_scalar('Loss/val', val_loss, epoch)
        print(f"val loss: {val_loss} , epoch: {epoch}")
        # 保存最优模型
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_path = os.path.join(model_output_dir, "best.pt")
            print("Save Best Model To ", best_model_path, ", epoch: ", epoch)
            torch.save(model.state_dict(), best_model_path)
        # 保存当前模型
        last_model_path = os.path.join(model_output_dir, "last.pt")
        print("Save Last Model To ", last_model_path, ", epoch: ", epoch)
        torch.save(model.state_dict(), last_model_path)


def validate_model(model, device, val_loader):
    running_loss = 0.0
    with torch.no_grad():
        for _, data in enumerate(tqdm(val_loader, file=sys.stdout, desc="Validation Data")):
            input_ids = data['input_ids'].to(device, dtype=torch.long)
            attention_mask = data['attention_mask'].to(device, dtype=torch.long)
            labels = data['labels'].to(device, dtype=torch.long)
            outputs, dec_self_attns, loss = model(input_ids, attention_mask, labels)
            running_loss += loss.item()
    return running_loss / len(val_loader)


def main():
    train_json_path = "data/train.json"  # 训练集
    val_json_path = "data/val.json"  # 验证集
    vocab_path = "data/vocab.json"  # 词表位置
    max_length = 120  # 最大长度
    epochs = 15  # 迭代周期
    batch_size = 128  # 训练一个批次的大小
    lr = 1e-4  # 学习率
    model_output_dir = "output"  # 模型保存目录
    logs_dir = "logs"  # 日志记录目标
    # 设备
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # 加载分词器
    tokenizer = Tokenizer(vocab_path)
    # 模型参数
    model_param = {
        "d_model": 768,  # 嵌入层大小
        "d_ff": 2048,  # 专家网络大小
        "d_k": 64,  # K 的大小
        "d_v": 64,  # V 的大小
        "n_layers": 6,  # 解码层的数量
        "n_heads": 8,  # 多头注意力的头数
        "max_pos": 1800,  # 位置编码的长度
        "device": device,  # 设备
        "vocab_size": tokenizer.get_vocab_size(),  # 词表大小
        "num_experts" :8,  # 8个专家
        "top_k" : 2,  # 每个token选择2个专家
        "load_balancing_weight" : 0.01  # 负载均衡损失权重
    }
    model = GPTMoEModel(**model_param)
    print("Start Load Train Data...")
    train_params = {
        "batch_size": batch_size,
        "shuffle": True,
        "num_workers": 4,
    }
    training_set = QADataset(train_json_path, tokenizer, max_length)
    training_loader = DataLoader(training_set, **train_params)
    print("Start Load Validation Data...")
    val_params = {
        "batch_size": batch_size,
        "shuffle": False,
        "num_workers": 4,
    }
    val_set = QADataset(val_json_path, tokenizer, max_length)
    val_loader = DataLoader(val_set, **val_params)
    # 日志记录
    writer = SummaryWriter(logs_dir)
    # 优化器
    optimizer = torch.optim.AdamW(params=model.parameters(), lr=lr)
    model = model.to(device)
    # 开始训练
    print("Start Training...")
    train_model(
        model=model,
        train_loader=training_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        device=device,
        num_epochs=epochs,
        model_output_dir=model_output_dir,
        writer=writer
    )
    writer.close()


if __name__ == '__main__':
    main()

训练过程:

训练结果后使用 tensorboard 查看下 loss 趋势:

在训练 15epochs 情况下,验证集的 loss,在前 9epochs一直处于下降趋势,第10epochs开始上升,考虑出现过拟合情况,后续优化可以在网络中加入部分 dropout 来随机失活。

四、模型预测使用测试

python 复制代码
import torch

from model_moe import GPTMoEModel
from tokenizer import Tokenizer


def generate(model, tokenizer, text, max_length, device):
    input, att_mask = tokenizer.encode(text)
    input = torch.tensor(input, dtype=torch.long, device=device).unsqueeze(0)
    stop = False
    input_len = len(input[0])
    while not stop:
        if len(input[0]) - input_len > max_length:
            next_symbol = tokenizer.sep_token
            input = torch.cat(
                [input.detach(), torch.tensor([[next_symbol]], dtype=input.dtype, device=device)], -1)
            break
        projected, self_attns = model(input)
        prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1]
        next_word = prob.data[-1]
        next_symbol = next_word
        if next_symbol == tokenizer.sep_token:
            stop = True
        input = torch.cat(
            [input.detach(), torch.tensor([[next_symbol]], dtype=input.dtype, device=device)], -1)
    decode = tokenizer.decode(input[0].tolist())
    decode = decode[len(text):]
    return "".join(decode)


def main():
    model_path = "output/last.pt"
    vocab_path = "data/vocab.json"  # 词表位置
    max_length = 120  # 最大长度
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # 加载分词器
    tokenizer = Tokenizer(vocab_path)
    # 模型参数
    model_param = {
        "d_model": 768,  # 嵌入层大小
        "d_ff": 2048,  # 专家网络大小
        "d_k": 64,  # K 的大小
        "d_v": 64,  # V 的大小
        "n_layers": 6,  # 解码层的数量
        "n_heads": 8,  # 多头注意力的头数
        "max_pos": 1800,  # 位置编码的长度
        "device": device,  # 设备
        "vocab_size": tokenizer.get_vocab_size(),  # 词表大小
        "num_experts": 8,  # 8个专家
        "top_k": 2,  # 每个token选择2个专家
        "load_balancing_weight": 0.01  # 负载均衡损失权重
    }
    model = GPTMoEModel(**model_param)
    model.load_state_dict(torch.load(model_path))
    model.to(device)

    while True:
        text = input("请输入:")
        if not text:
            continue
        if text == "q":
            break
        res = generate(model, tokenizer, text, max_length, device)
        print("AI: ", res)


if __name__ == '__main__':
    main()

预测效果:

五、总结

文本仅对MOE架构做了下的实验,其中还有很多可以优化的地方,例如可以使用RoPE旋转位置编码、加入 RMSNormal 、尝试更先进的路由策略、加入 dropout 等等,后续你可以继续尝试进行改造和优化。

相关推荐
ygyqinghuan2 小时前
Pytorch 数据处理
人工智能·pytorch·python
nju_spy4 小时前
南京大学 LLM开发基础(二)大语言模型解析 -- 基于HF LlaMA实现的讲解
人工智能·pytorch·深度学习·大模型·多头注意力·rmsnorm·位置掩码
Y200309165 小时前
PyTorch 实现 CIFAR10 图像分类知识点总结
人工智能·pytorch·分类
姜—姜5 小时前
使用 PyTorch 框架对 CIFAR - 10 数据集进行CNN分类
pytorch·分类·cnn
凳子(刘博浩)5 小时前
使用 PyTorch 实现 CIFAR-10 图像分类:从数据加载到模型训练全流程
人工智能·pytorch·分类
史锦彪8 小时前
PyTorch 实现 CIFAR-10 图像分类:从基础 CNN 到全局平均池化的探索
pytorch·分类·cnn
41号学员8 小时前
构建神经网络的两大核心工具
人工智能·pytorch·深度学习
Wah-Aug10 小时前
PyTorch 模型评估与全局平均池化的应用实践
人工智能·pytorch·python
诸葛箫声10 小时前
基于PyTorch的CIFAR-10图像分类项目总结(2)
人工智能·pytorch·分类