diff --git a/vllm/v1/core/kv_cache_coordinator.py b/vllm/v1/core/kv_cache_coordinator.py index 993ce4b484f9..231bad1df922 100644 --- a/vllm/v1/core/kv_cache_coordinator.py +++ b/vllm/v1/core/kv_cache_coordinator.py @@ -164,7 +164,7 @@ class KVCacheCoordinator(ABC): Get the blocks for the request. """ return [ - manager.req_to_blocks[request_id] + manager.req_to_blocks.get(request_id) or [] for manager in self.single_type_managers ] diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index f3b5c74829a9..b3293d9a541f 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -76,6 +76,9 @@ class Scheduler(SchedulerInterface): # KV Connector pushes/pull of remote KVs for P/D and offloading. self.connector = None if self.vllm_config.kv_transfer_config is not None: + assert len(self.kv_cache_config.kv_cache_groups) == 1, ( + "Multiple KV cache groups are not currently supported " + "with KV connectors") self.connector = KVConnectorFactory.create_connector_v1( config=self.vllm_config, role=KVConnectorRole.SCHEDULER) @@ -985,9 +988,8 @@ class Scheduler(SchedulerInterface): """ if self.connector is None: return False, None - assert len(self.kv_cache_config.kv_cache_groups - ) == 1, "KV connector only supports one KV cache group now" - block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0] + + (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) return self.connector.request_finished(request, block_ids) def _update_waiting_for_remote_kv(self, request: Request) -> bool: @@ -1002,12 +1004,12 @@ class Scheduler(SchedulerInterface): and the request state will be moved back to WAITING from WAITING_FOR_REMOTE_KV. """ + assert self.connector is not None if request.request_id not in self.finished_recving_kv_req_ids: return False - assert len(self.kv_cache_config.kv_cache_groups - ) == 1, "KV connector only supports one KV cache group now" + # Now that the blocks are ready, actually cache them. - block_ids = self.kv_cache_manager.get_block_ids(request.request_id)[0] + (block_ids, ) = self.kv_cache_manager.get_block_ids(request.request_id) num_computed_tokens = len(block_ids) * self.block_size # Handle the case where num request tokens less then one block. num_computed_tokens = min(num_computed_tokens, request.num_tokens)