相关材料
1\] [pd分离在vllm中用法](https://blog.csdn.net/qq_44319972/article/details/155456852)
\[2\][vLLM PD分离方案入门:核心概念、优势与适应场景梳理](https://blog.csdn.net/u012605037/article/details/154681594)
## DisaggPDScheduler
xllm在pd分离场景中,主要逻辑集中在DisaggPDScheduler。
```cpp
DisaggPDScheduler::DisaggPDScheduler(Engine* engine, const Options& options)
: ContinuousScheduler(engine, options), server_name_("DisaggPDServer") {
if (!options_.instance_role().has_value()) {
LOG(FATAL) << "Instance type is not set in disagg pd mode.";
}
// Only initialize for non-OOC mode
// OOC mode (PDOOCScheduler) will handle initialization in its own constructor
if (!options_.enable_pd_ooc()) {
// Start dispatch thread for prefill instance
dispatch_thread_ = std::make_unique(
&DisaggPDScheduler::dispatch_requests, this);
// Start RPC server thread
server_name_.append(std::to_string(options.server_idx()));
rpc_server_thread_ = std::make_unique(
&DisaggPDScheduler::start_rpc_server, this);
initialize_rpc_server_and_client(server_name_);
register_instance_info(server_name_, engine);
// Profile ttft & topt and update instance info (for mix instances)
if (!options_.disable_ttft_profiling() &&
options_.instance_role().value() == InstanceRole::MIX) {
profile_ttft();
profile_tpot();
}
}
}
```
在P实例,dispatch_thread_ 负载prefill的调度。
```cpp
bool DisaggPDScheduler::add_request(std::shared_ptr& request) {
CHECK(request != nullptr);
CHECK(!request->sequences().empty());
kv_cache_manager_->prefetch_from_storage(request);
if (request->offline()) {
// offline request, push to offline queue
prefill_request_queue_offline_.enqueue(request);
return true;
}
// push and wait
prefill_request_queue_.enqueue(request);
return true;
}
// prefill send new request to remote instance
void DisaggPDScheduler::dispatch_requests() {
while (true) {
const auto timeout = std::chrono::milliseconds(100);
// Wait for online request until timeout.
// If timeout, try to get offline request once. If no offline request,
// continue to wait for online request. This can avoid offline request
// blocking online request for too long time.
std::shared_ptr request;
if (!prefill_request_queue_.wait_dequeue_timed(request, timeout)) {
if (!prefill_request_queue_offline_.try_dequeue(request)) {
continue;
}
}
if (request == nullptr) {
// nullptr is a signal to exit
break;
}
std::vector> requests;
requests.emplace_back(request);
std::string selected_instance = "";
proto::DisaggPDService_Stub* stub = nullptr;
if (selected_instance.empty() && !stub) {
// get allocated decode instance list from Master
while (decode_inst_names_.empty()) {
decode_inst_names_ = xservice_client_->get_static_decode_list();
if (!decode_inst_names_.empty()) {
LOG(INFO) << "Get PD decode instance list: "
<< absl::StrJoin(decode_inst_names_, "; ");
break;
}
sleep(1);
}
// select a D instance use RR currently.
// TODO: use better decode selection strategy later. maybe different
// strategy for offline and online request. or implement in xllm service.
int try_decode_count = 0;
while (!stub) {
if (try_decode_count == decode_inst_names_.size()) {
LOG(FATAL) << "Can not connect to all decode instances.";
}
++try_decode_count;
selected_instance = decode_inst_names_[current_decode_idx_];
current_decode_idx_ =
(++current_decode_idx_) % decode_inst_names_.size();
stub = create_rpc_channel(selected_instance);
}
}
{
std::lock_guard lock(req_to_channel_map_mutex_);
for (auto& req : requests) {
req_to_channel_map_[req->request_id()] = stub;
}
}
// TODO: send the request to the selected D instance
// Send 'DisaggRequests' and recv 'DisaggResponses'
xllm::proto::DisaggRequests reqs;
xllm::proto::DisaggResponses resps;
// prefill name (ID)
reqs.set_prefill_name(xservice_client_->get_instance_name());
reqs.mutable_reqs()->Reserve(requests.size());
// currently we only support one request once.
for (size_t i = 0; i < requests.size(); ++i) {
// proto::DisaggRequest req;
auto req = reqs.mutable_reqs()->Add();
req->set_req_id(requests[i]->request_id());
req->set_service_req_id(requests[i]->service_request_id());
req->set_tokens_num(requests[i]->state().prompt_tokens.size());
req->set_prompt(requests[i]->state().prompt);
ADD_VECTOR_TO_PROTO(req->mutable_prompt_tokens(),
requests[i]->state().prompt_tokens);
req->set_stream(requests[i]->state().stream);
req->set_x_request_id(requests[i]->x_request_id());
req->set_x_request_time(requests[i]->x_request_time());
req->set_seq_capacity(requests[i]->state().seq_capacity);
req->set_max_tokens(
requests[i]->state().stopping_checker.get_max_generated_tokens());
req->set_max_context_len(
requests[i]->state().stopping_checker.get_max_context_len());
req->set_ignore_eos(
requests[i]->state().stopping_checker.get_ignore_eos());
req->set_eos_token_id(
requests[i]->state().stopping_checker.get_eos_token());
if (requests[i]->state().stopping_checker.get_stop_tokens().size() > 0) {
ADD_VECTOR_TO_PROTO(
req->mutable_stop_token_ids(),
requests[i]->state().stopping_checker.get_stop_tokens());
}
if (requests[i]->state().stopping_checker.get_stop_sequences().size() >
0) {
for (auto& stop_sequence :
requests[i]->state().stopping_checker.get_stop_sequences()) {
// proto::StopSequence proto_seq;
auto proto_seq = req->mutable_stop_sequences()->Add();
ADD_VECTOR_TO_PROTO(proto_seq->mutable_seq_tokens(), stop_sequence);
//*req->mutable_stop_sequences()->Add() = proto_seq;
}
}
req->set_n(requests[i]->state().n);
req->set_best_of(requests[i]->state().best_of);
req->set_frequency_penalty(
requests[i]->state().sampling_param.frequency_penalty);
req->set_presence_penalty(
requests[i]->state().sampling_param.presence_penalty);
req->set_repetition_penalty(
requests[i]->state().sampling_param.repetition_penalty);
req->set_temperature(requests[i]->state().sampling_param.temperature);
req->set_top_p(requests[i]->state().sampling_param.top_p);
req->set_top_k(requests[i]->state().sampling_param.top_k);
req->set_logprobs(requests[i]->state().sampling_param.logprobs);
req->set_top_logprobs(requests[i]->state().sampling_param.top_logprobs);
req->set_is_embeddings(requests[i]->state().sampling_param.is_embeddings);
req->set_echo(requests[i]->state().echo);
req->set_skip_special_tokens(requests[i]->state().skip_special_tokens);
//*reqs.mutable_reqs()->Add() = req;
}
std::vector device_ips;
std::vector ports;
engine_->get_device_info(device_ips, ports);
reqs.mutable_cluster_infos()->mutable_cluster_ids()->Add(
instance_info_.cluster_ids.begin(), instance_info_.cluster_ids.end());
reqs.mutable_cluster_infos()->mutable_addrs()->Add(
instance_info_.addrs.begin(), instance_info_.addrs.end());
reqs.mutable_cluster_infos()->mutable_device_ips()->Add(device_ips.begin(),
device_ips.end());
reqs.mutable_cluster_infos()->mutable_ports()->Add(ports.begin(),
ports.end());
reqs.mutable_cluster_infos()->set_dp_size(options_.dp_size());
// TODO: sync rpc here currently
brpc::Controller cntl;
stub->AddNewRequests(&cntl, &reqs, &resps, nullptr);
// TODO: error handler
// if (rpc failed) {
// // push all request back to prefill_request_queue_
//}
// check reqs which can not dispatch to D instance,
// and push back to prefill_request_queue_
CHECK_EQ(requests.size(), resps.resps().size())
<< "selected_instance : " << selected_instance;
// insert instance name to linked_instance_
{
std::lock_guard lock(linked_instances_mutex_);
linked_instance_.emplace(selected_instance);
}
for (size_t i = 0; i < requests.size(); ++i) {
if (resps.resps()[i].status_code() != 200) {
// push back to prefill_request_queue_
if (requests[i]->offline()) {
prefill_request_queue_offline_.enqueue(requests[i]);
} else {
prefill_request_queue_.enqueue(requests[i]);
}
} else {
for (auto& sequence : requests[i]->sequences()) {
TransferKVInfo info;
info.request_id = requests[i]->request_id();
for (auto& bid : resps.resps()[i].blocks_ids()) {
info.remote_blocks_ids.emplace_back(bid);
}
info.dp_rank = resps.resps()[i].dp_rank();
// TODO: remote_instances_info_ is not multi-thread safe.
info.remote_instance_info = remote_instances_info_[selected_instance];
sequence->kv_state().set_transfer_kv_info(std::move(info));
}
// push to request_queue_, and will be executed by engine.
request_queue_.write(requests[i]);
}
}
}
}
```
add_request将request放入prefill_request_queue_。
在dispatch_requests函数中,创建stub = create_rpc_channel。decode_address有两个来源:1 requests中携带的,2 使用xservice_client_-\>get_static_decode_list()获取。
xservice_client_中需要同etcd交互。
XllmRpcService定义的接口。
```cpp
service XllmRpcService {
rpc RegisterInstance(InstanceMetaInfo) returns (StatusCode) {}
rpc GetInstanceInfo(InstanceID) returns (InstanceMetaInfo) {}
rpc Heartbeat(HeartbeatRequest) returns (Status) {}
rpc GetStaticDecodeList(InstanceID) returns (InstanceIDs) {}
rpc GetStaticPrefillList(InstanceID) returns (InstanceIDs) {}
// xllm service receive response from decode instance directly in disagg pd mode.
// This can eliminate the cost brought by forwarding through prefill.
rpc Generations(xllm.proto.DisaggStreamGenerations) returns (xllm.proto.StatusSet) {}
}
```
服务由另一个工程[xllm-service](https://github.com/jd-opensource/xllm-service/)提供。
stub-\>AddNewRequests(\&cntl, \&reqs, \&resps, nullptr),发起远程调用。对应的服务接口为DisaggPDService::AddNewRequests。
reqs.mutable_cluster_info设置P实例集群的device_ips,ports。
在D实例,rpc_server_thread_ 负责启动DisaggPDService。
```cpp
void DisaggPDScheduler::start_rpc_server() {
std::unique_ptr service =
std::make_unique(this, engine_);
auto rpc_server =
ServerRegistry::get_instance().register_server(server_name_);
if (!rpc_server->start(std::move(service))) {
LOG(ERROR) << "Failed to start brpc disagg pd server on port "
<< FLAGS_disagg_pd_port;
return;
}
}
class DisaggPDService : public proto::DisaggPDService {};
```
## DisaggPDService
定义的RPC接口:
```cpp
service DisaggPDService {
rpc Generation(DisaggStreamGeneration) returns (Status) {}
rpc Generations(DisaggStreamGenerations) returns (StatusSet) {}
rpc AddNewRequests(DisaggRequests) returns (DisaggResponses) {}
rpc FirstGeneration(DisaggGenerationsRequests) returns (Status) {}
rpc MultiGenerations(DisaggGenerationsRequests) returns (Status) {}
rpc SendPullSignal(PullSignal) returns (Status) {}
}
```
DisaggPDService::AddNewRequests
```cpp
void DisaggPDService::AddNewRequests(
::google::protobuf::RpcController* controller,
const proto::DisaggRequests* request,
proto::DisaggResponses* response,
::google::protobuf::Closure* done) {
brpc::ClosureGuard done_guard(done);
// try to allocate blocks for new requests
disagg_pd_service_impl_->decode_recv_new_requests(request, response);
}
void DisaggPDServiceImpl::decode_recv_new_requests(
const proto::DisaggRequests* request,
proto::DisaggResponses* response) {
// link prefill instance
if (!scheduler_->is_instance_linked(request->prefill_name())) {
std::vector cluster_ids(
request->cluster_infos().cluster_ids().begin(),
request->cluster_infos().cluster_ids().end());
std::vector addrs(request->cluster_infos().addrs().begin(),
request->cluster_infos().addrs().end());
std::vector device_ips(
request->cluster_infos().device_ips().begin(),
request->cluster_infos().device_ips().end());
std::vector ports(request->cluster_infos().ports().begin(),
request->cluster_infos().ports().end());
int32_t dp_size = request->cluster_infos().dp_size();
if (!scheduler_->link_instance(request->prefill_name(),
cluster_ids,
addrs,
device_ips,
ports,
dp_size)) {
LOG(ERROR) << "Link instance failed, instance name : "
<< request->prefill_name();
return;
}
}
for (auto& req : request->reqs()) {
auto resp = response->add_resps();
resp->set_req_id(req.req_id());
auto new_request = generate_request(req);
if (new_request == nullptr) {
resp->set_status_code(500);
continue;
}
auto& sequences = new_request->sequences();
Sequence* sequence = sequences[0].get();
if (!scheduler_->try_allocate(sequence)) {
// FIXME: set status code
resp->set_status_code(404);
} else {
// push the request to scheduler request buffer
bool success =
scheduler_->decode_schedule(new_request, request->prefill_name());
if (!success) {
LOG(ERROR) << "Failed to schedule new decode instance request: "
<< req.req_id();
// request and blocks are released in scheduler
resp->set_status_code(500);
}
auto dp_rank = sequence->dp_rank();
resp->set_dp_rank(dp_rank);
size_t shared_num = sequence->kv_state().shared_kv_blocks_num();
auto blocks = sequence->kv_state().kv_blocks();
for (size_t i = shared_num; i < blocks.size(); i++) {
*(resp->mutable_blocks_ids()->Add()) = blocks[i].id();
}
resp->set_status_code(200);
}
}
}
```
### scheduler_-\>link_instance
D实例同P实例建立链接
```cpp
bool DisaggPDScheduler::link_instance(
const std::string& instance_name,
const std::vector& cluster_ids,
const std::vector& addrs,
const std::vector& device_ips,
const std::vector& ports,
const int32_t dp_size) {
std::lock_guard lock(linked_instances_mutex_);
if (!engine_->link_cluster(cluster_ids, addrs, device_ips, ports, dp_size)) {
LOG(ERROR) << "Link cluster failed!";
return false;
}
linked_instance_.emplace(instance_name);
return true;
}
bool LLMEngine::link_cluster(const std::vector& cluster_ids,
const std::vector& addrs,
const std::vector& device_ips,
const std::vector& ports,
const int32_t src_dp_size) {
// Indicate which worker in the dp group in prefill the current worker needs
// to connect to. First, we connect the rank 0 workers in each DP. Then,
// increment the ranks sequentially.
int32_t src_dp_worker_index = 0;
int32_t src_world_size = cluster_ids.size();
int32_t src_tp_size = src_world_size / src_dp_size;
std::vector> futures;
futures.reserve(worker_clients_num_);
for (size_t worker_rank = 0; worker_rank < worker_clients_num_;
++worker_rank) {
// The worker for decoding needs to establish a connection for each dp group
// in prefill.
std::vector dp_cluster_ids;
std::vector dp_addrs;
std::vector dp_device_ips;
std::vector dp_ports;
dp_cluster_ids.reserve(src_dp_size);
dp_addrs.reserve(src_dp_size);
dp_device_ips.reserve(src_dp_size);
dp_ports.reserve(src_dp_size);
for (int32_t i = 0; i < src_dp_size; ++i) {
int32_t src_worker_index = i * src_tp_size + src_dp_worker_index;
dp_cluster_ids.emplace_back(cluster_ids[src_worker_index]);
dp_addrs.emplace_back(addrs[src_worker_index]);
dp_device_ips.emplace_back(device_ips[src_worker_index]);
dp_ports.emplace_back(ports[src_worker_index]);
}
// Increment the rank.
src_dp_worker_index = (src_dp_worker_index + 1) % src_tp_size;
folly::Promise promise;
auto future = promise.getSemiFuture();
link_threadpool_->schedule([this,
promise = std::move(promise),
worker_rank,
dp_cluster_ids = std::move(dp_cluster_ids),
dp_addrs = std::move(dp_addrs),
dp_device_ips = std::move(dp_device_ips),
dp_ports = std::move(dp_ports)]() mutable {
promise.setValue(worker_clients_[worker_rank]->link_cluster(
dp_cluster_ids, dp_addrs, dp_device_ips, dp_ports));
});
futures.emplace_back(std::move(future));
}
// wait for all futures to complete
auto results = folly::collectAll(futures).get();
for (const auto& result : results) {
if (!result.value()) {
LOG(ERROR) << "Link cluster failed.";
return false;
}
}
return true;
}
```
## prefill_send_first_generation
P实例通知D拉取KV cache。
```cpp
void DisaggPDScheduler::prefill_send_first_generation() {
if (running_sequences_.size() == 0) {
return;
}
std::vector> requests;
std::vector> non_stream_requests;
requests.reserve(running_requests_.size());
non_stream_requests.reserve(running_requests_.size());
for (size_t i = 0; i < running_requests_.size(); ++i) {
auto request = running_requests_[i];
// Check if the request is a recently completed prefill request
if (request->sequences()[0]->num_generated_tokens() == 1) {
requests.emplace_back(request);
if (!request->state().stream) {
non_stream_requests.emplace_back(request);
}
running_requests_[i] = nullptr;
}
}
// call non_stream_request's callback in P instance when its prefill ends
response_processor_->process_completed_requests(non_stream_requests);
// No prefill request needs to be transferred to decode.
if (requests.size() == 0) {
return;
}
prefill_threadpool_.schedule([this,
requests = std::move(requests)]() mutable {
// send request first token to remote instance
// TODO: here we only support one sequence for now.
for (auto& request : requests) {
// TODO: support batch request later
proto::DisaggGenerationsRequests gens;
auto gen = gens.mutable_multi_gens()->Add();
gen->set_req_id(request->request_id());
if (request->sequences()[0]->first_token().has_value()) {
auto token = gen->mutable_tokens()->Add();
token->set_token_id(
request->sequences()[0]->first_token().value().token_id);
if (request->sequences()[0]
->first_token()
.value()
.token_logprob.has_value()) {
token->set_logprob(request->sequences()[0]
->first_token()
.value()
.token_logprob.value());
token->set_has_logprob(true);
} else {
token->set_has_logprob(false);
}
ADD_VECTOR_TO_PROTO(
token->mutable_top_tokens(),
request->sequences()[0]->first_token().value().token_top_tokens);
ADD_VECTOR_TO_PROTO(
token->mutable_top_logprobs(),
request->sequences()[0]->first_token().value().token_top_logprobs);
}
gen->set_kv_cache_transfer_mode(options_.kv_cache_transfer_mode());
if (options_.kv_cache_transfer_mode() == "PULL") {
ADD_VECTOR_TO_PROTO(gen->mutable_cluster_ids(),
instance_info_.cluster_ids);
ADD_VECTOR_TO_PROTO(gen->mutable_addrs(), instance_info_.addrs);
ADD_VECTOR_TO_PROTO(gen->mutable_k_cache_ids(),
instance_info_.k_cache_ids);
ADD_VECTOR_TO_PROTO(gen->mutable_v_cache_ids(),
instance_info_.v_cache_ids);
const auto blocks = request->sequences()[0]->kv_state().kv_blocks();
std::vector block_ids;
block_ids.reserve(blocks.size());
for (const auto& block : blocks) {
block_ids.push_back(block.id());
}
ADD_VECTOR_TO_PROTO(gen->mutable_block_ids(), block_ids);
gen->set_dp_size(instance_info_.dp_size);
gen->set_dp_rank(request->sequences()[0]->dp_rank());
}
// send first gens to remote instance
proto::DisaggPDService_Stub* stub = nullptr;
{
std::lock_guard lock(req_to_channel_map_mutex_);
// now we only support one request once.
stub = req_to_channel_map_[request->request_id()];
}
// TODO: Async call later
proto::Status resp;
brpc::Controller cntl;
stub->FirstGeneration(&cntl, &gens, &resp, nullptr);
if (cntl.Failed() || !resp.ok()) {
LOG(ERROR) << "Failed to send first generation, " << cntl.ErrorText()
<< ", staus: " << resp.ok();
}
{
std::lock_guard lock(req_to_channel_map_mutex_);
req_to_channel_map_.erase(request->request_id());
}
kv_cache_manager_->deallocate(request.get());
}
});
}
```
stub-\>FirstGeneration,向D实例发起远程调用。
### DisaggPDService::FirstGeneration
D实例,KV cache拉取过程。
```cpp
// TODO: support embedding later, now we only support tokens
void DisaggPDService::FirstGeneration(
::google::protobuf::RpcController* controller,
const proto::DisaggGenerationsRequests* request,
proto::Status* response,
::google::protobuf::Closure* done) {
// Receive first token from Prefill, schedule the request to running queue
brpc::ClosureGuard done_guard(done);
disagg_pd_service_impl_->decode_recv_first_generation(request, response);
}
// TODO: support embedding later, now we only support tokens
void DisaggPDServiceImpl::decode_recv_first_generation(
const proto::DisaggGenerationsRequests* request,
proto::Status* response) {
// TODO: we only support one request generation currently
for (auto& gen : request->multi_gens()) {
// Process the first token from the tokens array
if (gen.tokens().empty()) {
response->set_ok(false);
return;
}
std::vector addrs(gen.addrs().begin(), gen.addrs().end());
bool success =
scheduler_->decode_recv_first_generation(gen.req_id(),
first_token.token_id(),
first_token.has_logprob(),
first_token.logprob(),
std::move(top_tokens),
std::move(top_logprobs),
gen.kv_cache_transfer_mode(),
std::move(cluster_ids),
std::move(addrs),
std::move(k_cache_ids),
std::move(v_cache_ids),
std::move(block_ids),
gen.dp_size(),
gen.dp_rank());
}
}
bool DisaggPDScheduler::decode_recv_first_generation(
const std::string& req_id,
int64_t token_id,
bool has_logprob,
float logprob,
std::vector top_tokens,
std::vector top_logprobs,
const std::string& kv_cache_transfer_mode,
std::vector src_cluster_ids,
std::vector src_addrs,
std::vector src_k_cache_ids,
std::vector src_v_cache_ids,
std::vector src_block_ids,
int32_t src_dp_size,
int32_t src_dp_rank) {
// push to request_queue_, and will be executed by engine.
std::shared_ptr request = nullptr;
{
std::lock_guard lock(received_request_map_mutex_);
auto it = received_request_map_.find(req_id);
if (it == received_request_map_.end()) {
LOG(ERROR) << "Failed to find request, request id: " << req_id;
return false;
}
request = it->second;
received_request_map_.erase(it);
}
// pull kv cache
if (kv_cache_transfer_mode == "PULL") {
const auto blocks = request->sequences()[0]->kv_state().kv_blocks();
std::vector dst_block_ids;
dst_block_ids.reserve(blocks.size());
for (const auto& block : blocks) {
dst_block_ids.push_back(block.id());
}
int32_t dst_dp_rank = request->sequences()[0]->dp_rank();
engine_->pull_kv_blocks(src_dp_size,
src_dp_rank,
src_cluster_ids,
src_addrs,
src_k_cache_ids,
src_v_cache_ids,
src_block_ids,
dst_dp_rank,
dst_block_ids);
}
}
```
LLMEngine::pull_kv_blocks, 针对每一个dp组,我觉得P实例上的src_tp_size 同D实例上dst_tp_size是相等的。就是需要kvcache传输的P实例和D实例有相同数量的worker。
```cpp
bool LLMEngine::pull_kv_blocks(const int32_t src_dp_size,
const int32_t src_dp_rank,
const std::vector& src_cluster_ids,
const std::vector& src_addrs,
const std::vector& src_k_cache_ids,
const std::vector& src_v_cache_ids,
const std::vector& src_blocks,
const int32_t dst_dp_rank,
const std::vector& dst_blocks) {
int32_t src_world_size = src_cluster_ids.size();
int32_t src_tp_size = src_world_size / src_dp_size;
int32_t dst_world_size = options_.nnodes();
int32_t dst_tp_size = dst_world_size / dp_size_;
std::vector results;
results.reserve(dst_tp_size);
// Pull the KV cache for all workers in the current DP rank.
for (size_t tp_rank = 0; tp_rank < dst_tp_size; ++tp_rank) {
int32_t dst_worker_rank = dst_dp_rank * dst_tp_size + tp_rank;
// Determine the ranks of the remote workers connected to the current
// worker.
int32_t src_dp_worker_rank = dst_worker_rank % src_tp_size;
int32_t src_worker_rank = src_dp_rank * src_tp_size + src_dp_worker_rank;
results.push_back(worker_clients_[dst_worker_rank]->pull_kv_blocks(
src_cluster_ids[src_worker_rank],
src_addrs[src_worker_rank],
src_k_cache_ids[src_worker_rank],
src_v_cache_ids[src_worker_rank],
src_blocks,
dst_blocks));
}
for (bool result : results) {
if (!result) {
return false;
}
}
return true;
}
bool RemoteWorker::pull_kv_blocks(const uint64_t src_cluster_id,
const std::string& src_addr,
const int64_t src_k_cache_id,
const int64_t src_v_cache_id,
const std::vector& src_blocks,
const std::vector& dst_blocks) {
return channel_->pull_kv_blocks(src_cluster_id,
src_addr,
src_k_cache_id,
src_v_cache_id,
src_blocks,
dst_blocks);
}
bool CommChannel::pull_kv_blocks(const uint64_t src_cluster_id,
const std::string& src_addr,
const int64_t src_k_cache_id,
const int64_t src_v_cache_id,
const std::vector& src_blocks,
const std::vector& dst_blocks) {
proto::PullKVCacheRequest request;
request.set_cluster_id(src_cluster_id);
request.set_addr(src_addr);
request.set_k_cache_id(src_k_cache_id);
request.set_v_cache_id(src_v_cache_id);
ADD_VECTOR_TO_PROTO(request.mutable_src_blocks(), src_blocks);
ADD_VECTOR_TO_PROTO(request.mutable_dst_blocks(), dst_blocks);
proto::Status s;
brpc::Controller cntl;
stub_->PullKVCache(&cntl, &request, &s, nullptr);
return !cntl.Failed() && s.ok();
}
bool CommChannel::pull_kv_blocks(const uint64_t src_cluster_id,
const std::string& src_addr,
const int64_t src_k_cache_id,
const int64_t src_v_cache_id,
const std::vector& src_blocks,
const std::vector& dst_blocks) {
stub_->PullKVCache(&cntl, &request, &s, nullptr);
}
void WorkerService::PullKVCache(::google::protobuf::RpcController* controller,
const proto::PullKVCacheRequest* req,
proto::Status* resp,
::google::protobuf::Closure* done) {
threadpool_->schedule([this, controller, req, resp, done]() mutable {
brpc::ClosureGuard done_guard(done);
uint64_t src_cluster_id = req->cluster_id();
std::string addr = req->addr();
int64_t src_k_cache_id = req->k_cache_id();
int64_t src_v_cache_id = req->v_cache_id();
std::vector src_blocks(req->src_blocks().begin(),
req->src_blocks().end());
std::vector dst_blocks(req->dst_blocks().begin(),
req->dst_blocks().end());
auto future = worker_->pull_kv_blocks_async(src_cluster_id,
addr,
src_k_cache_id,
src_v_cache_id,
src_blocks,
dst_blocks);
bool status = std::move(future).get();
resp->set_ok(status);
});
return;
}
folly::SemiFuture WorkerImpl::pull_kv_blocks_async(
uint64_t src_cluster_id,
const std::string& src_addr,
int64_t src_k_cache_id,
int64_t src_v_cache_id,
const std::vector& src_blocks,
const std::vector& dst_blocks) {
#if defined(USE_NPU)
return kv_cache_transfer_->pull_kv_blocks_async(src_cluster_id,
src_addr,
src_k_cache_id,
src_v_cache_id,
src_blocks,
dst_blocks);
#endif
return false;
}
```
kv_cache_transfer_由KVCacheTransferFactory::create创建,有三种实现。
* LlmDataDist
* MooncakeKVCacheTransfer
* HcclKVCacheTransfer