魔改-隐语PSI通信,支持外部通信自定义

经过摸索,终于将隐语通信服务暴露出来,以后就可以将隐语的算子使用我们的通信流畅使用喽。

一.联络结构

SPU <-----PSI<------yacl

二、每层进行修改

1.yacl层

向外暴露的类参数为Context,所以我们再这里增加一个新的内置函数用来传递通信函数。

cpp 复制代码
// 使用抽象函数传递
  void AddChannel(std::function<int(const std::string &, std::string &)> sendCb, std::function<std::string(const std::string&)> recvCb) {
    for (size_t i = 0; i < channels_.size(); ++i) {
      if (i == rank_) {
        continue;
      }
      auto& channel = channels_[i];
      if (!channel) {
        continue;
      }
      channel->SetChannel(sendCb, recvCb); // 更为底层的通信在channel中
    }
  };
cpp 复制代码
  void SetChannel(std::function<int(const std::string &, std::string &)> sendCb, std::function<std::string(const std::string&)> recvCb){
    SPDLOG_INFO("Add SetChannel");
    
    sendCallback_ = sendCb;
    recvCallback_ = recvCb;
  };
 protected:
  std::function<int(const std::string &, std::string &)> sendCallback_;
  std::function<std::string(const std::string&)> recvCallback_;

然后我们将最底层的通信中增加我们的通信:

cpp 复制代码
void Channel::Send(const std::string& msg_key, ByteContainerView value) {
  if (aborting_.load()) {
    YACL_THROW_LINK_ABORTED("Send is not allowed when channel is aborting");
  }
  if (YACL_UNLIKELY(disable_msg_seq_id_)) {
    YACL_THROW("Send is not allowed when msg_seq_id is disabled");
  }

  YACL_ENFORCE(!waiting_finish_.load(),
               "Send is not allowed when channel is closing");
  NormalMessageKeyEnforce(msg_key);

 // 增加我们的通信
  if (sendCallback_) {
        SPDLOG_DEBUG("GAIA send tag={}, size={} (sync)", msg_key, value.size());
        std::string str((char*)value.data(), value.size());
        sendCallback_(msg_key, str);
        return;
  } 


  size_t seq_id = msg_seq_id_.fetch_add(1) + 1;
  auto key = BuildChannelKey(msg_key, seq_id);
  send_msgs_.Push(Message(seq_id, std::move(key), value));
  send_sync_.WaitSeqIdSendFinished(seq_id);
}
cpp 复制代码
Buffer Channel::Recv(const std::string& msg_key) {
  if (aborting_.load()) {
    YACL_THROW_LINK_ABORTED("Recv is not allowed when aborting channel");
  }
  NormalMessageKeyEnforce(msg_key);
  // do not intercept control messages: ack/fin/connect_*
  if (recvCallback_ && msg_key != kAckKey && msg_key != kFinKey && msg_key.rfind("connect_", 0) != 0) {  
      std::string str = recvCallback_(msg_key);
      Buffer future(str.c_str(), str.length());
      return future;
  }
  Buffer value;
  size_t seq_id = 0;
  {
    std::unique_lock<bthread::Mutex> lock(msg_mutex_);
    auto stop_waiting = [&] {
      auto itr = this->recv_msgs_.find(msg_key);
      if (itr == this->recv_msgs_.end()) {
        return false;
      } else {
        std::tie(value, seq_id) = std::move(itr->second);
        this->recv_msgs_.erase(itr);
        return true;
      }
    };
    while (!stop_waiting()) {
      if (aborting_.load()) {
        YACL_THROW_LINK_ABORTED("Aborting channel, skip waiting");
      }
      //                                timeout_us
      if (msg_db_cond_.wait_for(lock, static_cast<int64_t>(recv_timeout_ms_) *
                                          1000) == ETIMEDOUT) {
        YACL_THROW_IO_ERROR("Get data timeout, key={}", msg_key);
      }
    }
  }
  SendAck(seq_id);
  return value;
}

需要注意通信的截止符号,kAckKey、kFinKey

2.PSI层

不用修改,可以做一下测试

3.SPU层的修改

这里我自己封装了一个ecdh接口:

cpp 复制代码
  auto lctx = yacl::link::FactoryBrpc().CreateContext(ctx_desc, role);
  if (send_cb && recv_cb) {
    lctx->AddChannel(std::move(send_cb), std::move(recv_cb));
  }

比如使用内存传输

cpp 复制代码
  static std::mutex mtx;
  static std::condition_variable cv;
  static std::unordered_map<std::string, std::string> mailbox0;
  static std::unordered_map<std::string, std::string> mailbox1;

  auto send_cb0 = [](const std::string& tag, std::string& payload) -> int {
    std::lock_guard<std::mutex> lock(mtx);
    SPDLOG_INFO("send_cb0->1, tag: {}, payload(size): {}", tag, payload.size());

    mailbox1[tag] = std::move(payload);
    cv.notify_all();
    return 0;
  };

  auto recv_cb0 = [](const std::string& tag) -> std::string {
    std::unique_lock<std::mutex> lock(mtx);
    cv.wait(lock, [&]{ return mailbox0.find(tag) != mailbox0.end(); });
    auto it = mailbox0.find(tag);
    std::string out = std::move(it->second);
    SPDLOG_INFO("recv_cb0, tag: {}, payload(size): {}", tag, out.size());
    mailbox0.erase(it);
    return out;
  };

  auto send_cb1 = [](const std::string& tag, std::string& payload) -> int {
    std::lock_guard<std::mutex> lock(mtx);
    SPDLOG_INFO("send_cb1->0, tag: {}, payload(size): {}", tag, payload.size());

    mailbox0[tag] = std::move(payload);
    cv.notify_all();
    return 0;
  };

  auto recv_cb1 = [](const std::string& tag) -> std::string {
    std::unique_lock<std::mutex> lock(mtx);
    cv.wait(lock, [&]{ return mailbox1.find(tag) != mailbox1.end(); });
    auto it = mailbox1.find(tag);
    std::string out = std::move(it->second);
    SPDLOG_INFO("recv_cb1, tag: {}, payload(size): {}", tag, out.size());
    mailbox1.erase(it);
    return out;
  };
4.python层将通信传入c层
cpp 复制代码
  m.def("CreateChannel", 
        [](size_t role,
           py::function sendCb, py::function recvCb)
            -> std::shared_ptr<yacl::link::Context> {
          // 持有 Python 回调的指针,并在析构时获取 GIL 以避免 GIL 断言
          auto send_fn = std::shared_ptr<py::function>(
              new py::function(sendCb),
              [](py::function* f) {
                py::gil_scoped_acquire acquire;
                delete f;
              });
          auto recv_fn = std::shared_ptr<py::function>(
              new py::function(recvCb),
              [](py::function* f) {
                py::gil_scoped_acquire acquire;
                delete f;
              });

          // 将 Python 回调包装为 C++ 回调,调用时获取 GIL
          std::function<int(const std::string&, std::string&)> send_wrapper =
              [send_fn](const std::string& tag, std::string& payload) -> int {
                py::gil_scoped_acquire acquire;
                py::bytes py_payload(payload);
                py::object ret = (*send_fn)(tag, py_payload);
                return ret.cast<int>();
              };
          std::function<std::string(const std::string&)> recv_wrapper =
              [recv_fn](const std::string& tag) -> std::string {
                py::gil_scoped_acquire acquire;
                py::object ret = (*recv_fn)(tag);
                return ret.cast<std::string>();
              };

          py::gil_scoped_release release;
          auto ctx = psi::utils::Createlinks(role, std::move(send_wrapper), std::move(recv_wrapper));
          return ctx;
        },
        py::arg("role"),
        py::arg("send_cb"),
        py::arg("recv_cb"),
        "Create Context and inject memory send/recv callbacks");

python层传参到c

python 复制代码
    def send_cb(tag: str, payload: bytes) -> int:
        if role == 0:
            mailbox1[tag] = payload
        else:
            mailbox0[tag] = payload
        return 0

    def recv_cb(tag: str) -> bytes:
        if role == 0:
            while tag not in mailbox0:
                time.sleep(0.001)
            data = mailbox0[tag]
            del mailbox0[tag]
            return data
        else:
            while tag not in mailbox1:
                time.sleep(0.001)
            data = mailbox1[tag]
            del mailbox1[tag]
            return data

    ctx = CreateChannel(role, send_cb, recv_cb)

可以参考一下

Github-SPL

参考

https://github.com/secretflow/psi

https://github.com/secretflow/yacl

相关推荐
梦想很大很大3 小时前
使用 Go + Gin + Fx 构建工程化后端服务模板(gin-app 实践)
前端·后端·go
We་ct3 小时前
LeetCode 56. 合并区间:区间重叠问题的核心解法与代码解析
前端·算法·leetcode·typescript
Lionel6893 小时前
分步实现 Flutter 鸿蒙轮播图核心功能(搜索框 + 指示灯)
算法·图搜索算法
张3蜂3 小时前
深入理解 Python 的 frozenset:为什么要有“不可变集合”?
前端·python·spring
无小道3 小时前
Qt——事件简单介绍
开发语言·前端·qt
小妖6663 小时前
js 实现快速排序算法
数据结构·算法·排序算法
广州华水科技3 小时前
GNSS与单北斗变形监测技术的应用现状分析与未来发展方向
前端
xsyaaaan3 小时前
代码随想录Day30动态规划:背包问题二维_背包问题一维_416分割等和子集
算法·动态规划
code_YuJun4 小时前
corepack 作用
前端
千寻girling4 小时前
Koa.js 教程 | 一份不可多得的 Node.js 的 Web 框架 Koa.js 教程
前端·后端·面试