魔改-隐语PSI通信,支持外部通信自定义

经过摸索,终于将隐语通信服务暴露出来,以后就可以将隐语的算子使用我们的通信流畅使用喽。

一.联络结构

SPU <-----PSI<------yacl

二、每层进行修改

1.yacl层

向外暴露的类参数为Context,所以我们再这里增加一个新的内置函数用来传递通信函数。

cpp 复制代码
// 使用抽象函数传递
  void AddChannel(std::function<int(const std::string &, std::string &)> sendCb, std::function<std::string(const std::string&)> recvCb) {
    for (size_t i = 0; i < channels_.size(); ++i) {
      if (i == rank_) {
        continue;
      }
      auto& channel = channels_[i];
      if (!channel) {
        continue;
      }
      channel->SetChannel(sendCb, recvCb); // 更为底层的通信在channel中
    }
  };
cpp 复制代码
  void SetChannel(std::function<int(const std::string &, std::string &)> sendCb, std::function<std::string(const std::string&)> recvCb){
    SPDLOG_INFO("Add SetChannel");
    
    sendCallback_ = sendCb;
    recvCallback_ = recvCb;
  };
 protected:
  std::function<int(const std::string &, std::string &)> sendCallback_;
  std::function<std::string(const std::string&)> recvCallback_;

然后我们将最底层的通信中增加我们的通信:

cpp 复制代码
void Channel::Send(const std::string& msg_key, ByteContainerView value) {
  if (aborting_.load()) {
    YACL_THROW_LINK_ABORTED("Send is not allowed when channel is aborting");
  }
  if (YACL_UNLIKELY(disable_msg_seq_id_)) {
    YACL_THROW("Send is not allowed when msg_seq_id is disabled");
  }

  YACL_ENFORCE(!waiting_finish_.load(),
               "Send is not allowed when channel is closing");
  NormalMessageKeyEnforce(msg_key);

 // 增加我们的通信
  if (sendCallback_) {
        SPDLOG_DEBUG("GAIA send tag={}, size={} (sync)", msg_key, value.size());
        std::string str((char*)value.data(), value.size());
        sendCallback_(msg_key, str);
        return;
  } 


  size_t seq_id = msg_seq_id_.fetch_add(1) + 1;
  auto key = BuildChannelKey(msg_key, seq_id);
  send_msgs_.Push(Message(seq_id, std::move(key), value));
  send_sync_.WaitSeqIdSendFinished(seq_id);
}
cpp 复制代码
Buffer Channel::Recv(const std::string& msg_key) {
  if (aborting_.load()) {
    YACL_THROW_LINK_ABORTED("Recv is not allowed when aborting channel");
  }
  NormalMessageKeyEnforce(msg_key);
  // do not intercept control messages: ack/fin/connect_*
  if (recvCallback_ && msg_key != kAckKey && msg_key != kFinKey && msg_key.rfind("connect_", 0) != 0) {  
      std::string str = recvCallback_(msg_key);
      Buffer future(str.c_str(), str.length());
      return future;
  }
  Buffer value;
  size_t seq_id = 0;
  {
    std::unique_lock<bthread::Mutex> lock(msg_mutex_);
    auto stop_waiting = [&] {
      auto itr = this->recv_msgs_.find(msg_key);
      if (itr == this->recv_msgs_.end()) {
        return false;
      } else {
        std::tie(value, seq_id) = std::move(itr->second);
        this->recv_msgs_.erase(itr);
        return true;
      }
    };
    while (!stop_waiting()) {
      if (aborting_.load()) {
        YACL_THROW_LINK_ABORTED("Aborting channel, skip waiting");
      }
      //                                timeout_us
      if (msg_db_cond_.wait_for(lock, static_cast<int64_t>(recv_timeout_ms_) *
                                          1000) == ETIMEDOUT) {
        YACL_THROW_IO_ERROR("Get data timeout, key={}", msg_key);
      }
    }
  }
  SendAck(seq_id);
  return value;
}

需要注意通信的截止符号,kAckKey、kFinKey

2.PSI层

不用修改,可以做一下测试

3.SPU层的修改

这里我自己封装了一个ecdh接口:

cpp 复制代码
  auto lctx = yacl::link::FactoryBrpc().CreateContext(ctx_desc, role);
  if (send_cb && recv_cb) {
    lctx->AddChannel(std::move(send_cb), std::move(recv_cb));
  }

比如使用内存传输

cpp 复制代码
  static std::mutex mtx;
  static std::condition_variable cv;
  static std::unordered_map<std::string, std::string> mailbox0;
  static std::unordered_map<std::string, std::string> mailbox1;

  auto send_cb0 = [](const std::string& tag, std::string& payload) -> int {
    std::lock_guard<std::mutex> lock(mtx);
    SPDLOG_INFO("send_cb0->1, tag: {}, payload(size): {}", tag, payload.size());

    mailbox1[tag] = std::move(payload);
    cv.notify_all();
    return 0;
  };

  auto recv_cb0 = [](const std::string& tag) -> std::string {
    std::unique_lock<std::mutex> lock(mtx);
    cv.wait(lock, [&]{ return mailbox0.find(tag) != mailbox0.end(); });
    auto it = mailbox0.find(tag);
    std::string out = std::move(it->second);
    SPDLOG_INFO("recv_cb0, tag: {}, payload(size): {}", tag, out.size());
    mailbox0.erase(it);
    return out;
  };

  auto send_cb1 = [](const std::string& tag, std::string& payload) -> int {
    std::lock_guard<std::mutex> lock(mtx);
    SPDLOG_INFO("send_cb1->0, tag: {}, payload(size): {}", tag, payload.size());

    mailbox0[tag] = std::move(payload);
    cv.notify_all();
    return 0;
  };

  auto recv_cb1 = [](const std::string& tag) -> std::string {
    std::unique_lock<std::mutex> lock(mtx);
    cv.wait(lock, [&]{ return mailbox1.find(tag) != mailbox1.end(); });
    auto it = mailbox1.find(tag);
    std::string out = std::move(it->second);
    SPDLOG_INFO("recv_cb1, tag: {}, payload(size): {}", tag, out.size());
    mailbox1.erase(it);
    return out;
  };
4.python层将通信传入c层
cpp 复制代码
  m.def("CreateChannel", 
        [](size_t role,
           py::function sendCb, py::function recvCb)
            -> std::shared_ptr<yacl::link::Context> {
          // 持有 Python 回调的指针,并在析构时获取 GIL 以避免 GIL 断言
          auto send_fn = std::shared_ptr<py::function>(
              new py::function(sendCb),
              [](py::function* f) {
                py::gil_scoped_acquire acquire;
                delete f;
              });
          auto recv_fn = std::shared_ptr<py::function>(
              new py::function(recvCb),
              [](py::function* f) {
                py::gil_scoped_acquire acquire;
                delete f;
              });

          // 将 Python 回调包装为 C++ 回调,调用时获取 GIL
          std::function<int(const std::string&, std::string&)> send_wrapper =
              [send_fn](const std::string& tag, std::string& payload) -> int {
                py::gil_scoped_acquire acquire;
                py::bytes py_payload(payload);
                py::object ret = (*send_fn)(tag, py_payload);
                return ret.cast<int>();
              };
          std::function<std::string(const std::string&)> recv_wrapper =
              [recv_fn](const std::string& tag) -> std::string {
                py::gil_scoped_acquire acquire;
                py::object ret = (*recv_fn)(tag);
                return ret.cast<std::string>();
              };

          py::gil_scoped_release release;
          auto ctx = psi::utils::Createlinks(role, std::move(send_wrapper), std::move(recv_wrapper));
          return ctx;
        },
        py::arg("role"),
        py::arg("send_cb"),
        py::arg("recv_cb"),
        "Create Context and inject memory send/recv callbacks");

python层传参到c

python 复制代码
    def send_cb(tag: str, payload: bytes) -> int:
        if role == 0:
            mailbox1[tag] = payload
        else:
            mailbox0[tag] = payload
        return 0

    def recv_cb(tag: str) -> bytes:
        if role == 0:
            while tag not in mailbox0:
                time.sleep(0.001)
            data = mailbox0[tag]
            del mailbox0[tag]
            return data
        else:
            while tag not in mailbox1:
                time.sleep(0.001)
            data = mailbox1[tag]
            del mailbox1[tag]
            return data

    ctx = CreateChannel(role, send_cb, recv_cb)

可以参考一下

Github-SPL

参考

https://github.com/secretflow/psi

https://github.com/secretflow/yacl

相关推荐
chenyuhao202437 分钟前
Linux网络编程:传输层协议UDP
linux·服务器·网络·后端·udp
小鹏linux38 分钟前
【linux】进程与服务管理命令 - batch
linux·运维·服务器
夜思红尘7 小时前
算法--双指针
python·算法·剪枝
人工智能训练7 小时前
OpenEnler等Linux系统中安装git工具的方法
linux·运维·服务器·git·vscode·python·ubuntu
散峰而望7 小时前
【算法竞赛】C++函数详解:从定义、调用到高级用法
c语言·开发语言·数据结构·c++·算法·github
CoderCodingNo7 小时前
【GESP】C++五级真题(贪心思想考点) luogu-B4071 [GESP202412 五级] 武器强化
开发语言·c++·算法
我有一些感想……7 小时前
An abstract way to solve Luogu P1001
c++·算法·ai·洛谷·mlp
郭涤生7 小时前
第十章_信号_《UNIX环境高级编程(第三版)》_笔记
服务器·笔记·unix
前端小L7 小时前
双指针专题(三):去重的艺术——「三数之和」
javascript·算法·双指针与滑动窗口
0和1的舞者7 小时前
Spring AOP详解(一)
java·开发语言·前端·spring·aop·面向切面