[KVConnector] Aggregate finished requests on the scheduler (#19555)

Signed-off-by: Or Ozeri <oro@il.ibm.com>
This commit is contained in:
Or Ozeri 2025-07-10 11:22:18 +03:00 committed by GitHub
parent fdfd409f8f
commit cc876d0f29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 139 additions and 110 deletions

View File

@ -190,7 +190,9 @@ class KVConnectorBase_V1(ABC):
) -> tuple[Optional[set[str]], Optional[set[str]]]:
"""
Notifies worker-side connector ids of requests that have
finished generating tokens.
finished generating tokens on the worker.
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.
Returns:
ids of requests that have finished asynchronous transfer

View File

@ -408,14 +408,6 @@ class NixlConnectorWorker:
# Track the expiration time of requests that are waiting to be sent.
self._reqs_to_send: dict[ReqId, float] = {}
# Complete transfer tracker. Used by the rank 0 to track finished
# transactions on ranks 1 to N-1.
# [req_id -> count]
self._done_recving_count: defaultdict[ReqId,
int] = defaultdict(lambda: 0)
self._done_sending_count: defaultdict[ReqId,
int] = defaultdict(lambda: 0)
# Background thread for handling new handshake requests.
self._nixl_handshake_listener_t: Optional[threading.Thread] = None
# Background thread for initializing new NIXL handshakes.
@ -830,15 +822,9 @@ class NixlConnectorWorker:
def get_finished(self) -> tuple[set[str], set[str]]:
"""
Get requests that are done sending or recving.
In TP>1 setup, each rank exchanges KVs with its counterpart
ranks independently. get_finished() runs in a worker creates
the done_sending and done_recving sets that are sent to the
scheduler via ModelRunnerOutput by Rank 0. To ensure trnxs
are done before adding to finished, Ranks 1 to N-1 communicate
to Rank 0 once their transaction is done + Rank 0 returns
finished sets to Scheduler only once all ranks are done.
Get requests that are done sending or recving on this specific worker.
The scheduler process (via the MultiprocExecutor) will use this output
to track which workers are done.
"""
done_sending = self._get_new_notifs()
done_recving = self._pop_done_transfers(self._recving_transfers)
@ -858,50 +844,7 @@ class NixlConnectorWorker:
del self._reqs_to_send[req_id]
done_sending.add(req_id)
if self.world_size == 1:
return done_sending, done_recving
# Rank 0: get finished from all other ranks.
if self.tp_rank == 0:
for req_id in done_sending:
self._done_sending_count[req_id] += 1
for req_id in done_recving:
self._done_recving_count[req_id] += 1
# Keep track of how many other ranks have finished.
other_ranks_finished_ids: list[str] = []
for i in range(1, self.world_size):
other_ranks_finished_ids.extend(
self.tp_group.recv_object(src=i))
for req_id in other_ranks_finished_ids:
if (req_id in self._done_recving_count
or req_id in self._recving_transfers):
self._done_recving_count[req_id] += 1
else:
self._done_sending_count[req_id] += 1
# Return ids that finished on all ranks to the scheduler.
all_done_recving: set[str] = set()
for req_id in list(self._done_recving_count.keys()):
if self._done_recving_count[req_id] == self.world_size:
del self._done_recving_count[req_id]
all_done_recving.add(req_id)
all_done_sending: set[str] = set()
for req_id in list(self._done_sending_count.keys()):
if self._done_sending_count[req_id] >= self.world_size:
del self._done_sending_count[req_id]
all_done_sending.add(req_id)
return all_done_sending, all_done_recving
# Ranks 1 to N-1: send finished ids to Rank 0.
else:
finished_req_ids = list(done_recving.union(done_sending))
self.tp_group.send_object(finished_req_ids, dst=0)
# Unused as only Rank 0 results are sent to scheduler.
return done_sending, done_recving
return done_sending, done_recving
def _get_new_notifs(self) -> set[str]:
"""

View File

@ -9,7 +9,8 @@ import threading
import time
import traceback
import weakref
from concurrent.futures import Future, ThreadPoolExecutor
from collections import defaultdict
from concurrent.futures import CancelledError, Future, ThreadPoolExecutor
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
@ -111,10 +112,19 @@ class MultiprocExecutor(Executor):
if self.max_concurrent_batches > 1:
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue
# _async_aggregate_workers_output also assumes a single IO thread
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io")
self.output_rank = self._get_output_rank()
self.has_connector = self.vllm_config.kv_transfer_config is not None
# Complete transfer tracker. Used by to track finished requests
# [req_id -> n_finished_workers]
self._recv_remaining_count = defaultdict[str,
int](lambda: self.world_size)
self._send_remaining_count = defaultdict[str,
int](lambda: self.world_size)
def start_worker_monitor(self):
workers = self.workers
@ -155,13 +165,29 @@ class MultiprocExecutor(Executor):
self,
scheduler_output,
) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]:
(output, ) = self.collective_rpc(
non_block = self.max_concurrent_batches > 1
if not self.has_connector:
# get output only from a single worker (output_rank)
(output, ) = self.collective_rpc(
"execute_model",
args=(scheduler_output, ),
unique_reply_rank=self.output_rank,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
return output
# get output from all workers
outputs = self.collective_rpc(
"execute_model",
args=(scheduler_output, ),
unique_reply_rank=self.output_rank,
non_block=self.max_concurrent_batches > 1,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS)
return output
# aggregate all workers output to a single output
if non_block:
return self._async_aggregate_workers_output(outputs)
return self._aggregate_workers_output(outputs)
def collective_rpc(self,
method: Union[str, Callable],
@ -220,6 +246,80 @@ class MultiprocExecutor(Executor):
except TimeoutError as e:
raise TimeoutError(f"RPC call to {method} timed out.") from e
def _aggregate_workers_output(
self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput:
# aggregate finished_sending, finished_recving from all workers
finished_sending = set[str]()
finished_recving = set[str]()
for output in outputs:
# update finished_sending
for req_id in output.finished_sending or []:
new_count = self._send_remaining_count[req_id] - 1
if new_count == 0:
# got response from all workers, report back to scheduler
finished_sending.add(req_id)
del self._send_remaining_count[req_id]
else:
self._send_remaining_count[req_id] = new_count
# update finished_recving
for req_id in output.finished_recving or []:
new_count = self._recv_remaining_count[req_id] - 1
if new_count == 0:
# got response from all workers, report back to scheduler
finished_recving.add(req_id)
del self._recv_remaining_count[req_id]
else:
self._recv_remaining_count[req_id] = new_count
# select output of the worker specified by output_rank
output = outputs[self.output_rank]
# set the aggregated finished_sending / finished_recving
if finished_sending:
output.finished_sending = finished_sending
if finished_recving:
output.finished_recving = finished_recving
return output
def _async_aggregate_workers_output(
self, output_futures: list[Future[ModelRunnerOutput]]
) -> (Future[ModelRunnerOutput]):
"""Takes a list of futures and returns a single future which resolves
to the respective list of outputs."""
result_future: Future[ModelRunnerOutput] = Future()
outputs: list[Optional[ModelRunnerOutput]] = [None
] * len(output_futures)
def make_callback(idx):
def callback(fut):
if result_future.done():
return
try:
outputs[idx] = fut.result()
except CancelledError:
result_future.cancel()
except Exception as e:
result_future.set_exception(e)
# this check assumes io_thread_pool uses a single thread
if all(outputs):
result_future.set_result(
self._aggregate_workers_output(
cast(list[ModelRunnerOutput], outputs)))
return callback
for i, output_future in enumerate(output_futures):
output_future.add_done_callback(make_callback(i))
return result_future
@staticmethod
def _ensure_worker_termination(worker_procs: list[BaseProcess]):
"""Ensure that all worker processes are terminated. Assumes workers have

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
import gc
import time
import weakref
@ -1234,8 +1233,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states: torch.Tensor,
num_scheduled_tokens: int,
num_scheduled_tokens_np: np.ndarray,
finished_sending: Optional[set[str]],
finished_recving: Optional[set[str]],
) -> ModelRunnerOutput:
assert self.input_batch.num_reqs ==\
len(self.input_batch.pooling_params), \
@ -1270,8 +1267,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output,
finished_sending=finished_sending,
finished_recving=finished_recving,
)
@torch.inference_mode()
@ -1282,11 +1277,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
) -> Union[ModelRunnerOutput, IntermediateTensors]:
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
if has_kv_transfer_group():
with set_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
return self.kv_connector_no_forward(scheduler_output)
# Return empty ModelRunnerOutput if there's no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
# Prepare the decoder inputs.
(attn_metadata, attention_cuda_graphs, logits_indices,
@ -1379,8 +1375,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
)
self.maybe_wait_for_kv_save()
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
if self.use_aux_hidden_state_outputs:
hidden_states, aux_hidden_states = model_output
@ -1406,8 +1400,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
else:
if self.input_batch.pooling_params:
return self._pool(hidden_states, num_scheduled_tokens,
num_scheduled_tokens_np, finished_sending,
finished_recving)
num_scheduled_tokens_np)
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
@ -1560,8 +1553,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
finished_sending=finished_sending,
finished_recving=finished_recving,
num_nans_in_logits=num_nans_in_logits,
)
@ -1686,22 +1677,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids
def kv_connector_no_forward(
self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput:
# KV send/recv even if no work to do.
with set_forward_context(None, self.vllm_config):
self.maybe_setup_kv_connector(scheduler_output)
finished_sending, finished_recving = (
self.get_finished_kv_transfers(scheduler_output))
if not finished_sending and not finished_recving:
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.finished_sending = finished_sending
output.finished_recving = finished_recving
return output
@staticmethod
def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"):
# Update KVConnector with the KVConnector metadata forward().
@ -1723,15 +1698,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if has_kv_transfer_group():
get_kv_transfer_group().wait_for_save()
@staticmethod
def get_finished_kv_transfers(
scheduler_output: "SchedulerOutput",
) -> tuple[Optional[set[str]], Optional[set[str]]]:
if has_kv_transfer_group():
return get_kv_transfer_group().get_finished(
scheduler_output.finished_req_ids)
return None, None
def propose_ngram_draft_token_ids(
self,
sampled_token_ids: list[list[int]],

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A GPU worker class."""
import copy
import gc
import os
from typing import TYPE_CHECKING, Optional
@ -14,7 +15,9 @@ from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized
from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized,
get_kv_transfer_group,
has_kv_transfer_group)
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -23,7 +26,7 @@ from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase
@ -316,14 +319,29 @@ class Worker(WorkerBase):
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
parallel_config = self.vllm_config.parallel_config
if parallel_config.distributed_executor_backend != "external_launcher" \
and not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
return None
output = EMPTY_MODEL_RUNNER_OUTPUT
assert isinstance(output, ModelRunnerOutput)
if has_kv_transfer_group():
finished_sending, finished_recving = (
get_kv_transfer_group().get_finished(
scheduler_output.finished_req_ids))
if finished_sending or finished_recving:
if output is EMPTY_MODEL_RUNNER_OUTPUT:
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.finished_sending = finished_sending
output.finished_recving = finished_recving
# with a connector, the scheduler expects output from all workers
return output
# return output only from the driver worker
return output if self.is_driver_worker else None
def profile(self, is_start: bool = True):