diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index f80b5eba235d..b5199d85d5ae 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -190,7 +190,9 @@ class KVConnectorBase_V1(ABC): ) -> tuple[Optional[set[str]], Optional[set[str]]]: """ Notifies worker-side connector ids of requests that have - finished generating tokens. + finished generating tokens on the worker. + The scheduler process (via the MultiprocExecutor) will use this output + to track which workers are done. Returns: ids of requests that have finished asynchronous transfer diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index d2d3e88eabce..0c5986bfafaa 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -408,14 +408,6 @@ class NixlConnectorWorker: # Track the expiration time of requests that are waiting to be sent. self._reqs_to_send: dict[ReqId, float] = {} - # Complete transfer tracker. Used by the rank 0 to track finished - # transactions on ranks 1 to N-1. - # [req_id -> count] - self._done_recving_count: defaultdict[ReqId, - int] = defaultdict(lambda: 0) - self._done_sending_count: defaultdict[ReqId, - int] = defaultdict(lambda: 0) - # Background thread for handling new handshake requests. self._nixl_handshake_listener_t: Optional[threading.Thread] = None # Background thread for initializing new NIXL handshakes. @@ -830,15 +822,9 @@ class NixlConnectorWorker: def get_finished(self) -> tuple[set[str], set[str]]: """ - Get requests that are done sending or recving. - - In TP>1 setup, each rank exchanges KVs with its counterpart - ranks independently. get_finished() runs in a worker creates - the done_sending and done_recving sets that are sent to the - scheduler via ModelRunnerOutput by Rank 0. To ensure trnxs - are done before adding to finished, Ranks 1 to N-1 communicate - to Rank 0 once their transaction is done + Rank 0 returns - finished sets to Scheduler only once all ranks are done. + Get requests that are done sending or recving on this specific worker. + The scheduler process (via the MultiprocExecutor) will use this output + to track which workers are done. """ done_sending = self._get_new_notifs() done_recving = self._pop_done_transfers(self._recving_transfers) @@ -858,50 +844,7 @@ class NixlConnectorWorker: del self._reqs_to_send[req_id] done_sending.add(req_id) - if self.world_size == 1: - return done_sending, done_recving - - # Rank 0: get finished from all other ranks. - if self.tp_rank == 0: - for req_id in done_sending: - self._done_sending_count[req_id] += 1 - for req_id in done_recving: - self._done_recving_count[req_id] += 1 - - # Keep track of how many other ranks have finished. - other_ranks_finished_ids: list[str] = [] - for i in range(1, self.world_size): - other_ranks_finished_ids.extend( - self.tp_group.recv_object(src=i)) - for req_id in other_ranks_finished_ids: - if (req_id in self._done_recving_count - or req_id in self._recving_transfers): - self._done_recving_count[req_id] += 1 - else: - self._done_sending_count[req_id] += 1 - - # Return ids that finished on all ranks to the scheduler. - all_done_recving: set[str] = set() - for req_id in list(self._done_recving_count.keys()): - if self._done_recving_count[req_id] == self.world_size: - del self._done_recving_count[req_id] - all_done_recving.add(req_id) - - all_done_sending: set[str] = set() - for req_id in list(self._done_sending_count.keys()): - if self._done_sending_count[req_id] >= self.world_size: - del self._done_sending_count[req_id] - all_done_sending.add(req_id) - - return all_done_sending, all_done_recving - - # Ranks 1 to N-1: send finished ids to Rank 0. - else: - finished_req_ids = list(done_recving.union(done_sending)) - self.tp_group.send_object(finished_req_ids, dst=0) - - # Unused as only Rank 0 results are sent to scheduler. - return done_sending, done_recving + return done_sending, done_recving def _get_new_notifs(self) -> set[str]: """ diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index b06b7cc804d5..52812c5859fa 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -9,7 +9,8 @@ import threading import time import traceback import weakref -from concurrent.futures import Future, ThreadPoolExecutor +from collections import defaultdict +from concurrent.futures import CancelledError, Future, ThreadPoolExecutor from dataclasses import dataclass from enum import Enum, auto from functools import partial @@ -111,10 +112,19 @@ class MultiprocExecutor(Executor): if self.max_concurrent_batches > 1: # Note: must use only 1 IO thread to keep dequeue sequence # from the response queue + # _async_aggregate_workers_output also assumes a single IO thread self.io_thread_pool = ThreadPoolExecutor( max_workers=1, thread_name_prefix="mp_exec_io") self.output_rank = self._get_output_rank() + self.has_connector = self.vllm_config.kv_transfer_config is not None + + # Complete transfer tracker. Used by to track finished requests + # [req_id -> n_finished_workers] + self._recv_remaining_count = defaultdict[str, + int](lambda: self.world_size) + self._send_remaining_count = defaultdict[str, + int](lambda: self.world_size) def start_worker_monitor(self): workers = self.workers @@ -155,13 +165,29 @@ class MultiprocExecutor(Executor): self, scheduler_output, ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: - (output, ) = self.collective_rpc( + non_block = self.max_concurrent_batches > 1 + + if not self.has_connector: + # get output only from a single worker (output_rank) + (output, ) = self.collective_rpc( + "execute_model", + args=(scheduler_output, ), + unique_reply_rank=self.output_rank, + non_block=non_block, + timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) + return output + + # get output from all workers + outputs = self.collective_rpc( "execute_model", args=(scheduler_output, ), - unique_reply_rank=self.output_rank, - non_block=self.max_concurrent_batches > 1, + non_block=non_block, timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS) - return output + + # aggregate all workers output to a single output + if non_block: + return self._async_aggregate_workers_output(outputs) + return self._aggregate_workers_output(outputs) def collective_rpc(self, method: Union[str, Callable], @@ -220,6 +246,80 @@ class MultiprocExecutor(Executor): except TimeoutError as e: raise TimeoutError(f"RPC call to {method} timed out.") from e + def _aggregate_workers_output( + self, outputs: list[ModelRunnerOutput]) -> ModelRunnerOutput: + # aggregate finished_sending, finished_recving from all workers + + finished_sending = set[str]() + finished_recving = set[str]() + for output in outputs: + # update finished_sending + for req_id in output.finished_sending or []: + new_count = self._send_remaining_count[req_id] - 1 + if new_count == 0: + # got response from all workers, report back to scheduler + finished_sending.add(req_id) + del self._send_remaining_count[req_id] + else: + self._send_remaining_count[req_id] = new_count + + # update finished_recving + for req_id in output.finished_recving or []: + new_count = self._recv_remaining_count[req_id] - 1 + if new_count == 0: + # got response from all workers, report back to scheduler + finished_recving.add(req_id) + del self._recv_remaining_count[req_id] + else: + self._recv_remaining_count[req_id] = new_count + + # select output of the worker specified by output_rank + output = outputs[self.output_rank] + + # set the aggregated finished_sending / finished_recving + if finished_sending: + output.finished_sending = finished_sending + if finished_recving: + output.finished_recving = finished_recving + + return output + + def _async_aggregate_workers_output( + self, output_futures: list[Future[ModelRunnerOutput]] + ) -> (Future[ModelRunnerOutput]): + """Takes a list of futures and returns a single future which resolves + to the respective list of outputs.""" + result_future: Future[ModelRunnerOutput] = Future() + + outputs: list[Optional[ModelRunnerOutput]] = [None + ] * len(output_futures) + + def make_callback(idx): + + def callback(fut): + if result_future.done(): + return + + try: + outputs[idx] = fut.result() + except CancelledError: + result_future.cancel() + except Exception as e: + result_future.set_exception(e) + + # this check assumes io_thread_pool uses a single thread + if all(outputs): + result_future.set_result( + self._aggregate_workers_output( + cast(list[ModelRunnerOutput], outputs))) + + return callback + + for i, output_future in enumerate(output_futures): + output_future.add_done_callback(make_callback(i)) + + return result_future + @staticmethod def _ensure_worker_termination(worker_procs: list[BaseProcess]): """Ensure that all worker processes are terminated. Assumes workers have diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ef03626cf14d..9cda4dbb9615 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,7 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import copy import gc import time import weakref @@ -1234,8 +1233,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): hidden_states: torch.Tensor, num_scheduled_tokens: int, num_scheduled_tokens_np: np.ndarray, - finished_sending: Optional[set[str]], - finished_recving: Optional[set[str]], ) -> ModelRunnerOutput: assert self.input_batch.num_reqs ==\ len(self.input_batch.pooling_params), \ @@ -1270,8 +1267,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): logprobs=None, prompt_logprobs_dict={}, pooler_output=pooler_output, - finished_sending=finished_sending, - finished_recving=finished_recving, ) @torch.inference_mode() @@ -1282,11 +1277,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) -> Union[ModelRunnerOutput, IntermediateTensors]: self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: - if not has_kv_transfer_group(): - # Return empty ModelRunnerOutput if there's no work to do. - return EMPTY_MODEL_RUNNER_OUTPUT + if has_kv_transfer_group(): + with set_forward_context(None, self.vllm_config): + self.maybe_setup_kv_connector(scheduler_output) - return self.kv_connector_no_forward(scheduler_output) + # Return empty ModelRunnerOutput if there's no work to do. + return EMPTY_MODEL_RUNNER_OUTPUT # Prepare the decoder inputs. (attn_metadata, attention_cuda_graphs, logits_indices, @@ -1379,8 +1375,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) self.maybe_wait_for_kv_save() - finished_sending, finished_recving = ( - self.get_finished_kv_transfers(scheduler_output)) if self.use_aux_hidden_state_outputs: hidden_states, aux_hidden_states = model_output @@ -1406,8 +1400,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): else: if self.input_batch.pooling_params: return self._pool(hidden_states, num_scheduled_tokens, - num_scheduled_tokens_np, finished_sending, - finished_recving) + num_scheduled_tokens_np) sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) @@ -1560,8 +1553,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], - finished_sending=finished_sending, - finished_recving=finished_recving, num_nans_in_logits=num_nans_in_logits, ) @@ -1686,22 +1677,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): spec_token_ids = draft_token_ids.tolist() return spec_token_ids - def kv_connector_no_forward( - self, scheduler_output: "SchedulerOutput") -> ModelRunnerOutput: - # KV send/recv even if no work to do. - with set_forward_context(None, self.vllm_config): - self.maybe_setup_kv_connector(scheduler_output) - finished_sending, finished_recving = ( - self.get_finished_kv_transfers(scheduler_output)) - - if not finished_sending and not finished_recving: - return EMPTY_MODEL_RUNNER_OUTPUT - - output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - output.finished_sending = finished_sending - output.finished_recving = finished_recving - return output - @staticmethod def maybe_setup_kv_connector(scheduler_output: "SchedulerOutput"): # Update KVConnector with the KVConnector metadata forward(). @@ -1723,15 +1698,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): if has_kv_transfer_group(): get_kv_transfer_group().wait_for_save() - @staticmethod - def get_finished_kv_transfers( - scheduler_output: "SchedulerOutput", - ) -> tuple[Optional[set[str]], Optional[set[str]]]: - if has_kv_transfer_group(): - return get_kv_transfer_group().get_finished( - scheduler_output.finished_req_ids) - return None, None - def propose_ngram_draft_token_ids( self, sampled_token_ids: list[list[int]], diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 38c9545e3747..6b30acee1d90 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A GPU worker class.""" +import copy import gc import os from typing import TYPE_CHECKING, Optional @@ -14,7 +15,9 @@ from vllm.config import VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) -from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized +from vllm.distributed.kv_transfer import (ensure_kv_transfer_initialized, + get_kv_transfer_group, + has_kv_transfer_group) from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -23,7 +26,7 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.worker_base import WorkerBase @@ -316,14 +319,29 @@ class Worker(WorkerBase): output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) + parallel_config = self.vllm_config.parallel_config if parallel_config.distributed_executor_backend != "external_launcher" \ and not get_pp_group().is_last_rank: assert isinstance(output, IntermediateTensors) get_pp_group().send_tensor_dict(output.tensors, all_gather_group=get_tp_group()) - return None + output = EMPTY_MODEL_RUNNER_OUTPUT + assert isinstance(output, ModelRunnerOutput) + if has_kv_transfer_group(): + finished_sending, finished_recving = ( + get_kv_transfer_group().get_finished( + scheduler_output.finished_req_ids)) + if finished_sending or finished_recving: + if output is EMPTY_MODEL_RUNNER_OUTPUT: + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.finished_sending = finished_sending + output.finished_recving = finished_recving + # with a connector, the scheduler expects output from all workers + return output + + # return output only from the driver worker return output if self.is_driver_worker else None def profile(self, is_start: bool = True):