Moe机制与pytorch实现

核心思想

  • 用多个较小的 前馈网络 (FFN) 替换原本的大 FFN 层。
  • 每个 token 只经过少数几个专家 (top-k),而不是所有专家,提升计算效率。
  • 另有 共享专家 (shared experts),对所有 token 都进行计算,确保模型稳定性和表达能力。
  • 非激活专家不会被调用,因此不会参与前向和反向传播,减少计算量。

实现方式

  1. 门控 (Gating):根据输入特征选择 top-k 专家并分配权重。
  2. 路由与计算:激活专家处理对应 token,并加权聚合输出。
  3. 分布式 MoE:专家跨设备部署,输入 token 通过通信操作分发到对应专家,再聚合结果。

负载均衡 (Load Balancing)

  • 如果路由过于集中在少数专家,会导致:
    • 一部分专家过度训练,另一部分训练不足。
    • 计算资源浪费,影响模型效果。
  • 常见解决方法:
    1. 在打分阶段加入噪声,使专家选择更均匀,特别在训练早期。
    2. 添加负载均衡损失在,惩罚专家选择过度不均。
    3. 训练时随机禁用部分专家,防止过度依赖。
    4. 在打分结果中引入动态偏置,帮助提升专家利用率(Deepseek的无辅助损失函数)。

https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py

Deepseek-V3源码Moe结构简化版实现

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F


class Gate(nn.Module):
    def __init__(self, dim: int, n_experts: int, topk: int, score_func: str = "softmax", route_scale: float = 1.0):
        super().__init__()
        self.dim = dim
        self.n_experts = n_experts
        self.topk = topk
        self.score_func = score_func
        self.route_scale = route_scale

        self.weight = nn.Parameter(torch.empty(n_experts, dim))
        nn.init.xavier_uniform_(self.weight)

    def forward(self, x: torch.Tensor):
        # (batch, dim) @ (dim, n_experts)^T = (batch, n_experts)
        scores = x @ self.weight.t()
        if self.score_func == "softmax":
            scores = scores.softmax(dim=-1)
        else:
            scores = scores.sigmoid()

        # 选择 top-k 专家
        topk_scores, topk_indices = torch.topk(scores, self.topk, dim=-1)

        # 归一化
        if self.score_func == "sigmoid":
            topk_scores = topk_scores / (topk_scores.sum(dim=-1, keepdim=True) + 1e-9)

        topk_scores = topk_scores * self.route_scale
        return topk_scores, topk_indices


class Expert(nn.Module):
    def __init__(self, dim: int, inter_dim: int):
        super().__init__()
        self.w1 = nn.Linear(dim, inter_dim)
        self.w2 = nn.Linear(inter_dim, dim)

    def forward(self, x: torch.Tensor):
        return self.w2(F.silu(self.w1(x)))


class MoE(nn.Module):
    def __init__(self, dim: int, n_experts: int, topk: int, inter_dim: int):
        super().__init__()
        self.dim = dim
        self.gate = Gate(dim, n_experts, topk)
        self.experts = nn.ModuleList([Expert(dim, inter_dim) for _ in range(n_experts)])

    def forward(self, x: torch.Tensor):
        shape = x.size()
        x = x.view(-1, self.dim)

        # 得到选择的专家及权重
        weights, indices = self.gate(x)  # (batch, topk), (batch, topk)

        y = torch.zeros_like(x)
        for k in range(self.gate.topk):
            expert_idx = indices[:, k]
            expert_weight = weights[:, k]

            for i in range(self.gate.n_experts):
                mask = (expert_idx == i)
                if mask.any():
                    y[mask] += self.experts[i](x[mask]) * expert_weight[mask, None]

        return y.view(shape)
相关推荐
weixin_456904272 分钟前
OpenCV 摄像头参数控制详解
人工智能·opencv·计算机视觉
IT_陈寒19 分钟前
Vue 3.4 实战:这7个Composition API技巧让我的开发效率飙升50%
前端·人工智能·后端
张较瘦_29 分钟前
[论文阅读] AI+软件工程 | AI供应链信任革命:TAIBOM如何破解AI系统“可信难题“
论文阅读·人工智能·软件工程
合作小小程序员小小店1 小时前
web网页开发,在线%推荐算法学院培养计划,图书推荐,基于Python,FlaskWeb,用户和物品推荐MySql
python·mysql·算法·flask·推荐算法
媒体人8881 小时前
中国顶级 GEO 优化专家孟庆涛:用 15 年积淀定义 2025 年 GEO 优化新标准
人工智能·搜索引擎·chatgpt·生成式引擎优化·geo优化
那我掉的头发算什么2 小时前
【数据结构】二叉树的高频热门面试题大全
java·开发语言·数据结构·python·算法·链表·intellij idea
山海青风2 小时前
藏语自然语言处理入门 - 5 文本归类
人工智能·自然语言处理
十步杀一人_千里不留行2 小时前
和 AI 一起修 Bug 心得体会
人工智能·bug·ai编程
网安INF2 小时前
【论文阅读】-《Sparse and Imperceivable Adversarial Attacks》
论文阅读·人工智能·计算机视觉·网络安全·对抗攻击
yzx9910132 小时前
多模态分类:图文结合的智能识别与代码实战
人工智能·分类·数据挖掘