魔改-隐语PSI通信,支持外部通信自定义

经过摸索,终于将隐语通信服务暴露出来,以后就可以将隐语的算子使用我们的通信流畅使用喽。

一.联络结构

SPU <-----PSI<------yacl

二、每层进行修改

1.yacl层

向外暴露的类参数为Context,所以我们再这里增加一个新的内置函数用来传递通信函数。

cpp 复制代码
// 使用抽象函数传递
  void AddChannel(std::function<int(const std::string &, std::string &)> sendCb, std::function<std::string(const std::string&)> recvCb) {
    for (size_t i = 0; i < channels_.size(); ++i) {
      if (i == rank_) {
        continue;
      }
      auto& channel = channels_[i];
      if (!channel) {
        continue;
      }
      channel->SetChannel(sendCb, recvCb); // 更为底层的通信在channel中
    }
  };
cpp 复制代码
  void SetChannel(std::function<int(const std::string &, std::string &)> sendCb, std::function<std::string(const std::string&)> recvCb){
    SPDLOG_INFO("Add SetChannel");
    
    sendCallback_ = sendCb;
    recvCallback_ = recvCb;
  };
 protected:
  std::function<int(const std::string &, std::string &)> sendCallback_;
  std::function<std::string(const std::string&)> recvCallback_;

然后我们将最底层的通信中增加我们的通信:

cpp 复制代码
void Channel::Send(const std::string& msg_key, ByteContainerView value) {
  if (aborting_.load()) {
    YACL_THROW_LINK_ABORTED("Send is not allowed when channel is aborting");
  }
  if (YACL_UNLIKELY(disable_msg_seq_id_)) {
    YACL_THROW("Send is not allowed when msg_seq_id is disabled");
  }

  YACL_ENFORCE(!waiting_finish_.load(),
               "Send is not allowed when channel is closing");
  NormalMessageKeyEnforce(msg_key);

 // 增加我们的通信
  if (sendCallback_) {
        SPDLOG_DEBUG("GAIA send tag={}, size={} (sync)", msg_key, value.size());
        std::string str((char*)value.data(), value.size());
        sendCallback_(msg_key, str);
        return;
  } 


  size_t seq_id = msg_seq_id_.fetch_add(1) + 1;
  auto key = BuildChannelKey(msg_key, seq_id);
  send_msgs_.Push(Message(seq_id, std::move(key), value));
  send_sync_.WaitSeqIdSendFinished(seq_id);
}
cpp 复制代码
Buffer Channel::Recv(const std::string& msg_key) {
  if (aborting_.load()) {
    YACL_THROW_LINK_ABORTED("Recv is not allowed when aborting channel");
  }
  NormalMessageKeyEnforce(msg_key);
  // do not intercept control messages: ack/fin/connect_*
  if (recvCallback_ && msg_key != kAckKey && msg_key != kFinKey && msg_key.rfind("connect_", 0) != 0) {  
      std::string str = recvCallback_(msg_key);
      Buffer future(str.c_str(), str.length());
      return future;
  }
  Buffer value;
  size_t seq_id = 0;
  {
    std::unique_lock<bthread::Mutex> lock(msg_mutex_);
    auto stop_waiting = [&] {
      auto itr = this->recv_msgs_.find(msg_key);
      if (itr == this->recv_msgs_.end()) {
        return false;
      } else {
        std::tie(value, seq_id) = std::move(itr->second);
        this->recv_msgs_.erase(itr);
        return true;
      }
    };
    while (!stop_waiting()) {
      if (aborting_.load()) {
        YACL_THROW_LINK_ABORTED("Aborting channel, skip waiting");
      }
      //                                timeout_us
      if (msg_db_cond_.wait_for(lock, static_cast<int64_t>(recv_timeout_ms_) *
                                          1000) == ETIMEDOUT) {
        YACL_THROW_IO_ERROR("Get data timeout, key={}", msg_key);
      }
    }
  }
  SendAck(seq_id);
  return value;
}

需要注意通信的截止符号,kAckKey、kFinKey

2.PSI层

不用修改,可以做一下测试

3.SPU层的修改

这里我自己封装了一个ecdh接口:

cpp 复制代码
  auto lctx = yacl::link::FactoryBrpc().CreateContext(ctx_desc, role);
  if (send_cb && recv_cb) {
    lctx->AddChannel(std::move(send_cb), std::move(recv_cb));
  }

比如使用内存传输

cpp 复制代码
  static std::mutex mtx;
  static std::condition_variable cv;
  static std::unordered_map<std::string, std::string> mailbox0;
  static std::unordered_map<std::string, std::string> mailbox1;

  auto send_cb0 = [](const std::string& tag, std::string& payload) -> int {
    std::lock_guard<std::mutex> lock(mtx);
    SPDLOG_INFO("send_cb0->1, tag: {}, payload(size): {}", tag, payload.size());

    mailbox1[tag] = std::move(payload);
    cv.notify_all();
    return 0;
  };

  auto recv_cb0 = [](const std::string& tag) -> std::string {
    std::unique_lock<std::mutex> lock(mtx);
    cv.wait(lock, [&]{ return mailbox0.find(tag) != mailbox0.end(); });
    auto it = mailbox0.find(tag);
    std::string out = std::move(it->second);
    SPDLOG_INFO("recv_cb0, tag: {}, payload(size): {}", tag, out.size());
    mailbox0.erase(it);
    return out;
  };

  auto send_cb1 = [](const std::string& tag, std::string& payload) -> int {
    std::lock_guard<std::mutex> lock(mtx);
    SPDLOG_INFO("send_cb1->0, tag: {}, payload(size): {}", tag, payload.size());

    mailbox0[tag] = std::move(payload);
    cv.notify_all();
    return 0;
  };

  auto recv_cb1 = [](const std::string& tag) -> std::string {
    std::unique_lock<std::mutex> lock(mtx);
    cv.wait(lock, [&]{ return mailbox1.find(tag) != mailbox1.end(); });
    auto it = mailbox1.find(tag);
    std::string out = std::move(it->second);
    SPDLOG_INFO("recv_cb1, tag: {}, payload(size): {}", tag, out.size());
    mailbox1.erase(it);
    return out;
  };
4.python层将通信传入c层
cpp 复制代码
  m.def("CreateChannel", 
        [](size_t role,
           py::function sendCb, py::function recvCb)
            -> std::shared_ptr<yacl::link::Context> {
          // 持有 Python 回调的指针,并在析构时获取 GIL 以避免 GIL 断言
          auto send_fn = std::shared_ptr<py::function>(
              new py::function(sendCb),
              [](py::function* f) {
                py::gil_scoped_acquire acquire;
                delete f;
              });
          auto recv_fn = std::shared_ptr<py::function>(
              new py::function(recvCb),
              [](py::function* f) {
                py::gil_scoped_acquire acquire;
                delete f;
              });

          // 将 Python 回调包装为 C++ 回调,调用时获取 GIL
          std::function<int(const std::string&, std::string&)> send_wrapper =
              [send_fn](const std::string& tag, std::string& payload) -> int {
                py::gil_scoped_acquire acquire;
                py::bytes py_payload(payload);
                py::object ret = (*send_fn)(tag, py_payload);
                return ret.cast<int>();
              };
          std::function<std::string(const std::string&)> recv_wrapper =
              [recv_fn](const std::string& tag) -> std::string {
                py::gil_scoped_acquire acquire;
                py::object ret = (*recv_fn)(tag);
                return ret.cast<std::string>();
              };

          py::gil_scoped_release release;
          auto ctx = psi::utils::Createlinks(role, std::move(send_wrapper), std::move(recv_wrapper));
          return ctx;
        },
        py::arg("role"),
        py::arg("send_cb"),
        py::arg("recv_cb"),
        "Create Context and inject memory send/recv callbacks");

python层传参到c

python 复制代码
    def send_cb(tag: str, payload: bytes) -> int:
        if role == 0:
            mailbox1[tag] = payload
        else:
            mailbox0[tag] = payload
        return 0

    def recv_cb(tag: str) -> bytes:
        if role == 0:
            while tag not in mailbox0:
                time.sleep(0.001)
            data = mailbox0[tag]
            del mailbox0[tag]
            return data
        else:
            while tag not in mailbox1:
                time.sleep(0.001)
            data = mailbox1[tag]
            del mailbox1[tag]
            return data

    ctx = CreateChannel(role, send_cb, recv_cb)

可以参考一下

Github-SPL

参考

https://github.com/secretflow/psi

https://github.com/secretflow/yacl

相关推荐
by__csdn2 小时前
JavaScript性能优化实战:异步与延迟加载全方位攻略
开发语言·前端·javascript·vue.js·react.js·typescript·ecmascript
良木生香2 小时前
【数据结构-初阶】详解线性表(2)---单链表
c语言·数据结构·算法
菜鸟233号2 小时前
力扣106 从中序与后序遍历序列构造二叉树 java实现
java·算法·leetcode
杨超越luckly2 小时前
HTML应用指南:利用GET请求获取全国瑞思教育门店位置信息
前端·python·arcgis·html·门店数据
尘缘浮梦2 小时前
chrome英文翻译插件
前端·chrome
HIT_Weston2 小时前
58、【Ubuntu】【Gitlab】拉出内网 Web 服务:Gitlab 配置审视(二)
前端·ubuntu·gitlab
Donald_wsn2 小时前
牛客 栈和排序 C++
数据结构·c++·算法
wanhengidc2 小时前
云计算环境中的数据安全防护策略
运维·服务器·科技·游戏·智能手机·云计算
沃达德软件2 小时前
智慧警务实战模型与算法
大数据·人工智能·算法·数据挖掘·数据分析