From 161010c3847204d441bcc6ec91709d324071e954 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 18 Apr 2025 15:38:00 -0400 Subject: [PATCH] Initial stubs for P/D scheduling changes Signed-off-by: Tyler Michael Smith --- .../kv_transfer/kv_connector/v1/base.py | 6 +++- .../v1/shared_storage_connector.py | 8 +++-- vllm/v1/core/sched/scheduler.py | 31 ++++++++++++++++++- vllm/v1/request.py | 3 ++ vllm/v1/worker/gpu_model_runner.py | 4 +++ 5 files changed, 47 insertions(+), 5 deletions(-) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 95967d2ca9193..a335f43d3ad3a 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -196,7 +196,9 @@ class KVConnectorBase_V1(ABC): @abstractmethod def build_connector_meta( - self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + self, scheduler_output: SchedulerOutput, + sending_KV_req_ids: set[str], + waiting_KV_req_ids: set[str]) -> KVConnectorMetadata: """ Build the connector metadata for this step. @@ -205,5 +207,7 @@ class KVConnectorBase_V1(ABC): Args: scheduler_output (SchedulerOutput): the scheduler output object. + sending_KV_req_ids (set[str]): Request IDs to send + waiting_KV_req_ids (set[str]): Request IDs to receive """ pass diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py index 1d2040784e6cb..fb1f1e24da0a0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py @@ -271,9 +271,9 @@ class SharedStorageConnector(KVConnectorBase_V1): self._requests_need_load[request.request_id] = request def build_connector_meta( - self, - scheduler_output: SchedulerOutput, - ) -> KVConnectorMetadata: + self, scheduler_output: SchedulerOutput, + sending_KV_req_ids: set[str], + waiting_KV_req_ids: set[str]) -> KVConnectorMetadata: """Build the connector metadata for this step. This function should NOT modify any fields in the scheduler_output. @@ -281,6 +281,8 @@ class SharedStorageConnector(KVConnectorBase_V1): Args: scheduler_output (SchedulerOutput): the scheduler output object. + sending_KV_req_ids (set[str]): Request IDs to send + waiting_KV_req_ids (set[str]): Request IDs to receive """ meta = SharedStorageConnectorMetadata() diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 7e658d134cf77..75f42449c2673 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -98,6 +98,10 @@ class Scheduler(SchedulerInterface): # This is flushed at the end of each scheduling step. self.finished_req_ids: set[str] = set() + # Requests in states for tracking KV transfers for P/D disagg + self.sending_KV_req_ids: set[str] = set() + self.waiting_KV_req_ids: set[str] = set() + # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating # them at each scheduling step. # Request id -> CachedRequestData @@ -167,6 +171,21 @@ class Scheduler(SchedulerInterface): # For logging. scheduled_timestamp = time.monotonic() + # Check for new remote decode requests for P/D + if self.connector is not None: + self.waiting_KV_req_ids.update( + self.connector.receive_remote_decode_requests()) + + # Check if any P/D requests have finished sending or receiving + for req_id in list(self.sending_KV_req_ids): + if self.connector.done_sending_remote_decode_request(req_id): + self.sending_KV_req_ids.remove(req_id) + self.finished_req_ids.add(req_id) + for req_id in list(self.waiting_KV_req_ids): + if self.connector.done_waiting_remote_decode_request(req_id): + self.waiting_KV_req_ids.remove(req_id) + self.waiting.append(self.requests[req_id]) + # First, schedule the RUNNING requests. req_index = 0 while req_index < len(self.running) and token_budget > 0: @@ -479,7 +498,9 @@ class Scheduler(SchedulerInterface): # 2. Wrap up all the KV cache load / save ops into an opaque object # 3. Clear the internal states of the connector if self.connector is not None: - meta = self.connector.build_connector_meta(scheduler_output) + meta = self.connector.build_connector_meta(scheduler_output, + self.sending_KV_req_ids, + self.waiting_KV_req_ids) scheduler_output.kv_connector_metadata = meta # Advance the number of computed tokens for the request AFTER @@ -682,6 +703,7 @@ class Scheduler(SchedulerInterface): # Check for stop and update request state. # This must be called before we make the EngineCoreOutput. + # TODO: What if we detect we're done here when doing P/D disagg? stopped = check_stop(request, self.max_model_len) if stopped: self._free_request(request) @@ -718,6 +740,13 @@ class Scheduler(SchedulerInterface): # Invariant: EngineCore returns no partial prefill outputs. assert not prompt_logprobs_tensors + if self.connector is not None and request.do_remote_decode: + stopped = True + + self.sending_KV_req_ids.add(req_id) + self.connector.send_remote_decode_request( + self.kv_cache_manager.req_to_blocks[req_id]) + self.scheduled_req_ids.remove(req_id) if not stopped: new_running.append(request) diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 6be72431dde52..7c7803560bc8e 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -61,6 +61,9 @@ class Request: self.num_encoder_inputs = len(self.mm_inputs) self.has_encoder_inputs = self.num_encoder_inputs > 0 + # P/D disagg related + self.do_remote_decode = False + # Sanity check assert len(self.mm_inputs) == len(self.mm_positions) if self.mm_hashes: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index ac0701c459860..00026731d5164 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -991,6 +991,10 @@ class GPUModelRunner(LoRAModelRunnerMixin): ) -> Union[ModelRunnerOutput, torch.Tensor]: # Update KVConnector with the KVConnector metadata forward(). if has_kv_transfer_group(): + # Background KV cache transfers can happen here, + # since kv_connector_metadata has the req_ids to send/receive. + # Not sure I like doing it here since this does not have to do + # with model execution but this way we don't do a separate rpc. get_kv_transfer_group().bind_connector_metadata( scheduler_output.kv_connector_metadata)