From 7ad7adb67f1350b6e9f7cfdd7aacf38eed093bb1 Mon Sep 17 00:00:00 2001 From: Or Ozeri Date: Sat, 9 Aug 2025 09:09:51 +0300 Subject: [PATCH] v1: Pass KVConnectorOutput to scheduler-side (#22157) Signed-off-by: Or Ozeri --- .../distributed/kv_transfer/kv_connector/v1/base.py | 13 +++++++++++++ .../kv_transfer/kv_connector/v1/multi_connector.py | 5 +++++ vllm/v1/core/sched/scheduler.py | 4 ++++ 3 files changed, 22 insertions(+) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 7a2ccb58656f..b72104397822 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -12,6 +12,8 @@ The class provides the following primitives: times for a given request and should be side-effect free. update_state_after_alloc() - update KVConnector state after temporary buffer alloc by the CacheManager. + update_connector_output() - update KVConnector state after + output is received from worker-side connectors. request_finished() - called when a request is finished, with the computed kv cache blocks for the request. Returns whether KV cache should be freed now or will be @@ -38,6 +40,7 @@ import torch from vllm.logger import init_logger from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.outputs import KVConnectorOutput if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -283,6 +286,16 @@ class KVConnectorBase_V1(ABC): """ pass + def update_connector_output(self, connector_output: KVConnectorOutput): + """ + Update KVConnector state from worker-side connectors output. + + Args: + connector_output (KVConnectorOutput): the worker-side + connectors output. + """ + return + def request_finished( self, request: "Request", diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 62a4980bff97..7d67c76e2f05 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -14,6 +14,7 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( from vllm.logger import init_logger from vllm.v1.core.kv_cache_manager import KVCacheBlocks from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.outputs import KVConnectorOutput if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -177,6 +178,10 @@ class MultiConnector(KVConnectorBase_V1): self._extra_async_saves = {} return metadata + def update_connector_output(self, connector_output: KVConnectorOutput): + for c in self._connectors: + c.update_connector_output(connector_output) + def request_finished( self, request: "Request", diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 430085d9c978..85fc1a4a016a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -1150,6 +1150,10 @@ class Scheduler(SchedulerInterface): # if finished_recving: add to state so we can scheduler the request during the next step. """ + + assert self.connector is not None + self.connector.update_connector_output(kv_connector_output) + # KV Connector:: update recv and send status from last step. for req_id in (kv_connector_output.finished_recving or ()): logger.debug("Finished recving KV transfer for request %s", req_id)