魔改-隐语PSI通信,支持外部通信自定义

经过摸索,终于将隐语通信服务暴露出来,以后就可以将隐语的算子使用我们的通信流畅使用喽。

一.联络结构

SPU <-----PSI<------yacl

二、每层进行修改

1.yacl层

向外暴露的类参数为Context,所以我们再这里增加一个新的内置函数用来传递通信函数。

cpp 复制代码
// 使用抽象函数传递
  void AddChannel(std::function<int(const std::string &, std::string &)> sendCb, std::function<std::string(const std::string&)> recvCb) {
    for (size_t i = 0; i < channels_.size(); ++i) {
      if (i == rank_) {
        continue;
      }
      auto& channel = channels_[i];
      if (!channel) {
        continue;
      }
      channel->SetChannel(sendCb, recvCb); // 更为底层的通信在channel中
    }
  };
cpp 复制代码
  void SetChannel(std::function<int(const std::string &, std::string &)> sendCb, std::function<std::string(const std::string&)> recvCb){
    SPDLOG_INFO("Add SetChannel");
    
    sendCallback_ = sendCb;
    recvCallback_ = recvCb;
  };
 protected:
  std::function<int(const std::string &, std::string &)> sendCallback_;
  std::function<std::string(const std::string&)> recvCallback_;

然后我们将最底层的通信中增加我们的通信:

cpp 复制代码
void Channel::Send(const std::string& msg_key, ByteContainerView value) {
  if (aborting_.load()) {
    YACL_THROW_LINK_ABORTED("Send is not allowed when channel is aborting");
  }
  if (YACL_UNLIKELY(disable_msg_seq_id_)) {
    YACL_THROW("Send is not allowed when msg_seq_id is disabled");
  }

  YACL_ENFORCE(!waiting_finish_.load(),
               "Send is not allowed when channel is closing");
  NormalMessageKeyEnforce(msg_key);

 // 增加我们的通信
  if (sendCallback_) {
        SPDLOG_DEBUG("GAIA send tag={}, size={} (sync)", msg_key, value.size());
        std::string str((char*)value.data(), value.size());
        sendCallback_(msg_key, str);
        return;
  } 


  size_t seq_id = msg_seq_id_.fetch_add(1) + 1;
  auto key = BuildChannelKey(msg_key, seq_id);
  send_msgs_.Push(Message(seq_id, std::move(key), value));
  send_sync_.WaitSeqIdSendFinished(seq_id);
}
cpp 复制代码
Buffer Channel::Recv(const std::string& msg_key) {
  if (aborting_.load()) {
    YACL_THROW_LINK_ABORTED("Recv is not allowed when aborting channel");
  }
  NormalMessageKeyEnforce(msg_key);
  // do not intercept control messages: ack/fin/connect_*
  if (recvCallback_ && msg_key != kAckKey && msg_key != kFinKey && msg_key.rfind("connect_", 0) != 0) {  
      std::string str = recvCallback_(msg_key);
      Buffer future(str.c_str(), str.length());
      return future;
  }
  Buffer value;
  size_t seq_id = 0;
  {
    std::unique_lock<bthread::Mutex> lock(msg_mutex_);
    auto stop_waiting = [&] {
      auto itr = this->recv_msgs_.find(msg_key);
      if (itr == this->recv_msgs_.end()) {
        return false;
      } else {
        std::tie(value, seq_id) = std::move(itr->second);
        this->recv_msgs_.erase(itr);
        return true;
      }
    };
    while (!stop_waiting()) {
      if (aborting_.load()) {
        YACL_THROW_LINK_ABORTED("Aborting channel, skip waiting");
      }
      //                                timeout_us
      if (msg_db_cond_.wait_for(lock, static_cast<int64_t>(recv_timeout_ms_) *
                                          1000) == ETIMEDOUT) {
        YACL_THROW_IO_ERROR("Get data timeout, key={}", msg_key);
      }
    }
  }
  SendAck(seq_id);
  return value;
}

需要注意通信的截止符号,kAckKey、kFinKey

2.PSI层

不用修改,可以做一下测试

3.SPU层的修改

这里我自己封装了一个ecdh接口:

cpp 复制代码
  auto lctx = yacl::link::FactoryBrpc().CreateContext(ctx_desc, role);
  if (send_cb && recv_cb) {
    lctx->AddChannel(std::move(send_cb), std::move(recv_cb));
  }

比如使用内存传输

cpp 复制代码
  static std::mutex mtx;
  static std::condition_variable cv;
  static std::unordered_map<std::string, std::string> mailbox0;
  static std::unordered_map<std::string, std::string> mailbox1;

  auto send_cb0 = [](const std::string& tag, std::string& payload) -> int {
    std::lock_guard<std::mutex> lock(mtx);
    SPDLOG_INFO("send_cb0->1, tag: {}, payload(size): {}", tag, payload.size());

    mailbox1[tag] = std::move(payload);
    cv.notify_all();
    return 0;
  };

  auto recv_cb0 = [](const std::string& tag) -> std::string {
    std::unique_lock<std::mutex> lock(mtx);
    cv.wait(lock, [&]{ return mailbox0.find(tag) != mailbox0.end(); });
    auto it = mailbox0.find(tag);
    std::string out = std::move(it->second);
    SPDLOG_INFO("recv_cb0, tag: {}, payload(size): {}", tag, out.size());
    mailbox0.erase(it);
    return out;
  };

  auto send_cb1 = [](const std::string& tag, std::string& payload) -> int {
    std::lock_guard<std::mutex> lock(mtx);
    SPDLOG_INFO("send_cb1->0, tag: {}, payload(size): {}", tag, payload.size());

    mailbox0[tag] = std::move(payload);
    cv.notify_all();
    return 0;
  };

  auto recv_cb1 = [](const std::string& tag) -> std::string {
    std::unique_lock<std::mutex> lock(mtx);
    cv.wait(lock, [&]{ return mailbox1.find(tag) != mailbox1.end(); });
    auto it = mailbox1.find(tag);
    std::string out = std::move(it->second);
    SPDLOG_INFO("recv_cb1, tag: {}, payload(size): {}", tag, out.size());
    mailbox1.erase(it);
    return out;
  };
4.python层将通信传入c层
cpp 复制代码
  m.def("CreateChannel", 
        [](size_t role,
           py::function sendCb, py::function recvCb)
            -> std::shared_ptr<yacl::link::Context> {
          // 持有 Python 回调的指针,并在析构时获取 GIL 以避免 GIL 断言
          auto send_fn = std::shared_ptr<py::function>(
              new py::function(sendCb),
              [](py::function* f) {
                py::gil_scoped_acquire acquire;
                delete f;
              });
          auto recv_fn = std::shared_ptr<py::function>(
              new py::function(recvCb),
              [](py::function* f) {
                py::gil_scoped_acquire acquire;
                delete f;
              });

          // 将 Python 回调包装为 C++ 回调,调用时获取 GIL
          std::function<int(const std::string&, std::string&)> send_wrapper =
              [send_fn](const std::string& tag, std::string& payload) -> int {
                py::gil_scoped_acquire acquire;
                py::bytes py_payload(payload);
                py::object ret = (*send_fn)(tag, py_payload);
                return ret.cast<int>();
              };
          std::function<std::string(const std::string&)> recv_wrapper =
              [recv_fn](const std::string& tag) -> std::string {
                py::gil_scoped_acquire acquire;
                py::object ret = (*recv_fn)(tag);
                return ret.cast<std::string>();
              };

          py::gil_scoped_release release;
          auto ctx = psi::utils::Createlinks(role, std::move(send_wrapper), std::move(recv_wrapper));
          return ctx;
        },
        py::arg("role"),
        py::arg("send_cb"),
        py::arg("recv_cb"),
        "Create Context and inject memory send/recv callbacks");

python层传参到c

python 复制代码
    def send_cb(tag: str, payload: bytes) -> int:
        if role == 0:
            mailbox1[tag] = payload
        else:
            mailbox0[tag] = payload
        return 0

    def recv_cb(tag: str) -> bytes:
        if role == 0:
            while tag not in mailbox0:
                time.sleep(0.001)
            data = mailbox0[tag]
            del mailbox0[tag]
            return data
        else:
            while tag not in mailbox1:
                time.sleep(0.001)
            data = mailbox1[tag]
            del mailbox1[tag]
            return data

    ctx = CreateChannel(role, send_cb, recv_cb)

可以参考一下

Github-SPL

参考

https://github.com/secretflow/psi

https://github.com/secretflow/yacl

相关推荐
杜子不疼.2 小时前
【Linux】基础IO(二):系统文件IO
linux·运维·服务器
郝学胜-神的一滴2 小时前
深入理解网络IP协议与TTL机制:从原理到实践
linux·服务器·开发语言·网络·网络协议·tcp/ip·程序人生
松涛和鸣2 小时前
DAY61 IMX6ULL UART Serial Communication Practice
linux·服务器·网络·arm开发·数据库·驱动开发
一个不知名程序员www8 小时前
算法学习入门 --- 哈希表和unordered_map、unordered_set(C++)
c++·算法
二哈喇子!8 小时前
BOM模型
开发语言·前端·javascript·bom
二哈喇子!8 小时前
Vue2 监听器 watcher
前端·javascript·vue.js
杨靳言先8 小时前
✨【运维实战】内网服务器无法联网?巧用 SSH 隧道实现反向代理访问公网资源 (Docker/PortForwarding)
服务器·docker·ssh
yanyu-yaya8 小时前
前端面试题
前端·面试·前端框架
Sarvartha8 小时前
C++ STL 栈的便捷使用
c++·算法
二哈喇子!9 小时前
使用NVM下载Node.js管理多版本
前端·npm·node.js