diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/base.py b/vllm/distributed/kv_transfer/kv_connector/v1/base.py index 8f9d70eec038b..f80b5eba235dd 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/base.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/base.py @@ -8,9 +8,15 @@ The class provides the following primitives: Scheduler-side: runs in the scheduler, binds metadata, which is used by the worker-side to load/save KV cache. get_num_new_matched_tokens() - get number of new tokens - that exist in the remote KV cache + that exist in the remote KV cache. Might be called multiple + times for a given request and should be side-effect free. update_state_after_alloc() - update KVConnector state after temporary buffer alloc by the CacheManager. + request_finished() - called when a request is finished, with + the computed kv cache blocks for the request. + Returns whether KV cache should be freed now or will be + freed asynchronously and optionally returns KV transfer + params. Worker-side: runs in each worker, loads/saves KV cache to/from the Connector based on the metadata. @@ -19,6 +25,9 @@ The class provides the following primitives: save_kv_layer() - starts saving KV for layer i (maybe async) wait_for_save() - blocks until all saves are done + + get_finished() - called with ids of finished requests, returns + ids of requests that have completed async sending/recving. """ import enum @@ -184,7 +193,8 @@ class KVConnectorBase_V1(ABC): finished generating tokens. Returns: - ids of requests that have finished asynchronous transfer, + ids of requests that have finished asynchronous transfer + (requests that previously returned True from request_finished()), tuple of (sending/saving ids, recving/loading ids). The finished saves/sends req ids must belong to a set provided in a call to this method (this call or a prior one). @@ -215,7 +225,8 @@ class KVConnectorBase_V1(ABC): - The number of tokens that can be loaded from the external KV cache beyond what is already computed. - `True` if external KV cache tokens will be loaded - asynchronously (between scheduler steps). + asynchronously (between scheduler steps). Must be + 'False' if the first element is 0. """ pass @@ -225,6 +236,18 @@ class KVConnectorBase_V1(ABC): num_external_tokens: int): """ Update KVConnector state after block allocation. + + If get_num_new_matched_tokens previously returned True for a + request, this function may be called twice for that same request - + first when blocks are allocated for the connector tokens to be + asynchronously loaded into, and second when any additional blocks + are allocated, after the load/transfer is complete. + + Args: + request (Request): the request object. + blocks (KVCacheBlocks): the blocks allocated for the request. + num_external_tokens (int): the number of tokens that will be + loaded from the external KV cache. """ pass diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index cb9fa61dbaba4..f3b5c74829a9a 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -101,7 +101,7 @@ class Scheduler(SchedulerInterface): # This is flushed at the end of each scheduling step. self.finished_req_ids: set[str] = set() - # P/D: requests in process of recving KV transfers + # KV Connector: requests in process of async KV loading or recving self.finished_recving_kv_req_ids: set[str] = set() # OPTIMIZATION: Cache the CachedRequestData objects to avoid creating @@ -822,7 +822,7 @@ class Scheduler(SchedulerInterface): if not stopped: new_running.append(request) - # P/D: update state for finished KV Transfers. + # KV Connector: update state for finished KV Transfers. self._update_from_kv_xfer_finished(model_runner_output) # Return the cached request data to the queue so they can be reused. @@ -969,7 +969,7 @@ class Scheduler(SchedulerInterface): self.kv_event_publisher.shutdown() ######################################################################## - # P/D Related Methods + # KV Connector Related Methods ######################################################################## def get_kv_connector(self) -> Optional[KVConnectorBase_V1]: @@ -992,7 +992,7 @@ class Scheduler(SchedulerInterface): def _update_waiting_for_remote_kv(self, request: Request) -> bool: """ - P/D: check if the request_id is finished_recving. + KV Connector: check if the request_id is finished_recving. The finished_recving_kv_req_ids list is populated on the previous steps()'s update_from_output based @@ -1029,7 +1029,7 @@ class Scheduler(SchedulerInterface): def _update_from_kv_xfer_finished(self, model_runner_output: ModelRunnerOutput): """ - P/D: update the scheduler state based on the output. + KV Connector: update the scheduler state based on the output. The Worker side connectors add finished_recving and finished_sending reqs to the output. @@ -1037,7 +1037,7 @@ class Scheduler(SchedulerInterface): # if finished_recving: add to state so we can scheduler the request during the next step. """ - # P/D: update recv and send status from last step. + # KV Connector:: update recv and send status from last step. for req_id in (model_runner_output.finished_recving or ()): logger.debug("Finished recving KV transfer for request %s", req_id) self.finished_recving_kv_req_ids.add(req_id)