MoeDistributeDispatch算子代码阅读

图像定格,一声尖利的呜叫响起,排险者告诉人们,预警系统报警了。

"为什么?"总工程师不解地问。

"这个原始人仰望星空的时间超过了预誓阀值,已对宇宙表现出了充分的好奇。到此为止,已在不同的地点观察到了十例这样的超限事件,符合报警条件。"

"如果我没记错的话,你前面说过,只有当有能力产生创世能级能量、过程的文明出现时,预警系统才会报警。"

"你们看到的不正是这样一个文明吗?

------《朝闻道》

前言

为啥写这样一篇文章,倒不是因为难以抑制的好奇心。一个是闲,另一个是功利心,看看能不能做点通算融合的工作。MoeDistributeDispatch这个算子,我断断续续看了三周。单纯看代码,一片迷茫。后来在cpu上模拟了对应的几个接口:DataCopy,Duplicate,ReduceSum, Add,Sub,Mins,打印值帮助代码理解。

MoeDistributeDispatch

文档

https://gitcode.com/cann/ops-transformer/tree/master/mc2/moe_distribute_dispatch

主要流程

c++ 复制代码
template <TemplateMC2TypeClass>
__aicore__ inline void MoeDistributeDispatch<TemplateMC2TypeFunc>::Process()
{
    if ASCEND_IS_AIV { // 全aiv处理
        AlltoAllDispatch();
        SetStatus();
        WaitDispatch();
        LocalWindowCopy();
        if constexpr (IsNeedAllgather) {
            AllGatherSetStatusAndWait();
            AllgatherProcessOut();
        }
        UpdateTokenNumsOut();
    }
}

测试数据说明

为了理解,在代码测试中,配置moeExpertNum_为32,假设使用两张卡(epWorldSize_=2),每张卡上使用的aiv的核数(aivNum_)为4。

0到15对应的专家权重存储在rank0;16到31对应的专家权重存储在rank1。

moeExpertNumPerRank_ = moeExpertNum_/epWorldSize_。

假设每张卡上输入的expertIds的维度,axisBS_=4,axisK_=12。

rank=0上输入的expertIds为:

--------expertIds------------

11 17 29 12 24 23 1 0 16 30 18 8

2 14 11 3 30 21 12 0 10 9 6 31

22 27 30 21 1 24 17 11 3 19 2 29

16 13 7 27 6 29 5 22 24 19 23 2

rank=1上输入的expertIds为:

--------expertIds------------

24 6 29 5 20 0 16 27 9 31 11 17

27 28 15 1 25 26 11 5 3 2 4 10

28 22 6 13 3 30 18 15 27 31 7 5

31 2 21 23 20 3 6 8 7 12 5 15

AlltoAllDispatch

SendToMoeExpert函数将hidden states(xGMTensor_或者xGMTensor_量化后的值)发送到对应的dstExpertId对应的HBM中(dstWinGMTensor)。

复制代码
template <TemplateMC2TypeClass>
__aicore__ inline void MoeDistributeDispatch<TemplateMC2TypeFunc>::SendToMoeExpert()
{
    for (int32_t tokenIndex = startTokenId; tokenIndex < endTokenId; ++tokenIndex) {
        uint32_t dstExpertId = expertIdsTensor_(tokenIndex);
        int32_t curExpertCnt = 0;
        if (tokenIndex > 0) {
            CalTokenSendExpertCnt(dstExpertId, tokenIndex, curExpertCnt);
        }
        expertCountTensor_(tokenIndex - startTokenId) = curExpertCnt;
        uint32_t tempRankId = dstExpertId / moeExpertNumPerRank_ + sharedExpertRankNum_;
        GM_ADDR rankGM = (__gm__ uint8_t*)(GetWindAddrByRankId(COMM_EP_IDX, tempRankId) +
                        (expertPerSizeOnWin_ * (epRankId_ * moeExpertNumPerRank_ + dstExpertId % moeExpertNumPerRank_))
                        + hCommuSize_ * curExpertCnt); // 计算地址偏移
        dstWinGMTensor.SetGlobalBuffer((__gm__ ExpandXOutType*)rankGM);
    }
    GlobalTensor<int32_t> expandIdxGMTensor;
    expandIdxGMTensor.SetGlobalBuffer((__gm__ int32_t*)(expandIdxOutGM_ + startTokenId * sizeof(int32_t)));
    DataCopyExtParams expertIdsCntParams = {1U, static_cast<uint32_t>(sendTokenNum * sizeof(uint32_t)), 0U, 0U, 0U};
    DataCopyPad(expandIdxGMTensor, expertCountTensor_, expertIdsCntParams);
}

每个rank上,接收数据需要的存储区的大小:expertPerSizeOnWin_ * moeExpertNumPerRank_ * epWorldSize_。

每个专家配置的存储大小为expertPerSizeOnWin_ = axisMaxBS_ * axisH_ * sizeof(XType)。针对expertIds,每一行的id不会重复。同一个专家id的计数最大为axisMaxBS_ 。

针对rank0的地址范围:[0, expertPerSizeOnWin_ * moeExpertNumPerRank_),存储来自rank0上expertIds在[0,16)范围对应的hidden states。

针对rank0的地址范围:[expertPerSizeOnWin_ * moeExpertNumPerRank_, expertPerSizeOnWin_ * moeExpertNumPerRank_ * epWorldSize_),存储来自rank1上expertIds在[0,16)范围对应的hidden states。需要RDMA通信。

针对rank1的地址范围:[0, expertPerSizeOnWin_ * moeExpertNumPerRank_),存储来自rank0上expertIds在[16,32)范围对应的hidden states。需要RDMA通信。

针对rank1的地址范围:[expertPerSizeOnWin_ * moeExpertNumPerRank_, expertPerSizeOnWin_ * moeExpertNumPerRank_ * epWorldSize_),存储来自rank1上expertIds在[16,32)范围对应的hidden states。

测试程序打印信息

dstRankId:hidden states发送的目的rank。

tokenIndex:核上负责处理的token索引。

dstExpertId :expertIdsTensor_(tokenIndex)

curExpertCnt:参考函数CalTokenSendExpertCnt。
rank0上四个核的打印信息

#######SendToMoeExpert##########rank: 0, aiv_id: 0

dstRankId: 0, tokenIndex: 0, dstExpertId 11, curExpertCnt: 0, rankGMOffset: 720896

dstRankId: 1, tokenIndex: 1, dstExpertId 17, curExpertCnt: 0, rankGMOffset: 65536

dstRankId: 1, tokenIndex: 2, dstExpertId 29, curExpertCnt: 0, rankGMOffset: 851968

dstRankId: 0, tokenIndex: 3, dstExpertId 12, curExpertCnt: 0, rankGMOffset: 786432

dstRankId: 1, tokenIndex: 4, dstExpertId 24, curExpertCnt: 0, rankGMOffset: 524288

dstRankId: 1, tokenIndex: 5, dstExpertId 23, curExpertCnt: 0, rankGMOffset: 458752

dstRankId: 0, tokenIndex: 6, dstExpertId 1, curExpertCnt: 0, rankGMOffset: 65536

dstRankId: 0, tokenIndex: 7, dstExpertId 0, curExpertCnt: 0, rankGMOffset: 0

dstRankId: 1, tokenIndex: 8, dstExpertId 16, curExpertCnt: 0, rankGMOffset: 0

dstRankId: 1, tokenIndex: 9, dstExpertId 30, curExpertCnt: 0, rankGMOffset: 917504

dstRankId: 1, tokenIndex: 10, dstExpertId 18, curExpertCnt: 0, rankGMOffset: 131072

dstRankId: 0, tokenIndex: 11, dstExpertId 8, curExpertCnt: 0, rankGMOffset: 524288

#######SendToMoeExpert##########rank: 0, aiv_id: 1

dstRankId: 0, tokenIndex: 12, dstExpertId 2, curExpertCnt: 0, rankGMOffset: 131072

dstRankId: 0, tokenIndex: 13, dstExpertId 14, curExpertCnt: 0, rankGMOffset: 917504

dstRankId: 0, tokenIndex: 14, dstExpertId 11, curExpertCnt: 1, rankGMOffset: 722048

dstRankId: 0, tokenIndex: 15, dstExpertId 3, curExpertCnt: 0, rankGMOffset: 196608

dstRankId: 1, tokenIndex: 16, dstExpertId 30, curExpertCnt: 1, rankGMOffset: 918656

dstRankId: 1, tokenIndex: 17, dstExpertId 21, curExpertCnt: 0, rankGMOffset: 327680

dstRankId: 0, tokenIndex: 18, dstExpertId 12, curExpertCnt: 1, rankGMOffset: 787584

dstRankId: 0, tokenIndex: 19, dstExpertId 0, curExpertCnt: 1, rankGMOffset: 1152

dstRankId: 0, tokenIndex: 20, dstExpertId 10, curExpertCnt: 0, rankGMOffset: 655360

dstRankId: 0, tokenIndex: 21, dstExpertId 9, curExpertCnt: 0, rankGMOffset: 589824

dstRankId: 0, tokenIndex: 22, dstExpertId 6, curExpertCnt: 0, rankGMOffset: 393216

dstRankId: 1, tokenIndex: 23, dstExpertId 31, curExpertCnt: 0, rankGMOffset: 983040

#######SendToMoeExpert##########rank: 0, aiv_id: 2

dstRankId: 1, tokenIndex: 24, dstExpertId 22, curExpertCnt: 0, rankGMOffset: 393216

dstRankId: 1, tokenIndex: 25, dstExpertId 27, curExpertCnt: 0, rankGMOffset: 720896

dstRankId: 1, tokenIndex: 26, dstExpertId 30, curExpertCnt: 2, rankGMOffset: 919808

dstRankId: 1, tokenIndex: 27, dstExpertId 21, curExpertCnt: 1, rankGMOffset: 328832

dstRankId: 0, tokenIndex: 28, dstExpertId 1, curExpertCnt: 1, rankGMOffset: 66688

dstRankId: 1, tokenIndex: 29, dstExpertId 24, curExpertCnt: 1, rankGMOffset: 525440

dstRankId: 1, tokenIndex: 30, dstExpertId 17, curExpertCnt: 1, rankGMOffset: 66688

dstRankId: 0, tokenIndex: 31, dstExpertId 11, curExpertCnt: 2, rankGMOffset: 723200

dstRankId: 0, tokenIndex: 32, dstExpertId 3, curExpertCnt: 1, rankGMOffset: 197760

dstRankId: 1, tokenIndex: 33, dstExpertId 19, curExpertCnt: 0, rankGMOffset: 196608

dstRankId: 0, tokenIndex: 34, dstExpertId 2, curExpertCnt: 1, rankGMOffset: 132224

dstRankId: 1, tokenIndex: 35, dstExpertId 29, curExpertCnt: 1, rankGMOffset: 853120

#######SendToMoeExpert##########rank: 0, aiv_id: 3

dstRankId: 1, tokenIndex: 36, dstExpertId 16, curExpertCnt: 1, rankGMOffset: 1152

dstRankId: 0, tokenIndex: 37, dstExpertId 13, curExpertCnt: 0, rankGMOffset: 851968

dstRankId: 0, tokenIndex: 38, dstExpertId 7, curExpertCnt: 0, rankGMOffset: 458752

dstRankId: 1, tokenIndex: 39, dstExpertId 27, curExpertCnt: 1, rankGMOffset: 722048

dstRankId: 0, tokenIndex: 40, dstExpertId 6, curExpertCnt: 1, rankGMOffset: 394368

dstRankId: 1, tokenIndex: 41, dstExpertId 29, curExpertCnt: 2, rankGMOffset: 854272

dstRankId: 0, tokenIndex: 42, dstExpertId 5, curExpertCnt: 0, rankGMOffset: 327680

dstRankId: 1, tokenIndex: 43, dstExpertId 22, curExpertCnt: 1, rankGMOffset: 394368

dstRankId: 1, tokenIndex: 44, dstExpertId 24, curExpertCnt: 2, rankGMOffset: 526592

dstRankId: 1, tokenIndex: 45, dstExpertId 19, curExpertCnt: 1, rankGMOffset: 197760

dstRankId: 1, tokenIndex: 46, dstExpertId 23, curExpertCnt: 1, rankGMOffset: 459904

dstRankId: 0, tokenIndex: 47, dstExpertId 2, curExpertCnt: 2, rankGMOffset: 133376

四个核执行完SendToMoeExpert,打印expandIdxGMTensor的数据:

--------expandIdxGMTensor------------

0 0 0 0 0 0 0 0 0 0 0 0

0 0 1 0 1 0 1 1 0 0 0 0

0 0 2 1 1 1 1 2 1 0 1 1

1 0 0 1 1 2 0 1 2 1 1 2

rank1上四个核的打印信息:

#######SendToMoeExpert##########rank: 1, aiv_id: 0

dstRankId: 1, tokenIndex: 0, dstExpertId 24, curExpertCnt: 0, rankGMOffset: 1572864

dstRankId: 0, tokenIndex: 1, dstExpertId 6, curExpertCnt: 0, rankGMOffset: 1441792

dstRankId: 1, tokenIndex: 2, dstExpertId 29, curExpertCnt: 0, rankGMOffset: 1900544

dstRankId: 0, tokenIndex: 3, dstExpertId 5, curExpertCnt: 0, rankGMOffset: 1376256

dstRankId: 1, tokenIndex: 4, dstExpertId 20, curExpertCnt: 0, rankGMOffset: 1310720

dstRankId: 0, tokenIndex: 5, dstExpertId 0, curExpertCnt: 0, rankGMOffset: 1048576

dstRankId: 1, tokenIndex: 6, dstExpertId 16, curExpertCnt: 0, rankGMOffset: 1048576

dstRankId: 1, tokenIndex: 7, dstExpertId 27, curExpertCnt: 0, rankGMOffset: 1769472

dstRankId: 0, tokenIndex: 8, dstExpertId 9, curExpertCnt: 0, rankGMOffset: 1638400

dstRankId: 1, tokenIndex: 9, dstExpertId 31, curExpertCnt: 0, rankGMOffset: 2031616

dstRankId: 0, tokenIndex: 10, dstExpertId 11, curExpertCnt: 0, rankGMOffset: 1769472

dstRankId: 1, tokenIndex: 11, dstExpertId 17, curExpertCnt: 0, rankGMOffset: 1114112

#######SendToMoeExpert##########rank: 1, aiv_id: 1

dstRankId: 1, tokenIndex: 12, dstExpertId 27, curExpertCnt: 1, rankGMOffset: 1770624

dstRankId: 1, tokenIndex: 13, dstExpertId 28, curExpertCnt: 0, rankGMOffset: 1835008

dstRankId: 0, tokenIndex: 14, dstExpertId 15, curExpertCnt: 0, rankGMOffset: 2031616

dstRankId: 0, tokenIndex: 15, dstExpertId 1, curExpertCnt: 0, rankGMOffset: 1114112

dstRankId: 1, tokenIndex: 16, dstExpertId 25, curExpertCnt: 0, rankGMOffset: 1638400

dstRankId: 1, tokenIndex: 17, dstExpertId 26, curExpertCnt: 0, rankGMOffset: 1703936

dstRankId: 0, tokenIndex: 18, dstExpertId 11, curExpertCnt: 1, rankGMOffset: 1770624

dstRankId: 0, tokenIndex: 19, dstExpertId 5, curExpertCnt: 1, rankGMOffset: 1377408

dstRankId: 0, tokenIndex: 20, dstExpertId 3, curExpertCnt: 0, rankGMOffset: 1245184

dstRankId: 0, tokenIndex: 21, dstExpertId 2, curExpertCnt: 0, rankGMOffset: 1179648

dstRankId: 0, tokenIndex: 22, dstExpertId 4, curExpertCnt: 0, rankGMOffset: 1310720

dstRankId: 0, tokenIndex: 23, dstExpertId 10, curExpertCnt: 0, rankGMOffset: 1703936

#######SendToMoeExpert##########rank: 1, aiv_id: 2

dstRankId: 1, tokenIndex: 24, dstExpertId 28, curExpertCnt: 1, rankGMOffset: 1836160

dstRankId: 1, tokenIndex: 25, dstExpertId 22, curExpertCnt: 0, rankGMOffset: 1441792

dstRankId: 0, tokenIndex: 26, dstExpertId 6, curExpertCnt: 1, rankGMOffset: 1442944

dstRankId: 0, tokenIndex: 27, dstExpertId 13, curExpertCnt: 0, rankGMOffset: 1900544

dstRankId: 0, tokenIndex: 28, dstExpertId 3, curExpertCnt: 1, rankGMOffset: 1246336

dstRankId: 1, tokenIndex: 29, dstExpertId 30, curExpertCnt: 0, rankGMOffset: 1966080

dstRankId: 1, tokenIndex: 30, dstExpertId 18, curExpertCnt: 0, rankGMOffset: 1179648

dstRankId: 0, tokenIndex: 31, dstExpertId 15, curExpertCnt: 1, rankGMOffset: 2032768

dstRankId: 1, tokenIndex: 32, dstExpertId 27, curExpertCnt: 2, rankGMOffset: 1771776

dstRankId: 1, tokenIndex: 33, dstExpertId 31, curExpertCnt: 1, rankGMOffset: 2032768

dstRankId: 0, tokenIndex: 34, dstExpertId 7, curExpertCnt: 0, rankGMOffset: 1507328

dstRankId: 0, tokenIndex: 35, dstExpertId 5, curExpertCnt: 2, rankGMOffset: 1378560

#######SendToMoeExpert##########rank: 1, aiv_id: 3

dstRankId: 1, tokenIndex: 36, dstExpertId 31, curExpertCnt: 2, rankGMOffset: 2033920

dstRankId: 0, tokenIndex: 37, dstExpertId 2, curExpertCnt: 1, rankGMOffset: 1180800

dstRankId: 1, tokenIndex: 38, dstExpertId 21, curExpertCnt: 0, rankGMOffset: 1376256

dstRankId: 1, tokenIndex: 39, dstExpertId 23, curExpertCnt: 0, rankGMOffset: 1507328

dstRankId: 1, tokenIndex: 40, dstExpertId 20, curExpertCnt: 1, rankGMOffset: 1311872

dstRankId: 0, tokenIndex: 41, dstExpertId 3, curExpertCnt: 2, rankGMOffset: 1247488

dstRankId: 0, tokenIndex: 42, dstExpertId 6, curExpertCnt: 2, rankGMOffset: 1444096

dstRankId: 0, tokenIndex: 43, dstExpertId 8, curExpertCnt: 0, rankGMOffset: 1572864

dstRankId: 0, tokenIndex: 44, dstExpertId 7, curExpertCnt: 1, rankGMOffset: 1508480

dstRankId: 0, tokenIndex: 45, dstExpertId 12, curExpertCnt: 0, rankGMOffset: 1835008

dstRankId: 0, tokenIndex: 46, dstExpertId 5, curExpertCnt: 3, rankGMOffset: 1379712

dstRankId: 0, tokenIndex: 47, dstExpertId 15, curExpertCnt: 2, rankGMOffset: 2033920

四个核执行完SendToMoeExpert,打印expandIdxGMTensor的数据:

--------expandIdxGMTensor------------

0 0 0 0 0 0 0 0 0 0 0 0

1 0 0 0 0 0 1 1 0 0 0 0

1 0 1 0 1 0 0 1 2 1 0 2

2 1 0 0 1 2 2 0 1 0 3 2

SetStatus

SendToMoeExpert中hidden states的拷贝涉及RDMA通信。SetStatus需要发送状态,通知数据发送完成。

c++ 复制代码
template <TemplateMC2TypeClass>
__aicore__ inline void MoeDistributeDispatch<TemplateMC2TypeFunc>::SetStatus()
{
    totalExpertNum_ = sharedExpertRankNum_ + moeExpertNum_;
    SplitToCore(totalExpertNum_, aivNum_, startExpertId_, endExpertId_, sendExpertNum_);
    for (uint32_t curExpertId = startExpertId_; curExpertId < endExpertId_; ++curExpertId) {
        if (curExpertId < sharedExpertRankNum_) {
            continue;
        }
        int32_t curExpertCnt = 0;
        int32_t curMoeExpertId = curExpertId - sharedExpertRankNum_;
        CalTokenSendExpertCnt(curMoeExpertId, expertIdsCnt_, curExpertCnt);
        int32_t cntPosIndex = curExpertId * 8 + 1; // 8的含义为一个专家占32字节
        statusTensor_(cntPosIndex) = curExpertCnt;
    }
    PipeBarrier<PIPE_ALL>();
    SyncAll<true>();

    if (startExpertId_ >= totalExpertNum_) {
        return;
    }
    GlobalTensor<int32_t> rankGMTensor;
    uint32_t offset = stateOffset_ * epRankId_;
    for (uint32_t rankIndex = startExpertId_; rankIndex < endExpertId_; ++rankIndex) {
        uint32_t dstRankId = rankIndex;
        if (moeExpertNumPerRank_ > 1 && (rankIndex >= sharedExpertRankNum_)) {
            dstRankId = ((rankIndex - sharedExpertRankNum_) / moeExpertNumPerRank_ + sharedExpertRankNum_);
            offset = (epRankId_ + (rankIndex - sharedExpertRankNum_) % moeExpertNumPerRank_ * epWorldSize_) * stateOffset_;
        }
        GM_ADDR rankGM = (__gm__ uint8_t*)(GetWindStateAddrByRankId(COMM_EP_IDX, dstRankId) + offset); // 计算地址偏移
        rankGMTensor.SetGlobalBuffer((__gm__ int32_t*)rankGM);
#if defined(ASCENDC_OOM) && ASCENDC_OOM == 1
        OOMCheckAddrRange<int32_t>(
            (__gm__ int32_t*)(GetWindStateAddrByRankId(COMM_EP_IDX, dstRankId)), STATE_SIZE);
#endif
        DataCopy<int32_t>(rankGMTensor, statusTensor_[rankIndex * 8], 8UL); // 8时数据大小,按32对齐拷贝
    }
    SyncFunc<AscendC::HardEvent::MTE3_MTE2>();
}

在statusTensor_中,每个专家占据的标志位长度为8 * sizeof(int32_t)。第一个int表示flag(值为0x3F800000,是float的1)。第二个int表示curExpertCnt。

每个专家在statusSpaceGm_占据的长度为stateOffset_。

复制代码
constexpr uint32_t STATE_OFFSET = 512; // 状态空间偏移地址
stateOffset_ = (recvWinBlockNum_ > 512) ? (STATE_OFFSET / 2) : STATE_OFFSET;

测试程序打印信息

ExpertId:专家id

dstRank:专家id权重所在的rank

offset:rankGM中的地址偏移。

curExpertCnt:参考函数CalTokenSendExpertCnt。

rank0上四个核的打印信息:

0号卡上,根据输入的expertIds,ExpertId为0的计数(curExpertCnt)为2;ExpertId为2的计数为3。

GetWindStateAddrByRankId返回的地址用addr表示。0号卡上,addr的偏移地址空间[0,512)存储rank0上ExpertId 0的flag和curExpertCnt ;addr的偏移地址空间[512, 1024)存储来自rank1上ExpertId 0的flag和curExpertCnt 。

#######SetStatus##########rank: 0, aiv_id: 0

ExpertId 0, dstRank: 0, offset: 0, curExpertCnt 2

ExpertId 1, dstRank: 0, offset: 1024, curExpertCnt 2

ExpertId 2, dstRank: 0, offset: 2048, curExpertCnt 3

ExpertId 3, dstRank: 0, offset: 3072, curExpertCnt 2

ExpertId 4, dstRank: 0, offset: 4096, curExpertCnt 0

ExpertId 5, dstRank: 0, offset: 5120, curExpertCnt 1

ExpertId 6, dstRank: 0, offset: 6144, curExpertCnt 2

ExpertId 7, dstRank: 0, offset: 7168, curExpertCnt 1

#######SetStatus##########rank: 0, aiv_id: 1

ExpertId 8, dstRank: 0, offset: 8192, curExpertCnt 1

ExpertId 9, dstRank: 0, offset: 9216, curExpertCnt 1

ExpertId 10, dstRank: 0, offset: 10240, curExpertCnt 1

ExpertId 11, dstRank: 0, offset: 11264, curExpertCnt 3

ExpertId 12, dstRank: 0, offset: 12288, curExpertCnt 2

ExpertId 13, dstRank: 0, offset: 13312, curExpertCnt 1

ExpertId 14, dstRank: 0, offset: 14336, curExpertCnt 1

ExpertId 15, dstRank: 0, offset: 15360, curExpertCnt 0

#######SetStatus##########rank: 0, aiv_id: 2

ExpertId 16, dstRank: 1, offset: 0, curExpertCnt 2

ExpertId 17, dstRank: 1, offset: 1024, curExpertCnt 2

ExpertId 18, dstRank: 1, offset: 2048, curExpertCnt 1

ExpertId 19, dstRank: 1, offset: 3072, curExpertCnt 2

ExpertId 20, dstRank: 1, offset: 4096, curExpertCnt 0

ExpertId 21, dstRank: 1, offset: 5120, curExpertCnt 2

ExpertId 22, dstRank: 1, offset: 6144, curExpertCnt 2

ExpertId 23, dstRank: 1, offset: 7168, curExpertCnt 2

#######SetStatus##########rank: 0, aiv_id: 3

ExpertId 24, dstRank: 1, offset: 8192, curExpertCnt 3

ExpertId 25, dstRank: 1, offset: 9216, curExpertCnt 0

ExpertId 26, dstRank: 1, offset: 10240, curExpertCnt 0

ExpertId 27, dstRank: 1, offset: 11264, curExpertCnt 2

ExpertId 28, dstRank: 1, offset: 12288, curExpertCnt 0

ExpertId 29, dstRank: 1, offset: 13312, curExpertCnt 3

ExpertId 30, dstRank: 1, offset: 14336, curExpertCnt 3

ExpertId 31, dstRank: 1, offset: 15360, curExpertCnt 1

rank1上四个核的打印信息:

#######SetStatus##########rank: 1, aiv_id: 0

ExpertId 0, dstRank: 0, offset: 512, curExpertCnt 1

ExpertId 1, dstRank: 0, offset: 1536, curExpertCnt 1

ExpertId 2, dstRank: 0, offset: 2560, curExpertCnt 2

ExpertId 3, dstRank: 0, offset: 3584, curExpertCnt 3

ExpertId 4, dstRank: 0, offset: 4608, curExpertCnt 1

ExpertId 5, dstRank: 0, offset: 5632, curExpertCnt 4

ExpertId 6, dstRank: 0, offset: 6656, curExpertCnt 3

ExpertId 7, dstRank: 0, offset: 7680, curExpertCnt 2

#######SetStatus##########rank: 1, aiv_id: 1

ExpertId 8, dstRank: 0, offset: 8704, curExpertCnt 1

ExpertId 9, dstRank: 0, offset: 9728, curExpertCnt 1

ExpertId 10, dstRank: 0, offset: 10752, curExpertCnt 1

ExpertId 11, dstRank: 0, offset: 11776, curExpertCnt 2

ExpertId 12, dstRank: 0, offset: 12800, curExpertCnt 1

ExpertId 13, dstRank: 0, offset: 13824, curExpertCnt 1

ExpertId 14, dstRank: 0, offset: 14848, curExpertCnt 0

ExpertId 15, dstRank: 0, offset: 15872, curExpertCnt 3

#######SetStatus##########rank: 1, aiv_id: 2

ExpertId 16, dstRank: 1, offset: 512, curExpertCnt 1

ExpertId 17, dstRank: 1, offset: 1536, curExpertCnt 1

ExpertId 18, dstRank: 1, offset: 2560, curExpertCnt 1

ExpertId 19, dstRank: 1, offset: 3584, curExpertCnt 0

ExpertId 20, dstRank: 1, offset: 4608, curExpertCnt 2

ExpertId 21, dstRank: 1, offset: 5632, curExpertCnt 1

ExpertId 22, dstRank: 1, offset: 6656, curExpertCnt 1

ExpertId 23, dstRank: 1, offset: 7680, curExpertCnt 1

#######SetStatus##########rank: 1, aiv_id: 3

ExpertId 24, dstRank: 1, offset: 8704, curExpertCnt 1

ExpertId 25, dstRank: 1, offset: 9728, curExpertCnt 1

ExpertId 26, dstRank: 1, offset: 10752, curExpertCnt 1

ExpertId 27, dstRank: 1, offset: 11776, curExpertCnt 3

ExpertId 28, dstRank: 1, offset: 12800, curExpertCnt 2

ExpertId 29, dstRank: 1, offset: 13824, curExpertCnt 1

ExpertId 30, dstRank: 1, offset: 14848, curExpertCnt 1

ExpertId 31, dstRank: 1, offset: 15872, curExpertCnt 3

WaitDispatch

c++ 复制代码
template <TemplateMC2TypeClass>
__aicore__ inline void MoeDistributeDispatch<TemplateMC2TypeFunc>::WaitDispatch()
{
    uint32_t startStatusIndex = 0;
    uint32_t endStatusIndex = 0;
    uint32_t rscvStatusNum = isShareExpertRank_ ? epWorldSize_ : recvWinBlockNum_;
    SplitToCore(rscvStatusNum, aivNum_, startStatusIndex, endStatusIndex, recStatusNumPerCore_);
    InitBufferWait();
    if (startStatusIndex >= rscvStatusNum) {
        SyncAll<true>();
        return;
    }

    LocalTensor<float> gatherMaskOutTensor = gatherMaskOutBuf_.Get<float>();
    LocalTensor<uint32_t> gatherTmpTensor = scalarBuf_.GetWithOffset<uint32_t>(UB_ALIGN / sizeof(uint32_t), 0);
    gatherTmpTensor.SetValue(0, 1);
    LocalTensor<float> statusSumOutTensor = scalarBuf_.GetWithOffset<float>(UB_ALIGN / sizeof(float), UB_ALIGN);
    statusFp32Tensor_ = statusTensor_.ReinterpretCast<float>();
    uint32_t mask = 1; // gatherMask + sum 相关参数
    uint64_t rsvdCnt = 0;
    uint32_t recStatusNumPerCoreInner = Ceil(recStatusNumPerCore_ * sizeof(float), UB_ALIGN) * UB_ALIGN / sizeof(float);
    SumParams sumParams{1, recStatusNumPerCoreInner, recStatusNumPerCore_};
    float sumOfFlag = static_cast<float>(-1.0);
    float minTarget = (sumTarget_ * recStatusNumPerCore_) - (float)0.5;
    float maxTarget = (sumTarget_ * recStatusNumPerCore_) + (float)0.5;
    DataCopyParams intriParams{static_cast<uint16_t>(recStatusNumPerCore_), 1,
                               static_cast<uint16_t>((recvWinBlockNum_ > 512) ? 7 : 15), 0}; // srcStride为15个block
    SyncFunc<AscendC::HardEvent::S_V>();
    while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) {
        DataCopy(statusFp32Tensor_, windowInstatusFp32Tensor_[startStatusIndex * stateOffset_ / sizeof(float)], intriParams);
        SyncFunc<AscendC::HardEvent::MTE2_V>();
        GatherMask(gatherMaskOutTensor, statusFp32Tensor_, gatherTmpTensor, true, mask,
                   {1, (uint16_t)recStatusNumPerCore_, 1, 0}, rsvdCnt);
        PipeBarrier<PIPE_V>();
        Sum(statusSumOutTensor, gatherMaskOutTensor, sumParams);
        SyncFunc<AscendC::HardEvent::V_S>();
        sumOfFlag = statusSumOutTensor.GetValue(0);
    }
    SyncAll<true>();
}

recvWinBlockNum_ = epWorldSize_ * moeExpertNumPerRank_;

WaitDispatch根据flag的个数,判断数据传输是否完成。每个aiv核需要等待接收的flag个数为recStatusNumPerCore_。

接收到recStatusNumPerCore_个flag后,退出while循环,调用SyncAll,同步同一个rank上的aiv核。SyncAll退出后,说明整个rank已经接收到其他rank发送的hidden states。

测试程序打印信息

rank0上四个核的打印信息:

#######WaitDispatch##########rank: 0, aiv_id: 0

rscvStatusNum: 32, recStatusNumPerCore: 8, startStatusIndex: 0, endStatusIndex: 8

minTarget: 7.500000, maxTarget: 8.500000, sumOfFlag: 8.000000

针对此测试例子,其他3个核,打印的信息相同,省略。
rank1上四个核的打印信息:

#######WaitDispatch##########rank: 1, aiv_id: 0

rscvStatusNum: 32, recStatusNumPerCore: 8, startStatusIndex: 0, endStatusIndex: 8

minTarget: 7.500000, maxTarget: 8.500000, sumOfFlag: 8.000000

针对此测试例子,其他3个核,打印的信息相同,省略。

LocalWindowCopy

c++ 复制代码
template <TemplateMC2TypeClass>
__aicore__ inline void MoeDistributeDispatch<TemplateMC2TypeFunc>::GetCumSum(LocalTensor<int32_t>& outLocal, int32_t totalCount)
{
    DataCopyParams intriParams{static_cast<uint16_t>(recvWinBlockNum_), 1,
                               static_cast<uint16_t>((recvWinBlockNum_ > 512) ? 7 : 15), 0}; // srcStride为15个block
    DataCopy(statusTensor_, windowInstatusTensor_, intriParams);

    if (isShareExpertRank_) {
        SyncFunc<AscendC::HardEvent::MTE2_S>();
        for (uint32_t curStatusExpId = 0; curStatusExpId < sharedExpertRankNum_; ++curStatusExpId) {
            int32_t curExpertCnt = (curStatusExpId + 1 + epRankId_) * axisBS_ / sharedExpertRankNum_
                                    - (curStatusExpId + epRankId_) * axisBS_ / sharedExpertRankNum_;
            statusTensor_((curStatusExpId) * 8 + 1) = curExpertCnt;
        }
        SyncFunc<AscendC::HardEvent::S_V>();
    } else {
        SyncFunc<AscendC::HardEvent::MTE2_V>();
    }

    outLocal = gatherMaskOutBuf_.Get<int32_t>(); // 内存复用
    LocalTensor<float> getTotalLocal = getTotalBuf_.Get<float>();
    // gather mask在一起

    LocalTensor<uint32_t> gatherTmpTensor = gatherTmpBuf_.Get<uint32_t>();
    Duplicate(gatherTmpTensor, (uint32_t)33686018, recvWinBlockNum_ / 4); // 0000 0010 0000 0010 0000 0010 0000 0010
    PipeBarrier<PIPE_V>();
    uint32_t mask = recvWinBlockNum_ * 8; // 512 / 32
    uint64_t rsvdCnt = 0;
    GatherMask(outLocal, statusTensor_, gatherTmpTensor, true, mask, {1, 1, 0, 0}, rsvdCnt);
    LocalTensor<float> tmpFp32 = outLocal.ReinterpretCast<float>();
    PipeBarrier<PIPE_V>();
    ReduceSum<float>(getTotalLocal, tmpFp32, workLocalTensor_, epWorldSize_);
    totalCnt_ = getTotalLocal.ReinterpretCast<int32_t>().GetValue(0);
    PipeBarrier<PIPE_V>();
    ReduceSum<float>(tmpFp32, tmpFp32, workLocalTensor_, totalCount);
    PipeBarrier<PIPE_V>();
}

LocalWindowCopy调用GetCumSum,获取专家处理的token计数。

rank0,rank1上专家0到16产生的curExpertCnt值,统计到表格:

rank\id 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
rank0 2 2 3 2 0 1 2 1 1 1 1 3 2 1 1 0
rank1 1 1 2 3 1 4 3 2 1 1 1 2 1 1 0 3
合计 3 3 5 5 1 5 5 3 2 2 2 5 3 2 1 3

GetCumSum函数中执行GatherMask后,打印的outLocal值(同一个rank,不同核的outLocal的值相同)。表格中的值从上到下读。

--------outLocal------------

2 1 2 1 3 2 2 3

0 1 1 4 2 3 1 2

1 1 1 1 1 1 3 2

2 1 1 1 1 0 0 3

GetCumSum函数中执行GatherMask后,rank0中打印的outLocal值。

outCountLocal cumsum打印的值,是LocalWindowCopy执行了下面代码之后的输出。

c++ 复制代码
for (uint32_t index = startExpertId_; index < endExpertId_; index++) {
    uint32_t i = index - startExpertId_;
    if (i > 0) {
        outCountLocal.SetValue(i, outCountLocal.GetValue(i - 1) + outCountLocal.GetValue(index));
    }
}
//来自我的模拟代码
uint32_t  count = endExpertId_ - startExpertId_;
LOG_INFO("startExpertId: %d, endExpertId %d", startExpertId_, endExpertId_);
print_tensor(outCountLocal, "outCountLocal cumsum", count);

#######GetCumSum##########rank: 0, aiv_id: 0

--------outLocal ReduceSum, totalCount = startExpertId_ + 1: 1------------

2 1 2 1 3 2 2 3

0 1 1 4 2 3 1 2

1 1 1 1 1 1 3 2

2 1 1 1 1 0 0 3

#######LocalWindowCopy##########rank: 0, aiv_id: 0

startExpertId: 0, endExpertId 8

--------outCountLocal cumsum------------

2 3 5 6 9 11 13 16

#######GetCumSum##########rank: 0, aiv_id: 1

--------outLocal ReduceSum, totalCount = startExpertId_ + 1: 9------------

16 1 2 1 3 2 2 3

0 1 1 4 2 3 1 2

1 1 1 1 1 1 3 2

2 1 1 1 1 0 0 3

#######LocalWindowCopy##########rank: 0, aiv_id: 1

startExpertId: 8, endExpertId 16

--------outCountLocal cumsum------------

16 17 18 22 24 27 28 30

#######GetCumSum##########rank: 0, aiv_id: 2

--------outLocal ReduceSum, totalCount = startExpertId_ + 1: 17------------

31 1 2 1 3 2 2 3

0 1 1 4 2 3 1 2

1 1 1 1 1 1 3 2

2 1 1 1 1 0 0 3

#######LocalWindowCopy##########rank: 0, aiv_id: 2

startExpertId: 16, endExpertId 24

#######GetCumSum##########rank: 0, aiv_id: 3

--------outLocal ReduceSum, totalCount = startExpertId_ + 1: 25------------

43 1 2 1 3 2 2 3

0 1 1 4 2 3 1 2

1 1 1 1 1 1 3 2

2 1 1 1 1 0 0 3

#######LocalWindowCopy##########rank: 0, aiv_id: 3

startExpertId: 24, endExpertId 32

--------outCountLocal cumsum------------

43 44 45 46 47 47 47 50

LocalWindowCopy最后将outCountLocal的值复制到sendCountsOutGM_。

c++ 复制代码
template <TemplateMC2TypeClass>
__aicore__ inline void MoeDistributeDispatch<TemplateMC2TypeFunc>::LocalWindowCopy()
    DataCopyExtParams dataCopyOutParams = {1U, static_cast<uint32_t>(sendExpertNum_ * sizeof(int32_t)), 0U, 0U, 0U};
    GlobalTensor<int32_t> sendCountsGlobal;
    sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sendCountsOutGM_));
    DataCopyPad(sendCountsGlobal[startExpertId_], outCountLocal, dataCopyOutParams);
    PipeBarrier<PIPE_MTE3>();
}

UpdateTokenNumsOut

只有最后一个核执行UpdateTokenNumsOut函数,需要调用SyncAll同步。只看UpdateMultiMoeTokenNumsOut的分支。

c++ 复制代码
// 更新tokenNumsOut tensor
template <TemplateMC2TypeClass>
__aicore__ inline void MoeDistributeDispatch<TemplateMC2TypeFunc>::UpdateTokenNumsOut()
{
    // 最后一个核做更新,Moe专家只有最后一个核有计算出所有 sendCountsGlobal
    if (!isShareExpertRank_ && moeExpertNumPerRank_ > 1) {
        SyncAll<true>();
        if (aivId_ != lastCore_) return;

        SyncFunc<AscendC::HardEvent::MTE3_S>();
        UpdateMultiMoeTokenNumsOut();
    }
}
// 更新多专家卡上的tokenNumsOut tensor
template <TemplateMC2TypeClass>
__aicore__ inline void MoeDistributeDispatch<TemplateMC2TypeFunc>::UpdateMultiMoeTokenNumsOut()
{
    uint32_t tokenSums = 0;
    GlobalTensor<int32_t> sendCountsGlobal;
    sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(sendCountsOutGM_));
    for (uint32_t localMoeIndex = 0; localMoeIndex < moeExpertNumPerRank_; ++localMoeIndex) {
        if (localMoeIndex == 0) {
            DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
                sendCountsGlobal[epWorldSize_ - 1]);
            uint32_t firstMoeCnt = sendCountsGlobal.GetValue(epWorldSize_ - 1);
            tokenSums = firstMoeCnt + gatherCount_;
            expertTokenNumsOutGMTensor_.SetValue(localMoeIndex, tokenSums);
            DataCacheCleanAndInvalid<int64_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
                expertTokenNumsOutGMTensor_[localMoeIndex]);
        } else {
            uint32_t preIndex = epWorldSize_ * (localMoeIndex - 1) + epWorldSize_- 1;
            uint32_t curIndex = epWorldSize_ * localMoeIndex + epWorldSize_ - 1;
            DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
                sendCountsGlobal[preIndex]);
            DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
                sendCountsGlobal[curIndex]);
            uint32_t preMoeIndexCnt = sendCountsGlobal.GetValue(preIndex);
            uint32_t curMoeIndexCnt = sendCountsGlobal.GetValue(curIndex);
            tokenSums = ((expertTokenNumsType_ == 0) ? tokenSums : 0) + (curMoeIndexCnt - preMoeIndexCnt) + gatherCount_;
            expertTokenNumsOutGMTensor_.SetValue(localMoeIndex, tokenSums);
            DataCacheCleanAndInvalid<int64_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
                expertTokenNumsOutGMTensor_[localMoeIndex]);
        }
    }
}

测试程序打印信息

#######UpdateMultiMoeTokenNumsOut##########rank: 0, aiv_id: 3

--------sendCountsGlobal------------

2 3 5 6 9 11 13 16

16 17 18 22 24 27 28 30

31 32 33 34 35 36 39 41

43 44 45 46 47 47 47 50

--------expertTokenNumsOutGMTensor------------

3 3 5 5 1 5 5 3

2 2 2 5 3 2 1 3

同表格中的合计值一致。expertTokenNumsOutGMTensor的索引代表专家id,值代表专家处理的token计数。

辅助功能函数分析

CalTokenSendExpertCnt

c++ 复制代码
template <TemplateMC2TypeClass>
__aicore__ inline void MoeDistributeDispatch<TemplateMC2TypeFunc>::CalTokenSendExpertCnt(uint32_t dstExpertId, int32_t calCnt,
    int32_t &curExpertCnt)
{
    if (isQuant_) { // 量化模式下buffer复用
       dstExpIdTensor_ = receiveDataCastFloatBuf_.Get<int32_t>();
       subExpIdTensor_ = smoothScalesBuf_.Get<int32_t>();
    }
    Duplicate<int32_t>(dstExpIdTensor_, dstExpertId, calCnt);
    PipeBarrier<PIPE_V>();
    Sub(subExpIdTensor_, expertIdsTensor_, dstExpIdTensor_, calCnt);
    PipeBarrier<PIPE_V>();
    LocalTensor<float> tmpFp32 = subExpIdTensor_.ReinterpretCast<float>();
    LocalTensor<float> tmpoutFp32 = dstExpIdTensor_.ReinterpretCast<float>();
    Abs(tmpoutFp32, tmpFp32, calCnt);
    PipeBarrier<PIPE_V>();
    Mins(subExpIdTensor_, dstExpIdTensor_, 1, calCnt);
    PipeBarrier<PIPE_V>();
    ReduceSum<float>(tmpoutFp32, tmpFp32, workLocalTensor_, calCnt);
    SyncFunc<AscendC::HardEvent::V_S>();
    int32_t curOtherExpertCnt = dstExpIdTensor_(0);
    if (calCnt > curOtherExpertCnt) {
        curExpertCnt = calCnt - curOtherExpertCnt; 
    }
}

函数计算expertIdsTensor_中值为dstExpertId的计数。

Duplicate:复制dstExpertId到dstExpIdTensor_,次数为calCnt。

Sub,作差:expertIdsTensor_ - dstExpIdTensor_,值保存到subExpIdTensor_。

Abs:对subExpIdTensor_中保存的结果做绝对值,内容被当做float类型。结果保存到dstExpIdTensor_中的空间中。

c++ 复制代码
template <typename T>
__aicore__ inline void Abs(const LocalTensor<T>& dstLocal, const LocalTensor<T>& srcLocal, const int32_t& calCount)
{
    for(int32_t i = 0; i < calCount; i++) {
        if (srcLocal.GetValue(i) >=0) {
            dstLocal.SetValue(i, srcLocal.GetValue(i));
        } else {
            if (std::is_same<T, float>::value) {
                dstLocal.SetValue(i, std::fabs(srcLocal(i)));
            } else {
                dstLocal.SetValue(i, std::abs(srcLocal(i)));
            }
        }
    }
}

Mins:subExpIdTensor_(i) = min(dstExpIdTensor_(i), 1)

c++ 复制代码
//https://developer.huawei.com/consumer/cn/doc/hiai-References/cannkit-scalar-binocular-mins-0000002123235326
template <typename T, bool isSetMask = true>
__aicore__ inline void Mins(const LocalTensor<T>& dstLocal, const LocalTensor<T>& srcLocal, const T& scalarValue, const int32_t& calCount)
{
    for(int32_t i = 0; i < calCount; i++) {
        if (srcLocal.GetValue(i) >=scalarValue) {
            dstLocal.SetValue(i, scalarValue);
        } else {
            dstLocal.SetValue(i, srcLocal(i));
        }
    }
}

以随机生成的expertIds为例:

--------expertIds------------

17 18 30 29 6 12 10 19 26 13 15 22

12 26 10 30 11 25 29 22 15 5 18 28

29 14 12 15 2 6 9 13 17 21 18 27

18 12 11 2 22 3 16 30 10 13 28 26

CalTokenSendExpertCnt的输入相关参数为:

c++ 复制代码
    int calCount = 37;
    uint32_t dstExpertId = expertIds[calCount];
    int curExpertCnt = 0;
    dispatch.CalTokenSendExpertCnt(dstExpertId, calCount, curExpertCnt, true);
    LOG_INFO("index: %d, dstExpertId %d, curExpertCnt: %d", calCount, dstExpertId, curExpertCnt);
//index: 37, dstExpertId 12, curExpertCnt: 3

经过Duplicate,Sub, Abs,Mins处理后的subExpIdTensor_的结果:

--------subExpIdTensor_------------

1 1 1 1 1 0 1 1 1 1 1 1

0 1 1 1 1 1 1 1 1 1 1 1

1 1 0 1 1 1 1 1 1 1 1 1

1

测试程序源码

to be uploaded.

相关推荐
sanggou1 小时前
Windsurf AI IDE 完全使用指南
ide·人工智能
2501_941870562 小时前
人工智能与未来的工作:自动化与人类协作的新时代
大数据·人工智能
Blurpath2 小时前
2025 年用ChatGPT+代理构建AI驱动的智能爬虫
人工智能·爬虫·chatgpt·ip代理·住宅ip·动态住宅代理·轮换ip
极客BIM工作室3 小时前
大模型中的Scaling Law:AI的“增长密码“
人工智能
纪伊路上盛名在3 小时前
Alphafold实用指南—官网教程3
数据库·人工智能·机器学习·alphafold·计算生物学·结构生物学
茶杯6753 小时前
数字孪生厂商推荐:跨行业通用型平台与垂直领域专精企业对比指南
人工智能
道可云3 小时前
场景搭桥,产业赋能:新政策如何激活乡村振兴新动能
人工智能
诸葛务农3 小时前
人形机器人:热成像血管分布图及糖尿病足早期病变预警模型
人工智能·机器人