图像定格,一声尖利的呜叫响起,排险者告诉人们,预警系统报警了。
"为什么?"总工程师不解地问。
"这个原始人仰望星空的时间超过了预誓阀值,已对宇宙表现出了充分的好奇。到此为止,已在不同的地点观察到了十例这样的超限事件,符合报警条件。"
"如果我没记错的话,你前面说过,只有当有能力产生创世能级能量、过程的文明出现时,预警系统才会报警。"
"你们看到的不正是这样一个文明吗?
------《朝闻道》
前言
为啥写这样一篇文章,倒不是因为难以抑制的好奇心。一个是闲,另一个是功利心,看看能不能做点通算融合的工作。MoeDistributeDispatch这个算子,我断断续续看了三周。单纯看代码,一片迷茫。后来在cpu上模拟了对应的几个接口:DataCopy,Duplicate,ReduceSum, Add,Sub,Mins,打印值帮助代码理解。
MoeDistributeDispatch
文档
https://gitcode.com/cann/ops-transformer/tree/master/mc2/moe_distribute_dispatch
主要流程
c++
template <TemplateMC2TypeClass>
__aicore__ inline void MoeDistributeDispatch<TemplateMC2TypeFunc>::Process()
{
if ASCEND_IS_AIV { // 全aiv处理
AlltoAllDispatch();
SetStatus();
WaitDispatch();
LocalWindowCopy();
if constexpr (IsNeedAllgather) {
AllGatherSetStatusAndWait();
AllgatherProcessOut();
}
UpdateTokenNumsOut();
}
}
测试数据说明
为了理解,在代码测试中,配置moeExpertNum_为32,假设使用两张卡(epWorldSize_=2),每张卡上使用的aiv的核数(aivNum_)为4。
0到15对应的专家权重存储在rank0;16到31对应的专家权重存储在rank1。
moeExpertNumPerRank_ = moeExpertNum_/epWorldSize_。
假设每张卡上输入的expertIds的维度,axisBS_=4,axisK_=12。
rank=0上输入的expertIds为:
--------expertIds------------
11 17 29 12 24 23 1 0 16 30 18 8
2 14 11 3 30 21 12 0 10 9 6 31
22 27 30 21 1 24 17 11 3 19 2 29
16 13 7 27 6 29 5 22 24 19 23 2
rank=1上输入的expertIds为:
--------expertIds------------
24 6 29 5 20 0 16 27 9 31 11 17
27 28 15 1 25 26 11 5 3 2 4 10
28 22 6 13 3 30 18 15 27 31 7 5
31 2 21 23 20 3 6 8 7 12 5 15
AlltoAllDispatch
SendToMoeExpert函数将hidden states(xGMTensor_或者xGMTensor_量化后的值)发送到对应的dstExpertId对应的HBM中(dstWinGMTensor)。
template <TemplateMC2TypeClass>
__aicore__ inline void MoeDistributeDispatch<TemplateMC2TypeFunc>::SendToMoeExpert()
{
for (int32_t tokenIndex = startTokenId; tokenIndex < endTokenId; ++tokenIndex) {
uint32_t dstExpertId = expertIdsTensor_(tokenIndex);
int32_t curExpertCnt = 0;
if (tokenIndex > 0) {
CalTokenSendExpertCnt(dstExpertId, tokenIndex, curExpertCnt);
}
expertCountTensor_(tokenIndex - startTokenId) = curExpertCnt;
uint32_t tempRankId = dstExpertId / moeExpertNumPerRank_ + sharedExpertRankNum_;
GM_ADDR rankGM = (__gm__ uint8_t*)(GetWindAddrByRankId(COMM_EP_IDX, tempRankId) +
(expertPerSizeOnWin_ * (epRankId_ * moeExpertNumPerRank_ + dstExpertId % moeExpertNumPerRank_))
+ hCommuSize_ * curExpertCnt); // 计算地址偏移
dstWinGMTensor.SetGlobalBuffer((__gm__ ExpandXOutType*)rankGM);
}
GlobalTensor<int32_t> expandIdxGMTensor;
expandIdxGMTensor.SetGlobalBuffer((__gm__ int32_t*)(expandIdxOutGM_ + startTokenId * sizeof(int32_t)));
DataCopyExtParams expertIdsCntParams = {1U, static_cast<uint32_t>(sendTokenNum * sizeof(uint32_t)), 0U, 0U, 0U};
DataCopyPad(expandIdxGMTensor, expertCountTensor_, expertIdsCntParams);
}
每个rank上,接收数据需要的存储区的大小:expertPerSizeOnWin_ * moeExpertNumPerRank_ * epWorldSize_。
每个专家配置的存储大小为expertPerSizeOnWin_ = axisMaxBS_ * axisH_ * sizeof(XType)。针对expertIds,每一行的id不会重复。同一个专家id的计数最大为axisMaxBS_ 。
针对rank0的地址范围:[0, expertPerSizeOnWin_ * moeExpertNumPerRank_),存储来自rank0上expertIds在[0,16)范围对应的hidden states。
针对rank0的地址范围:[expertPerSizeOnWin_ * moeExpertNumPerRank_, expertPerSizeOnWin_ * moeExpertNumPerRank_ * epWorldSize_),存储来自rank1上expertIds在[0,16)范围对应的hidden states。需要RDMA通信。
针对rank1的地址范围:[0, expertPerSizeOnWin_ * moeExpertNumPerRank_),存储来自rank0上expertIds在[16,32)范围对应的hidden states。需要RDMA通信。
针对rank1的地址范围:[expertPerSizeOnWin_ * moeExpertNumPerRank_, expertPerSizeOnWin_ * moeExpertNumPerRank_ * epWorldSize_),存储来自rank1上expertIds在[16,32)范围对应的hidden states。
测试程序打印信息
dstRankId:hidden states发送的目的rank。
tokenIndex:核上负责处理的token索引。
dstExpertId :expertIdsTensor_(tokenIndex)
curExpertCnt:参考函数CalTokenSendExpertCnt。
rank0上四个核的打印信息 :
#######SendToMoeExpert##########rank: 0, aiv_id: 0
dstRankId: 0, tokenIndex: 0, dstExpertId 11, curExpertCnt: 0, rankGMOffset: 720896
dstRankId: 1, tokenIndex: 1, dstExpertId 17, curExpertCnt: 0, rankGMOffset: 65536
dstRankId: 1, tokenIndex: 2, dstExpertId 29, curExpertCnt: 0, rankGMOffset: 851968
dstRankId: 0, tokenIndex: 3, dstExpertId 12, curExpertCnt: 0, rankGMOffset: 786432
dstRankId: 1, tokenIndex: 4, dstExpertId 24, curExpertCnt: 0, rankGMOffset: 524288
dstRankId: 1, tokenIndex: 5, dstExpertId 23, curExpertCnt: 0, rankGMOffset: 458752
dstRankId: 0, tokenIndex: 6, dstExpertId 1, curExpertCnt: 0, rankGMOffset: 65536
dstRankId: 0, tokenIndex: 7, dstExpertId 0, curExpertCnt: 0, rankGMOffset: 0
dstRankId: 1, tokenIndex: 8, dstExpertId 16, curExpertCnt: 0, rankGMOffset: 0
dstRankId: 1, tokenIndex: 9, dstExpertId 30, curExpertCnt: 0, rankGMOffset: 917504
dstRankId: 1, tokenIndex: 10, dstExpertId 18, curExpertCnt: 0, rankGMOffset: 131072
dstRankId: 0, tokenIndex: 11, dstExpertId 8, curExpertCnt: 0, rankGMOffset: 524288
#######SendToMoeExpert##########rank: 0, aiv_id: 1
dstRankId: 0, tokenIndex: 12, dstExpertId 2, curExpertCnt: 0, rankGMOffset: 131072
dstRankId: 0, tokenIndex: 13, dstExpertId 14, curExpertCnt: 0, rankGMOffset: 917504
dstRankId: 0, tokenIndex: 14, dstExpertId 11, curExpertCnt: 1, rankGMOffset: 722048
dstRankId: 0, tokenIndex: 15, dstExpertId 3, curExpertCnt: 0, rankGMOffset: 196608
dstRankId: 1, tokenIndex: 16, dstExpertId 30, curExpertCnt: 1, rankGMOffset: 918656
dstRankId: 1, tokenIndex: 17, dstExpertId 21, curExpertCnt: 0, rankGMOffset: 327680
dstRankId: 0, tokenIndex: 18, dstExpertId 12, curExpertCnt: 1, rankGMOffset: 787584
dstRankId: 0, tokenIndex: 19, dstExpertId 0, curExpertCnt: 1, rankGMOffset: 1152
dstRankId: 0, tokenIndex: 20, dstExpertId 10, curExpertCnt: 0, rankGMOffset: 655360
dstRankId: 0, tokenIndex: 21, dstExpertId 9, curExpertCnt: 0, rankGMOffset: 589824
dstRankId: 0, tokenIndex: 22, dstExpertId 6, curExpertCnt: 0, rankGMOffset: 393216
dstRankId: 1, tokenIndex: 23, dstExpertId 31, curExpertCnt: 0, rankGMOffset: 983040
#######SendToMoeExpert##########rank: 0, aiv_id: 2
dstRankId: 1, tokenIndex: 24, dstExpertId 22, curExpertCnt: 0, rankGMOffset: 393216
dstRankId: 1, tokenIndex: 25, dstExpertId 27, curExpertCnt: 0, rankGMOffset: 720896
dstRankId: 1, tokenIndex: 26, dstExpertId 30, curExpertCnt: 2, rankGMOffset: 919808
dstRankId: 1, tokenIndex: 27, dstExpertId 21, curExpertCnt: 1, rankGMOffset: 328832
dstRankId: 0, tokenIndex: 28, dstExpertId 1, curExpertCnt: 1, rankGMOffset: 66688
dstRankId: 1, tokenIndex: 29, dstExpertId 24, curExpertCnt: 1, rankGMOffset: 525440
dstRankId: 1, tokenIndex: 30, dstExpertId 17, curExpertCnt: 1, rankGMOffset: 66688
dstRankId: 0, tokenIndex: 31, dstExpertId 11, curExpertCnt: 2, rankGMOffset: 723200
dstRankId: 0, tokenIndex: 32, dstExpertId 3, curExpertCnt: 1, rankGMOffset: 197760
dstRankId: 1, tokenIndex: 33, dstExpertId 19, curExpertCnt: 0, rankGMOffset: 196608
dstRankId: 0, tokenIndex: 34, dstExpertId 2, curExpertCnt: 1, rankGMOffset: 132224
dstRankId: 1, tokenIndex: 35, dstExpertId 29, curExpertCnt: 1, rankGMOffset: 853120
#######SendToMoeExpert##########rank: 0, aiv_id: 3
dstRankId: 1, tokenIndex: 36, dstExpertId 16, curExpertCnt: 1, rankGMOffset: 1152
dstRankId: 0, tokenIndex: 37, dstExpertId 13, curExpertCnt: 0, rankGMOffset: 851968
dstRankId: 0, tokenIndex: 38, dstExpertId 7, curExpertCnt: 0, rankGMOffset: 458752
dstRankId: 1, tokenIndex: 39, dstExpertId 27, curExpertCnt: 1, rankGMOffset: 722048
dstRankId: 0, tokenIndex: 40, dstExpertId 6, curExpertCnt: 1, rankGMOffset: 394368
dstRankId: 1, tokenIndex: 41, dstExpertId 29, curExpertCnt: 2, rankGMOffset: 854272
dstRankId: 0, tokenIndex: 42, dstExpertId 5, curExpertCnt: 0, rankGMOffset: 327680
dstRankId: 1, tokenIndex: 43, dstExpertId 22, curExpertCnt: 1, rankGMOffset: 394368
dstRankId: 1, tokenIndex: 44, dstExpertId 24, curExpertCnt: 2, rankGMOffset: 526592
dstRankId: 1, tokenIndex: 45, dstExpertId 19, curExpertCnt: 1, rankGMOffset: 197760
dstRankId: 1, tokenIndex: 46, dstExpertId 23, curExpertCnt: 1, rankGMOffset: 459904
dstRankId: 0, tokenIndex: 47, dstExpertId 2, curExpertCnt: 2, rankGMOffset: 133376
四个核执行完SendToMoeExpert,打印expandIdxGMTensor的数据:
--------expandIdxGMTensor------------
0 0 0 0 0 0 0 0 0 0 0 0
0 0 1 0 1 0 1 1 0 0 0 0
0 0 2 1 1 1 1 2 1 0 1 1
1 0 0 1 1 2 0 1 2 1 1 2
rank1上四个核的打印信息:
#######SendToMoeExpert##########rank: 1, aiv_id: 0
dstRankId: 1, tokenIndex: 0, dstExpertId 24, curExpertCnt: 0, rankGMOffset: 1572864
dstRankId: 0, tokenIndex: 1, dstExpertId 6, curExpertCnt: 0, rankGMOffset: 1441792
dstRankId: 1, tokenIndex: 2, dstExpertId 29, curExpertCnt: 0, rankGMOffset: 1900544
dstRankId: 0, tokenIndex: 3, dstExpertId 5, curExpertCnt: 0, rankGMOffset: 1376256
dstRankId: 1, tokenIndex: 4, dstExpertId 20, curExpertCnt: 0, rankGMOffset: 1310720
dstRankId: 0, tokenIndex: 5, dstExpertId 0, curExpertCnt: 0, rankGMOffset: 1048576
dstRankId: 1, tokenIndex: 6, dstExpertId 16, curExpertCnt: 0, rankGMOffset: 1048576
dstRankId: 1, tokenIndex: 7, dstExpertId 27, curExpertCnt: 0, rankGMOffset: 1769472
dstRankId: 0, tokenIndex: 8, dstExpertId 9, curExpertCnt: 0, rankGMOffset: 1638400
dstRankId: 1, tokenIndex: 9, dstExpertId 31, curExpertCnt: 0, rankGMOffset: 2031616
dstRankId: 0, tokenIndex: 10, dstExpertId 11, curExpertCnt: 0, rankGMOffset: 1769472
dstRankId: 1, tokenIndex: 11, dstExpertId 17, curExpertCnt: 0, rankGMOffset: 1114112
#######SendToMoeExpert##########rank: 1, aiv_id: 1
dstRankId: 1, tokenIndex: 12, dstExpertId 27, curExpertCnt: 1, rankGMOffset: 1770624
dstRankId: 1, tokenIndex: 13, dstExpertId 28, curExpertCnt: 0, rankGMOffset: 1835008
dstRankId: 0, tokenIndex: 14, dstExpertId 15, curExpertCnt: 0, rankGMOffset: 2031616
dstRankId: 0, tokenIndex: 15, dstExpertId 1, curExpertCnt: 0, rankGMOffset: 1114112
dstRankId: 1, tokenIndex: 16, dstExpertId 25, curExpertCnt: 0, rankGMOffset: 1638400
dstRankId: 1, tokenIndex: 17, dstExpertId 26, curExpertCnt: 0, rankGMOffset: 1703936
dstRankId: 0, tokenIndex: 18, dstExpertId 11, curExpertCnt: 1, rankGMOffset: 1770624
dstRankId: 0, tokenIndex: 19, dstExpertId 5, curExpertCnt: 1, rankGMOffset: 1377408
dstRankId: 0, tokenIndex: 20, dstExpertId 3, curExpertCnt: 0, rankGMOffset: 1245184
dstRankId: 0, tokenIndex: 21, dstExpertId 2, curExpertCnt: 0, rankGMOffset: 1179648
dstRankId: 0, tokenIndex: 22, dstExpertId 4, curExpertCnt: 0, rankGMOffset: 1310720
dstRankId: 0, tokenIndex: 23, dstExpertId 10, curExpertCnt: 0, rankGMOffset: 1703936
#######SendToMoeExpert##########rank: 1, aiv_id: 2
dstRankId: 1, tokenIndex: 24, dstExpertId 28, curExpertCnt: 1, rankGMOffset: 1836160
dstRankId: 1, tokenIndex: 25, dstExpertId 22, curExpertCnt: 0, rankGMOffset: 1441792
dstRankId: 0, tokenIndex: 26, dstExpertId 6, curExpertCnt: 1, rankGMOffset: 1442944
dstRankId: 0, tokenIndex: 27, dstExpertId 13, curExpertCnt: 0, rankGMOffset: 1900544
dstRankId: 0, tokenIndex: 28, dstExpertId 3, curExpertCnt: 1, rankGMOffset: 1246336
dstRankId: 1, tokenIndex: 29, dstExpertId 30, curExpertCnt: 0, rankGMOffset: 1966080
dstRankId: 1, tokenIndex: 30, dstExpertId 18, curExpertCnt: 0, rankGMOffset: 1179648
dstRankId: 0, tokenIndex: 31, dstExpertId 15, curExpertCnt: 1, rankGMOffset: 2032768
dstRankId: 1, tokenIndex: 32, dstExpertId 27, curExpertCnt: 2, rankGMOffset: 1771776
dstRankId: 1, tokenIndex: 33, dstExpertId 31, curExpertCnt: 1, rankGMOffset: 2032768
dstRankId: 0, tokenIndex: 34, dstExpertId 7, curExpertCnt: 0, rankGMOffset: 1507328
dstRankId: 0, tokenIndex: 35, dstExpertId 5, curExpertCnt: 2, rankGMOffset: 1378560
#######SendToMoeExpert##########rank: 1, aiv_id: 3
dstRankId: 1, tokenIndex: 36, dstExpertId 31, curExpertCnt: 2, rankGMOffset: 2033920
dstRankId: 0, tokenIndex: 37, dstExpertId 2, curExpertCnt: 1, rankGMOffset: 1180800
dstRankId: 1, tokenIndex: 38, dstExpertId 21, curExpertCnt: 0, rankGMOffset: 1376256
dstRankId: 1, tokenIndex: 39, dstExpertId 23, curExpertCnt: 0, rankGMOffset: 1507328
dstRankId: 1, tokenIndex: 40, dstExpertId 20, curExpertCnt: 1, rankGMOffset: 1311872
dstRankId: 0, tokenIndex: 41, dstExpertId 3, curExpertCnt: 2, rankGMOffset: 1247488
dstRankId: 0, tokenIndex: 42, dstExpertId 6, curExpertCnt: 2, rankGMOffset: 1444096
dstRankId: 0, tokenIndex: 43, dstExpertId 8, curExpertCnt: 0, rankGMOffset: 1572864
dstRankId: 0, tokenIndex: 44, dstExpertId 7, curExpertCnt: 1, rankGMOffset: 1508480
dstRankId: 0, tokenIndex: 45, dstExpertId 12, curExpertCnt: 0, rankGMOffset: 1835008
dstRankId: 0, tokenIndex: 46, dstExpertId 5, curExpertCnt: 3, rankGMOffset: 1379712
dstRankId: 0, tokenIndex: 47, dstExpertId 15, curExpertCnt: 2, rankGMOffset: 2033920
四个核执行完SendToMoeExpert,打印expandIdxGMTensor的数据:
--------expandIdxGMTensor------------
0 0 0 0 0 0 0 0 0 0 0 0
1 0 0 0 0 0 1 1 0 0 0 0
1 0 1 0 1 0 0 1 2 1 0 2
2 1 0 0 1 2 2 0 1 0 3 2
SetStatus
SendToMoeExpert中hidden states的拷贝涉及RDMA通信。SetStatus需要发送状态,通知数据发送完成。
c++
template <TemplateMC2TypeClass>
__aicore__ inline void MoeDistributeDispatch<TemplateMC2TypeFunc>::SetStatus()
{
totalExpertNum_ = sharedExpertRankNum_ + moeExpertNum_;
SplitToCore(totalExpertNum_, aivNum_, startExpertId_, endExpertId_, sendExpertNum_);
for (uint32_t curExpertId = startExpertId_; curExpertId < endExpertId_; ++curExpertId) {
if (curExpertId < sharedExpertRankNum_) {
continue;
}
int32_t curExpertCnt = 0;
int32_t curMoeExpertId = curExpertId - sharedExpertRankNum_;
CalTokenSendExpertCnt(curMoeExpertId, expertIdsCnt_, curExpertCnt);
int32_t cntPosIndex = curExpertId * 8 + 1; // 8的含义为一个专家占32字节
statusTensor_(cntPosIndex) = curExpertCnt;
}
PipeBarrier<PIPE_ALL>();
SyncAll<true>();
if (startExpertId_ >= totalExpertNum_) {
return;
}
GlobalTensor<int32_t> rankGMTensor;
uint32_t offset = stateOffset_ * epRankId_;
for (uint32_t rankIndex = startExpertId_; rankIndex < endExpertId_; ++rankIndex) {
uint32_t dstRankId = rankIndex;
if (moeExpertNumPerRank_ > 1 && (rankIndex >= sharedExpertRankNum_)) {
dstRankId = ((rankIndex - sharedExpertRankNum_) / moeExpertNumPerRank_ + sharedExpertRankNum_);
offset = (epRankId_ + (rankIndex - sharedExpertRankNum_) % moeExpertNumPerRank_ * epWorldSize_) * stateOffset_;
}
GM_ADDR rankGM = (__gm__ uint8_t*)(GetWindStateAddrByRankId(COMM_EP_IDX, dstRankId) + offset); // 计算地址偏移
rankGMTensor.SetGlobalBuffer((__gm__ int32_t*)rankGM);
#if defined(ASCENDC_OOM) && ASCENDC_OOM == 1
OOMCheckAddrRange<int32_t>(
(__gm__ int32_t*)(GetWindStateAddrByRankId(COMM_EP_IDX, dstRankId)), STATE_SIZE);
#endif
DataCopy<int32_t>(rankGMTensor, statusTensor_[rankIndex * 8], 8UL); // 8时数据大小,按32对齐拷贝
}
SyncFunc<AscendC::HardEvent::MTE3_MTE2>();
}
在statusTensor_中,每个专家占据的标志位长度为8 * sizeof(int32_t)。第一个int表示flag(值为0x3F800000,是float的1)。第二个int表示curExpertCnt。
每个专家在statusSpaceGm_占据的长度为stateOffset_。
constexpr uint32_t STATE_OFFSET = 512; // 状态空间偏移地址
stateOffset_ = (recvWinBlockNum_ > 512) ? (STATE_OFFSET / 2) : STATE_OFFSET;
测试程序打印信息
ExpertId:专家id
dstRank:专家id权重所在的rank
offset:rankGM中的地址偏移。
curExpertCnt:参考函数CalTokenSendExpertCnt。
rank0上四个核的打印信息:
0号卡上,根据输入的expertIds,ExpertId为0的计数(curExpertCnt)为2;ExpertId为2的计数为3。
GetWindStateAddrByRankId返回的地址用addr表示。0号卡上,addr的偏移地址空间[0,512)存储rank0上ExpertId 0的flag和curExpertCnt ;addr的偏移地址空间[512, 1024)存储来自rank1上ExpertId 0的flag和curExpertCnt 。
#######SetStatus##########rank: 0, aiv_id: 0
ExpertId 0, dstRank: 0, offset: 0, curExpertCnt 2
ExpertId 1, dstRank: 0, offset: 1024, curExpertCnt 2
ExpertId 2, dstRank: 0, offset: 2048, curExpertCnt 3
ExpertId 3, dstRank: 0, offset: 3072, curExpertCnt 2
ExpertId 4, dstRank: 0, offset: 4096, curExpertCnt 0
ExpertId 5, dstRank: 0, offset: 5120, curExpertCnt 1
ExpertId 6, dstRank: 0, offset: 6144, curExpertCnt 2
ExpertId 7, dstRank: 0, offset: 7168, curExpertCnt 1
#######SetStatus##########rank: 0, aiv_id: 1
ExpertId 8, dstRank: 0, offset: 8192, curExpertCnt 1
ExpertId 9, dstRank: 0, offset: 9216, curExpertCnt 1
ExpertId 10, dstRank: 0, offset: 10240, curExpertCnt 1
ExpertId 11, dstRank: 0, offset: 11264, curExpertCnt 3
ExpertId 12, dstRank: 0, offset: 12288, curExpertCnt 2
ExpertId 13, dstRank: 0, offset: 13312, curExpertCnt 1
ExpertId 14, dstRank: 0, offset: 14336, curExpertCnt 1
ExpertId 15, dstRank: 0, offset: 15360, curExpertCnt 0
#######SetStatus##########rank: 0, aiv_id: 2
ExpertId 16, dstRank: 1, offset: 0, curExpertCnt 2
ExpertId 17, dstRank: 1, offset: 1024, curExpertCnt 2
ExpertId 18, dstRank: 1, offset: 2048, curExpertCnt 1
ExpertId 19, dstRank: 1, offset: 3072, curExpertCnt 2
ExpertId 20, dstRank: 1, offset: 4096, curExpertCnt 0
ExpertId 21, dstRank: 1, offset: 5120, curExpertCnt 2
ExpertId 22, dstRank: 1, offset: 6144, curExpertCnt 2
ExpertId 23, dstRank: 1, offset: 7168, curExpertCnt 2
#######SetStatus##########rank: 0, aiv_id: 3
ExpertId 24, dstRank: 1, offset: 8192, curExpertCnt 3
ExpertId 25, dstRank: 1, offset: 9216, curExpertCnt 0
ExpertId 26, dstRank: 1, offset: 10240, curExpertCnt 0
ExpertId 27, dstRank: 1, offset: 11264, curExpertCnt 2
ExpertId 28, dstRank: 1, offset: 12288, curExpertCnt 0
ExpertId 29, dstRank: 1, offset: 13312, curExpertCnt 3
ExpertId 30, dstRank: 1, offset: 14336, curExpertCnt 3
ExpertId 31, dstRank: 1, offset: 15360, curExpertCnt 1
rank1上四个核的打印信息:
#######SetStatus##########rank: 1, aiv_id: 0
ExpertId 0, dstRank: 0, offset: 512, curExpertCnt 1
ExpertId 1, dstRank: 0, offset: 1536, curExpertCnt 1
ExpertId 2, dstRank: 0, offset: 2560, curExpertCnt 2
ExpertId 3, dstRank: 0, offset: 3584, curExpertCnt 3
ExpertId 4, dstRank: 0, offset: 4608, curExpertCnt 1
ExpertId 5, dstRank: 0, offset: 5632, curExpertCnt 4
ExpertId 6, dstRank: 0, offset: 6656, curExpertCnt 3
ExpertId 7, dstRank: 0, offset: 7680, curExpertCnt 2
#######SetStatus##########rank: 1, aiv_id: 1
ExpertId 8, dstRank: 0, offset: 8704, curExpertCnt 1
ExpertId 9, dstRank: 0, offset: 9728, curExpertCnt 1
ExpertId 10, dstRank: 0, offset: 10752, curExpertCnt 1
ExpertId 11, dstRank: 0, offset: 11776, curExpertCnt 2
ExpertId 12, dstRank: 0, offset: 12800, curExpertCnt 1
ExpertId 13, dstRank: 0, offset: 13824, curExpertCnt 1
ExpertId 14, dstRank: 0, offset: 14848, curExpertCnt 0
ExpertId 15, dstRank: 0, offset: 15872, curExpertCnt 3
#######SetStatus##########rank: 1, aiv_id: 2
ExpertId 16, dstRank: 1, offset: 512, curExpertCnt 1
ExpertId 17, dstRank: 1, offset: 1536, curExpertCnt 1
ExpertId 18, dstRank: 1, offset: 2560, curExpertCnt 1
ExpertId 19, dstRank: 1, offset: 3584, curExpertCnt 0
ExpertId 20, dstRank: 1, offset: 4608, curExpertCnt 2
ExpertId 21, dstRank: 1, offset: 5632, curExpertCnt 1
ExpertId 22, dstRank: 1, offset: 6656, curExpertCnt 1
ExpertId 23, dstRank: 1, offset: 7680, curExpertCnt 1
#######SetStatus##########rank: 1, aiv_id: 3
ExpertId 24, dstRank: 1, offset: 8704, curExpertCnt 1
ExpertId 25, dstRank: 1, offset: 9728, curExpertCnt 1
ExpertId 26, dstRank: 1, offset: 10752, curExpertCnt 1
ExpertId 27, dstRank: 1, offset: 11776, curExpertCnt 3
ExpertId 28, dstRank: 1, offset: 12800, curExpertCnt 2
ExpertId 29, dstRank: 1, offset: 13824, curExpertCnt 1
ExpertId 30, dstRank: 1, offset: 14848, curExpertCnt 1
ExpertId 31, dstRank: 1, offset: 15872, curExpertCnt 3
WaitDispatch
c++
template <TemplateMC2TypeClass>
__aicore__ inline void MoeDistributeDispatch<TemplateMC2TypeFunc>::WaitDispatch()
{
uint32_t startStatusIndex = 0;
uint32_t endStatusIndex = 0;
uint32_t rscvStatusNum = isShareExpertRank_ ? epWorldSize_ : recvWinBlockNum_;
SplitToCore(rscvStatusNum, aivNum_, startStatusIndex, endStatusIndex, recStatusNumPerCore_);
InitBufferWait();
if (startStatusIndex >= rscvStatusNum) {
SyncAll<true>();
return;
}
LocalTensor<float> gatherMaskOutTensor = gatherMaskOutBuf_.Get<float>();
LocalTensor<uint32_t> gatherTmpTensor = scalarBuf_.GetWithOffset<uint32_t>(UB_ALIGN / sizeof(uint32_t), 0);
gatherTmpTensor.SetValue(0, 1);
LocalTensor<float> statusSumOutTensor = scalarBuf_.GetWithOffset<float>(UB_ALIGN / sizeof(float), UB_ALIGN);
statusFp32Tensor_ = statusTensor_.ReinterpretCast<float>();
uint32_t mask = 1; // gatherMask + sum 相关参数
uint64_t rsvdCnt = 0;
uint32_t recStatusNumPerCoreInner = Ceil(recStatusNumPerCore_ * sizeof(float), UB_ALIGN) * UB_ALIGN / sizeof(float);
SumParams sumParams{1, recStatusNumPerCoreInner, recStatusNumPerCore_};
float sumOfFlag = static_cast<float>(-1.0);
float minTarget = (sumTarget_ * recStatusNumPerCore_) - (float)0.5;
float maxTarget = (sumTarget_ * recStatusNumPerCore_) + (float)0.5;
DataCopyParams intriParams{static_cast<uint16_t>(recStatusNumPerCore_), 1,
static_cast<uint16_t>((recvWinBlockNum_ > 512) ? 7 : 15), 0}; // srcStride为15个block
SyncFunc<AscendC::HardEvent::S_V>();
while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) {
DataCopy(statusFp32Tensor_, windowInstatusFp32Tensor_[startStatusIndex * stateOffset_ / sizeof(float)], intriParams);
SyncFunc<AscendC::HardEvent::MTE2_V>();
GatherMask(gatherMaskOutTensor, statusFp32Tensor_, gatherTmpTensor, true, mask,
{1, (uint16_t)recStatusNumPerCore_, 1, 0}, rsvdCnt);
PipeBarrier<PIPE_V>();
Sum(statusSumOutTensor, gatherMaskOutTensor, sumParams);
SyncFunc<AscendC::HardEvent::V_S>();
sumOfFlag = statusSumOutTensor.GetValue(0);
}
SyncAll<true>();
}
recvWinBlockNum_ = epWorldSize_ * moeExpertNumPerRank_;
WaitDispatch根据flag的个数,判断数据传输是否完成。每个aiv核需要等待接收的flag个数为recStatusNumPerCore_。
接收到recStatusNumPerCore_个flag后,退出while循环,调用SyncAll,同步同一个rank上的aiv核。SyncAll退出后,说明整个rank已经接收到其他rank发送的hidden states。
测试程序打印信息
rank0上四个核的打印信息:
#######WaitDispatch##########rank: 0, aiv_id: 0
rscvStatusNum: 32, recStatusNumPerCore: 8, startStatusIndex: 0, endStatusIndex: 8
minTarget: 7.500000, maxTarget: 8.500000, sumOfFlag: 8.000000
针对此测试例子,其他3个核,打印的信息相同,省略。
rank1上四个核的打印信息:
#######WaitDispatch##########rank: 1, aiv_id: 0
rscvStatusNum: 32, recStatusNumPerCore: 8, startStatusIndex: 0, endStatusIndex: 8
minTarget: 7.500000, maxTarget: 8.500000, sumOfFlag: 8.000000
针对此测试例子,其他3个核,打印的信息相同,省略。
LocalWindowCopy
c++
template <TemplateMC2TypeClass>
__aicore__ inline void MoeDistributeDispatch<TemplateMC2TypeFunc>::GetCumSum(LocalTensor<int32_t>& outLocal, int32_t totalCount)
{
DataCopyParams intriParams{static_cast<uint16_t>(recvWinBlockNum_), 1,
static_cast<uint16_t>((recvWinBlockNum_ > 512) ? 7 : 15), 0}; // srcStride为15个block
DataCopy(statusTensor_, windowInstatusTensor_, intriParams);
if (isShareExpertRank_) {
SyncFunc<AscendC::HardEvent::MTE2_S>();
for (uint32_t curStatusExpId = 0; curStatusExpId < sharedExpertRankNum_; ++curStatusExpId) {
int32_t curExpertCnt = (curStatusExpId + 1 + epRankId_) * axisBS_ / sharedExpertRankNum_
- (curStatusExpId + epRankId_) * axisBS_ / sharedExpertRankNum_;
statusTensor_((curStatusExpId) * 8 + 1) = curExpertCnt;
}
SyncFunc<AscendC::HardEvent::S_V>();
} else {
SyncFunc<AscendC::HardEvent::MTE2_V>();
}
outLocal = gatherMaskOutBuf_.Get<int32_t>(); // 内存复用
LocalTensor<float> getTotalLocal = getTotalBuf_.Get<float>();
// gather mask在一起
LocalTensor<uint32_t> gatherTmpTensor = gatherTmpBuf_.Get<uint32_t>();
Duplicate(gatherTmpTensor, (uint32_t)33686018, recvWinBlockNum_ / 4); // 0000 0010 0000 0010 0000 0010 0000 0010
PipeBarrier<PIPE_V>();
uint32_t mask = recvWinBlockNum_ * 8; // 512 / 32
uint64_t rsvdCnt = 0;
GatherMask(outLocal, statusTensor_, gatherTmpTensor, true, mask, {1, 1, 0, 0}, rsvdCnt);
LocalTensor<float> tmpFp32 = outLocal.ReinterpretCast<float>();
PipeBarrier<PIPE_V>();
ReduceSum<float>(getTotalLocal, tmpFp32, workLocalTensor_, epWorldSize_);
totalCnt_ = getTotalLocal.ReinterpretCast<int32_t>().GetValue(0);
PipeBarrier<PIPE_V>();
ReduceSum<float>(tmpFp32, tmpFp32, workLocalTensor_, totalCount);
PipeBarrier<PIPE_V>();
}
LocalWindowCopy调用GetCumSum,获取专家处理的token计数。
rank0,rank1上专家0到16产生的curExpertCnt值,统计到表格:
| rank\id | 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| rank0 | 2 | 2 | 3 | 2 | 0 | 1 | 2 | 1 | 1 | 1 | 1 | 3 | 2 | 1 | 1 | 0 |
| rank1 | 1 | 1 | 2 | 3 | 1 | 4 | 3 | 2 | 1 | 1 | 1 | 2 | 1 | 1 | 0 | 3 |
| 合计 | 3 | 3 | 5 | 5 | 1 | 5 | 5 | 3 | 2 | 2 | 2 | 5 | 3 | 2 | 1 | 3 |
GetCumSum函数中执行GatherMask后,打印的outLocal值(同一个rank,不同核的outLocal的值相同)。表格中的值从上到下读。
--------outLocal------------
2 1 2 1 3 2 2 3
0 1 1 4 2 3 1 2
1 1 1 1 1 1 3 2
2 1 1 1 1 0 0 3
GetCumSum函数中执行GatherMask后,rank0中打印的outLocal值。
outCountLocal cumsum打印的值,是LocalWindowCopy执行了下面代码之后的输出。
c++
for (uint32_t index = startExpertId_; index < endExpertId_; index++) {
uint32_t i = index - startExpertId_;
if (i > 0) {
outCountLocal.SetValue(i, outCountLocal.GetValue(i - 1) + outCountLocal.GetValue(index));
}
}
//来自我的模拟代码
uint32_t count = endExpertId_ - startExpertId_;
LOG_INFO("startExpertId: %d, endExpertId %d", startExpertId_, endExpertId_);
print_tensor(outCountLocal, "outCountLocal cumsum", count);
#######GetCumSum##########rank: 0, aiv_id: 0
--------outLocal ReduceSum, totalCount = startExpertId_ + 1: 1------------
2 1 2 1 3 2 2 3
0 1 1 4 2 3 1 2
1 1 1 1 1 1 3 2
2 1 1 1 1 0 0 3
#######LocalWindowCopy##########rank: 0, aiv_id: 0
startExpertId: 0, endExpertId 8
--------outCountLocal cumsum------------
2 3 5 6 9 11 13 16
#######GetCumSum##########rank: 0, aiv_id: 1
--------outLocal ReduceSum, totalCount = startExpertId_ + 1: 9------------
16 1 2 1 3 2 2 3
0 1 1 4 2 3 1 2
1 1 1 1 1 1 3 2
2 1 1 1 1 0 0 3
#######LocalWindowCopy##########rank: 0, aiv_id: 1
startExpertId: 8, endExpertId 16
--------outCountLocal cumsum------------
16 17 18 22 24 27 28 30
#######GetCumSum##########rank: 0, aiv_id: 2
--------outLocal ReduceSum, totalCount = startExpertId_ + 1: 17------------
31 1 2 1 3 2 2 3
0 1 1 4 2 3 1 2
1 1 1 1 1 1 3 2
2 1 1 1 1 0 0 3
#######LocalWindowCopy##########rank: 0, aiv_id: 2
startExpertId: 16, endExpertId 24
#######GetCumSum##########rank: 0, aiv_id: 3
--------outLocal ReduceSum, totalCount = startExpertId_ + 1: 25------------
43 1 2 1 3 2 2 3
0 1 1 4 2 3 1 2
1 1 1 1 1 1 3 2
2 1 1 1 1 0 0 3
#######LocalWindowCopy##########rank: 0, aiv_id: 3
startExpertId: 24, endExpertId 32
--------outCountLocal cumsum------------
43 44 45 46 47 47 47 50
LocalWindowCopy最后将outCountLocal的值复制到sendCountsOutGM_。
c++
template <TemplateMC2TypeClass>
__aicore__ inline void MoeDistributeDispatch<TemplateMC2TypeFunc>::LocalWindowCopy()
DataCopyExtParams dataCopyOutParams = {1U, static_cast<uint32_t>(sendExpertNum_ * sizeof(int32_t)), 0U, 0U, 0U};
GlobalTensor<int32_t> sendCountsGlobal;
sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(sendCountsOutGM_));
DataCopyPad(sendCountsGlobal[startExpertId_], outCountLocal, dataCopyOutParams);
PipeBarrier<PIPE_MTE3>();
}
UpdateTokenNumsOut
只有最后一个核执行UpdateTokenNumsOut函数,需要调用SyncAll同步。只看UpdateMultiMoeTokenNumsOut的分支。
c++
// 更新tokenNumsOut tensor
template <TemplateMC2TypeClass>
__aicore__ inline void MoeDistributeDispatch<TemplateMC2TypeFunc>::UpdateTokenNumsOut()
{
// 最后一个核做更新,Moe专家只有最后一个核有计算出所有 sendCountsGlobal
if (!isShareExpertRank_ && moeExpertNumPerRank_ > 1) {
SyncAll<true>();
if (aivId_ != lastCore_) return;
SyncFunc<AscendC::HardEvent::MTE3_S>();
UpdateMultiMoeTokenNumsOut();
}
}
// 更新多专家卡上的tokenNumsOut tensor
template <TemplateMC2TypeClass>
__aicore__ inline void MoeDistributeDispatch<TemplateMC2TypeFunc>::UpdateMultiMoeTokenNumsOut()
{
uint32_t tokenSums = 0;
GlobalTensor<int32_t> sendCountsGlobal;
sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(sendCountsOutGM_));
for (uint32_t localMoeIndex = 0; localMoeIndex < moeExpertNumPerRank_; ++localMoeIndex) {
if (localMoeIndex == 0) {
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
sendCountsGlobal[epWorldSize_ - 1]);
uint32_t firstMoeCnt = sendCountsGlobal.GetValue(epWorldSize_ - 1);
tokenSums = firstMoeCnt + gatherCount_;
expertTokenNumsOutGMTensor_.SetValue(localMoeIndex, tokenSums);
DataCacheCleanAndInvalid<int64_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
expertTokenNumsOutGMTensor_[localMoeIndex]);
} else {
uint32_t preIndex = epWorldSize_ * (localMoeIndex - 1) + epWorldSize_- 1;
uint32_t curIndex = epWorldSize_ * localMoeIndex + epWorldSize_ - 1;
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
sendCountsGlobal[preIndex]);
DataCacheCleanAndInvalid<int32_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
sendCountsGlobal[curIndex]);
uint32_t preMoeIndexCnt = sendCountsGlobal.GetValue(preIndex);
uint32_t curMoeIndexCnt = sendCountsGlobal.GetValue(curIndex);
tokenSums = ((expertTokenNumsType_ == 0) ? tokenSums : 0) + (curMoeIndexCnt - preMoeIndexCnt) + gatherCount_;
expertTokenNumsOutGMTensor_.SetValue(localMoeIndex, tokenSums);
DataCacheCleanAndInvalid<int64_t, CacheLine::SINGLE_CACHE_LINE, DcciDst::CACHELINE_OUT>(
expertTokenNumsOutGMTensor_[localMoeIndex]);
}
}
}
测试程序打印信息
#######UpdateMultiMoeTokenNumsOut##########rank: 0, aiv_id: 3
--------sendCountsGlobal------------
2 3 5 6 9 11 13 16
16 17 18 22 24 27 28 30
31 32 33 34 35 36 39 41
43 44 45 46 47 47 47 50
--------expertTokenNumsOutGMTensor------------
3 3 5 5 1 5 5 3
2 2 2 5 3 2 1 3
同表格中的合计值一致。expertTokenNumsOutGMTensor的索引代表专家id,值代表专家处理的token计数。
辅助功能函数分析
CalTokenSendExpertCnt
c++
template <TemplateMC2TypeClass>
__aicore__ inline void MoeDistributeDispatch<TemplateMC2TypeFunc>::CalTokenSendExpertCnt(uint32_t dstExpertId, int32_t calCnt,
int32_t &curExpertCnt)
{
if (isQuant_) { // 量化模式下buffer复用
dstExpIdTensor_ = receiveDataCastFloatBuf_.Get<int32_t>();
subExpIdTensor_ = smoothScalesBuf_.Get<int32_t>();
}
Duplicate<int32_t>(dstExpIdTensor_, dstExpertId, calCnt);
PipeBarrier<PIPE_V>();
Sub(subExpIdTensor_, expertIdsTensor_, dstExpIdTensor_, calCnt);
PipeBarrier<PIPE_V>();
LocalTensor<float> tmpFp32 = subExpIdTensor_.ReinterpretCast<float>();
LocalTensor<float> tmpoutFp32 = dstExpIdTensor_.ReinterpretCast<float>();
Abs(tmpoutFp32, tmpFp32, calCnt);
PipeBarrier<PIPE_V>();
Mins(subExpIdTensor_, dstExpIdTensor_, 1, calCnt);
PipeBarrier<PIPE_V>();
ReduceSum<float>(tmpoutFp32, tmpFp32, workLocalTensor_, calCnt);
SyncFunc<AscendC::HardEvent::V_S>();
int32_t curOtherExpertCnt = dstExpIdTensor_(0);
if (calCnt > curOtherExpertCnt) {
curExpertCnt = calCnt - curOtherExpertCnt;
}
}
函数计算expertIdsTensor_中值为dstExpertId的计数。
Duplicate:复制dstExpertId到dstExpIdTensor_,次数为calCnt。
Sub,作差:expertIdsTensor_ - dstExpIdTensor_,值保存到subExpIdTensor_。
Abs:对subExpIdTensor_中保存的结果做绝对值,内容被当做float类型。结果保存到dstExpIdTensor_中的空间中。
c++
template <typename T>
__aicore__ inline void Abs(const LocalTensor<T>& dstLocal, const LocalTensor<T>& srcLocal, const int32_t& calCount)
{
for(int32_t i = 0; i < calCount; i++) {
if (srcLocal.GetValue(i) >=0) {
dstLocal.SetValue(i, srcLocal.GetValue(i));
} else {
if (std::is_same<T, float>::value) {
dstLocal.SetValue(i, std::fabs(srcLocal(i)));
} else {
dstLocal.SetValue(i, std::abs(srcLocal(i)));
}
}
}
}
Mins:subExpIdTensor_(i) = min(dstExpIdTensor_(i), 1)
c++
//https://developer.huawei.com/consumer/cn/doc/hiai-References/cannkit-scalar-binocular-mins-0000002123235326
template <typename T, bool isSetMask = true>
__aicore__ inline void Mins(const LocalTensor<T>& dstLocal, const LocalTensor<T>& srcLocal, const T& scalarValue, const int32_t& calCount)
{
for(int32_t i = 0; i < calCount; i++) {
if (srcLocal.GetValue(i) >=scalarValue) {
dstLocal.SetValue(i, scalarValue);
} else {
dstLocal.SetValue(i, srcLocal(i));
}
}
}
以随机生成的expertIds为例:
--------expertIds------------
17 18 30 29 6 12 10 19 26 13 15 22
12 26 10 30 11 25 29 22 15 5 18 28
29 14 12 15 2 6 9 13 17 21 18 27
18 12 11 2 22 3 16 30 10 13 28 26
CalTokenSendExpertCnt的输入相关参数为:
c++
int calCount = 37;
uint32_t dstExpertId = expertIds[calCount];
int curExpertCnt = 0;
dispatch.CalTokenSendExpertCnt(dstExpertId, calCount, curExpertCnt, true);
LOG_INFO("index: %d, dstExpertId %d, curExpertCnt: %d", calCount, dstExpertId, curExpertCnt);
//index: 37, dstExpertId 12, curExpertCnt: 3
经过Duplicate,Sub, Abs,Mins处理后的subExpIdTensor_的结果:
--------subExpIdTensor_------------
1 1 1 1 1 0 1 1 1 1 1 1
0 1 1 1 1 1 1 1 1 1 1 1
1 1 0 1 1 1 1 1 1 1 1 1
1
测试程序源码
to be uploaded.