经过摸索,终于将隐语通信服务暴露出来,以后就可以将隐语的算子使用我们的通信流畅使用喽。
一.联络结构
SPU <-----PSI<------yacl
二、每层进行修改
1.yacl层
向外暴露的类参数为Context,所以我们再这里增加一个新的内置函数用来传递通信函数。

cpp
// 使用抽象函数传递
void AddChannel(std::function<int(const std::string &, std::string &)> sendCb, std::function<std::string(const std::string&)> recvCb) {
for (size_t i = 0; i < channels_.size(); ++i) {
if (i == rank_) {
continue;
}
auto& channel = channels_[i];
if (!channel) {
continue;
}
channel->SetChannel(sendCb, recvCb); // 更为底层的通信在channel中
}
};

cpp
void SetChannel(std::function<int(const std::string &, std::string &)> sendCb, std::function<std::string(const std::string&)> recvCb){
SPDLOG_INFO("Add SetChannel");
sendCallback_ = sendCb;
recvCallback_ = recvCb;
};
protected:
std::function<int(const std::string &, std::string &)> sendCallback_;
std::function<std::string(const std::string&)> recvCallback_;
然后我们将最底层的通信中增加我们的通信:
cpp
void Channel::Send(const std::string& msg_key, ByteContainerView value) {
if (aborting_.load()) {
YACL_THROW_LINK_ABORTED("Send is not allowed when channel is aborting");
}
if (YACL_UNLIKELY(disable_msg_seq_id_)) {
YACL_THROW("Send is not allowed when msg_seq_id is disabled");
}
YACL_ENFORCE(!waiting_finish_.load(),
"Send is not allowed when channel is closing");
NormalMessageKeyEnforce(msg_key);
// 增加我们的通信
if (sendCallback_) {
SPDLOG_DEBUG("GAIA send tag={}, size={} (sync)", msg_key, value.size());
std::string str((char*)value.data(), value.size());
sendCallback_(msg_key, str);
return;
}
size_t seq_id = msg_seq_id_.fetch_add(1) + 1;
auto key = BuildChannelKey(msg_key, seq_id);
send_msgs_.Push(Message(seq_id, std::move(key), value));
send_sync_.WaitSeqIdSendFinished(seq_id);
}
cpp
Buffer Channel::Recv(const std::string& msg_key) {
if (aborting_.load()) {
YACL_THROW_LINK_ABORTED("Recv is not allowed when aborting channel");
}
NormalMessageKeyEnforce(msg_key);
// do not intercept control messages: ack/fin/connect_*
if (recvCallback_ && msg_key != kAckKey && msg_key != kFinKey && msg_key.rfind("connect_", 0) != 0) {
std::string str = recvCallback_(msg_key);
Buffer future(str.c_str(), str.length());
return future;
}
Buffer value;
size_t seq_id = 0;
{
std::unique_lock<bthread::Mutex> lock(msg_mutex_);
auto stop_waiting = [&] {
auto itr = this->recv_msgs_.find(msg_key);
if (itr == this->recv_msgs_.end()) {
return false;
} else {
std::tie(value, seq_id) = std::move(itr->second);
this->recv_msgs_.erase(itr);
return true;
}
};
while (!stop_waiting()) {
if (aborting_.load()) {
YACL_THROW_LINK_ABORTED("Aborting channel, skip waiting");
}
// timeout_us
if (msg_db_cond_.wait_for(lock, static_cast<int64_t>(recv_timeout_ms_) *
1000) == ETIMEDOUT) {
YACL_THROW_IO_ERROR("Get data timeout, key={}", msg_key);
}
}
}
SendAck(seq_id);
return value;
}
需要注意通信的截止符号,kAckKey、kFinKey
2.PSI层
不用修改,可以做一下测试
3.SPU层的修改
这里我自己封装了一个ecdh接口:
cpp
auto lctx = yacl::link::FactoryBrpc().CreateContext(ctx_desc, role);
if (send_cb && recv_cb) {
lctx->AddChannel(std::move(send_cb), std::move(recv_cb));
}
比如使用内存传输
cpp
static std::mutex mtx;
static std::condition_variable cv;
static std::unordered_map<std::string, std::string> mailbox0;
static std::unordered_map<std::string, std::string> mailbox1;
auto send_cb0 = [](const std::string& tag, std::string& payload) -> int {
std::lock_guard<std::mutex> lock(mtx);
SPDLOG_INFO("send_cb0->1, tag: {}, payload(size): {}", tag, payload.size());
mailbox1[tag] = std::move(payload);
cv.notify_all();
return 0;
};
auto recv_cb0 = [](const std::string& tag) -> std::string {
std::unique_lock<std::mutex> lock(mtx);
cv.wait(lock, [&]{ return mailbox0.find(tag) != mailbox0.end(); });
auto it = mailbox0.find(tag);
std::string out = std::move(it->second);
SPDLOG_INFO("recv_cb0, tag: {}, payload(size): {}", tag, out.size());
mailbox0.erase(it);
return out;
};
auto send_cb1 = [](const std::string& tag, std::string& payload) -> int {
std::lock_guard<std::mutex> lock(mtx);
SPDLOG_INFO("send_cb1->0, tag: {}, payload(size): {}", tag, payload.size());
mailbox0[tag] = std::move(payload);
cv.notify_all();
return 0;
};
auto recv_cb1 = [](const std::string& tag) -> std::string {
std::unique_lock<std::mutex> lock(mtx);
cv.wait(lock, [&]{ return mailbox1.find(tag) != mailbox1.end(); });
auto it = mailbox1.find(tag);
std::string out = std::move(it->second);
SPDLOG_INFO("recv_cb1, tag: {}, payload(size): {}", tag, out.size());
mailbox1.erase(it);
return out;
};
4.python层将通信传入c层
cpp
m.def("CreateChannel",
[](size_t role,
py::function sendCb, py::function recvCb)
-> std::shared_ptr<yacl::link::Context> {
// 持有 Python 回调的指针,并在析构时获取 GIL 以避免 GIL 断言
auto send_fn = std::shared_ptr<py::function>(
new py::function(sendCb),
[](py::function* f) {
py::gil_scoped_acquire acquire;
delete f;
});
auto recv_fn = std::shared_ptr<py::function>(
new py::function(recvCb),
[](py::function* f) {
py::gil_scoped_acquire acquire;
delete f;
});
// 将 Python 回调包装为 C++ 回调,调用时获取 GIL
std::function<int(const std::string&, std::string&)> send_wrapper =
[send_fn](const std::string& tag, std::string& payload) -> int {
py::gil_scoped_acquire acquire;
py::bytes py_payload(payload);
py::object ret = (*send_fn)(tag, py_payload);
return ret.cast<int>();
};
std::function<std::string(const std::string&)> recv_wrapper =
[recv_fn](const std::string& tag) -> std::string {
py::gil_scoped_acquire acquire;
py::object ret = (*recv_fn)(tag);
return ret.cast<std::string>();
};
py::gil_scoped_release release;
auto ctx = psi::utils::Createlinks(role, std::move(send_wrapper), std::move(recv_wrapper));
return ctx;
},
py::arg("role"),
py::arg("send_cb"),
py::arg("recv_cb"),
"Create Context and inject memory send/recv callbacks");
python层传参到c
python
def send_cb(tag: str, payload: bytes) -> int:
if role == 0:
mailbox1[tag] = payload
else:
mailbox0[tag] = payload
return 0
def recv_cb(tag: str) -> bytes:
if role == 0:
while tag not in mailbox0:
time.sleep(0.001)
data = mailbox0[tag]
del mailbox0[tag]
return data
else:
while tag not in mailbox1:
time.sleep(0.001)
data = mailbox1[tag]
del mailbox1[tag]
return data
ctx = CreateChannel(role, send_cb, recv_cb)
可以参考一下
参考