魔改-隐语PSI通信,支持外部通信自定义

经过摸索,终于将隐语通信服务暴露出来,以后就可以将隐语的算子使用我们的通信流畅使用喽。

一.联络结构

SPU <-----PSI<------yacl

二、每层进行修改

1.yacl层

向外暴露的类参数为Context,所以我们再这里增加一个新的内置函数用来传递通信函数。

cpp 复制代码
// 使用抽象函数传递
  void AddChannel(std::function<int(const std::string &, std::string &)> sendCb, std::function<std::string(const std::string&)> recvCb) {
    for (size_t i = 0; i < channels_.size(); ++i) {
      if (i == rank_) {
        continue;
      }
      auto& channel = channels_[i];
      if (!channel) {
        continue;
      }
      channel->SetChannel(sendCb, recvCb); // 更为底层的通信在channel中
    }
  };
cpp 复制代码
  void SetChannel(std::function<int(const std::string &, std::string &)> sendCb, std::function<std::string(const std::string&)> recvCb){
    SPDLOG_INFO("Add SetChannel");
    
    sendCallback_ = sendCb;
    recvCallback_ = recvCb;
  };
 protected:
  std::function<int(const std::string &, std::string &)> sendCallback_;
  std::function<std::string(const std::string&)> recvCallback_;

然后我们将最底层的通信中增加我们的通信:

cpp 复制代码
void Channel::Send(const std::string& msg_key, ByteContainerView value) {
  if (aborting_.load()) {
    YACL_THROW_LINK_ABORTED("Send is not allowed when channel is aborting");
  }
  if (YACL_UNLIKELY(disable_msg_seq_id_)) {
    YACL_THROW("Send is not allowed when msg_seq_id is disabled");
  }

  YACL_ENFORCE(!waiting_finish_.load(),
               "Send is not allowed when channel is closing");
  NormalMessageKeyEnforce(msg_key);

 // 增加我们的通信
  if (sendCallback_) {
        SPDLOG_DEBUG("GAIA send tag={}, size={} (sync)", msg_key, value.size());
        std::string str((char*)value.data(), value.size());
        sendCallback_(msg_key, str);
        return;
  } 


  size_t seq_id = msg_seq_id_.fetch_add(1) + 1;
  auto key = BuildChannelKey(msg_key, seq_id);
  send_msgs_.Push(Message(seq_id, std::move(key), value));
  send_sync_.WaitSeqIdSendFinished(seq_id);
}
cpp 复制代码
Buffer Channel::Recv(const std::string& msg_key) {
  if (aborting_.load()) {
    YACL_THROW_LINK_ABORTED("Recv is not allowed when aborting channel");
  }
  NormalMessageKeyEnforce(msg_key);
  // do not intercept control messages: ack/fin/connect_*
  if (recvCallback_ && msg_key != kAckKey && msg_key != kFinKey && msg_key.rfind("connect_", 0) != 0) {  
      std::string str = recvCallback_(msg_key);
      Buffer future(str.c_str(), str.length());
      return future;
  }
  Buffer value;
  size_t seq_id = 0;
  {
    std::unique_lock<bthread::Mutex> lock(msg_mutex_);
    auto stop_waiting = [&] {
      auto itr = this->recv_msgs_.find(msg_key);
      if (itr == this->recv_msgs_.end()) {
        return false;
      } else {
        std::tie(value, seq_id) = std::move(itr->second);
        this->recv_msgs_.erase(itr);
        return true;
      }
    };
    while (!stop_waiting()) {
      if (aborting_.load()) {
        YACL_THROW_LINK_ABORTED("Aborting channel, skip waiting");
      }
      //                                timeout_us
      if (msg_db_cond_.wait_for(lock, static_cast<int64_t>(recv_timeout_ms_) *
                                          1000) == ETIMEDOUT) {
        YACL_THROW_IO_ERROR("Get data timeout, key={}", msg_key);
      }
    }
  }
  SendAck(seq_id);
  return value;
}

需要注意通信的截止符号,kAckKey、kFinKey

2.PSI层

不用修改,可以做一下测试

3.SPU层的修改

这里我自己封装了一个ecdh接口:

cpp 复制代码
  auto lctx = yacl::link::FactoryBrpc().CreateContext(ctx_desc, role);
  if (send_cb && recv_cb) {
    lctx->AddChannel(std::move(send_cb), std::move(recv_cb));
  }

比如使用内存传输

cpp 复制代码
  static std::mutex mtx;
  static std::condition_variable cv;
  static std::unordered_map<std::string, std::string> mailbox0;
  static std::unordered_map<std::string, std::string> mailbox1;

  auto send_cb0 = [](const std::string& tag, std::string& payload) -> int {
    std::lock_guard<std::mutex> lock(mtx);
    SPDLOG_INFO("send_cb0->1, tag: {}, payload(size): {}", tag, payload.size());

    mailbox1[tag] = std::move(payload);
    cv.notify_all();
    return 0;
  };

  auto recv_cb0 = [](const std::string& tag) -> std::string {
    std::unique_lock<std::mutex> lock(mtx);
    cv.wait(lock, [&]{ return mailbox0.find(tag) != mailbox0.end(); });
    auto it = mailbox0.find(tag);
    std::string out = std::move(it->second);
    SPDLOG_INFO("recv_cb0, tag: {}, payload(size): {}", tag, out.size());
    mailbox0.erase(it);
    return out;
  };

  auto send_cb1 = [](const std::string& tag, std::string& payload) -> int {
    std::lock_guard<std::mutex> lock(mtx);
    SPDLOG_INFO("send_cb1->0, tag: {}, payload(size): {}", tag, payload.size());

    mailbox0[tag] = std::move(payload);
    cv.notify_all();
    return 0;
  };

  auto recv_cb1 = [](const std::string& tag) -> std::string {
    std::unique_lock<std::mutex> lock(mtx);
    cv.wait(lock, [&]{ return mailbox1.find(tag) != mailbox1.end(); });
    auto it = mailbox1.find(tag);
    std::string out = std::move(it->second);
    SPDLOG_INFO("recv_cb1, tag: {}, payload(size): {}", tag, out.size());
    mailbox1.erase(it);
    return out;
  };
4.python层将通信传入c层
cpp 复制代码
  m.def("CreateChannel", 
        [](size_t role,
           py::function sendCb, py::function recvCb)
            -> std::shared_ptr<yacl::link::Context> {
          // 持有 Python 回调的指针,并在析构时获取 GIL 以避免 GIL 断言
          auto send_fn = std::shared_ptr<py::function>(
              new py::function(sendCb),
              [](py::function* f) {
                py::gil_scoped_acquire acquire;
                delete f;
              });
          auto recv_fn = std::shared_ptr<py::function>(
              new py::function(recvCb),
              [](py::function* f) {
                py::gil_scoped_acquire acquire;
                delete f;
              });

          // 将 Python 回调包装为 C++ 回调,调用时获取 GIL
          std::function<int(const std::string&, std::string&)> send_wrapper =
              [send_fn](const std::string& tag, std::string& payload) -> int {
                py::gil_scoped_acquire acquire;
                py::bytes py_payload(payload);
                py::object ret = (*send_fn)(tag, py_payload);
                return ret.cast<int>();
              };
          std::function<std::string(const std::string&)> recv_wrapper =
              [recv_fn](const std::string& tag) -> std::string {
                py::gil_scoped_acquire acquire;
                py::object ret = (*recv_fn)(tag);
                return ret.cast<std::string>();
              };

          py::gil_scoped_release release;
          auto ctx = psi::utils::Createlinks(role, std::move(send_wrapper), std::move(recv_wrapper));
          return ctx;
        },
        py::arg("role"),
        py::arg("send_cb"),
        py::arg("recv_cb"),
        "Create Context and inject memory send/recv callbacks");

python层传参到c

python 复制代码
    def send_cb(tag: str, payload: bytes) -> int:
        if role == 0:
            mailbox1[tag] = payload
        else:
            mailbox0[tag] = payload
        return 0

    def recv_cb(tag: str) -> bytes:
        if role == 0:
            while tag not in mailbox0:
                time.sleep(0.001)
            data = mailbox0[tag]
            del mailbox0[tag]
            return data
        else:
            while tag not in mailbox1:
                time.sleep(0.001)
            data = mailbox1[tag]
            del mailbox1[tag]
            return data

    ctx = CreateChannel(role, send_cb, recv_cb)

可以参考一下

Github-SPL

参考

https://github.com/secretflow/psi

https://github.com/secretflow/yacl

相关推荐
Pedantic1 小时前
SwiftUI 手势层级(Gesture Hierarchy)详解
前端
飘尘2 小时前
前端转型全栈(Java后端)的快速上手指引
前端·后端·全栈
一颗烂土豆2 小时前
Meshopt 压缩深度解析,为什么它比 Draco 更快
前端·javascript·webgl
浏览器工程师3 小时前
AI Agent 接浏览器任务,先别让它一路点到底
前端·后端
雨季mo浅忆3 小时前
VSCode自动格式化三要素
前端
爱勇宝4 小时前
深扒 Anthropic 1680 位工程师简历:应届生几乎没机会,AI 公司最缺的不是博士
前端·后端·程序员
kyriewen4 小时前
同事每天催我 Code Review,我写了个脚本让 AI 替我 review PR——现在他反过来催 AI 了
前端·javascript·ai编程
user20585561518136 小时前
Windows 项目安装时报 `node-sass` 错误,如何快速处理
前端
LiaCode6 小时前
Redis 在生产项目的使用
前端·后端