From 2b10ba749177513e6423ff26bbb6d45fe17ee62b Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Date: Fri, 23 May 2025 19:30:16 -0400 Subject: [PATCH] [Bugfix][Nixl] Fix Preemption Bug (#18631) Signed-off-by: rshaw@neuralmagic.com --- .../unit/test_remote_prefill_lifecycle.py | 81 +++++++++++++++++++ vllm/v1/core/sched/scheduler.py | 31 +++---- 2 files changed, 97 insertions(+), 15 deletions(-) diff --git a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py index fc4928f9ebd19..6fcff0d620452 100644 --- a/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py +++ b/tests/v1/kv_connector/unit/test_remote_prefill_lifecycle.py @@ -340,3 +340,84 @@ def test_full_block_prompt(): output = outputs[0] assert output.finish_reason == FinishReason.STOP assert_scheduler_empty(scheduler) + + +def test_cannot_schedule_after_recv(): + """ + Test that we can handle no schedule after recv due to not + enough remaining KV blocks. + """ + + # NOTE: the KVCacheManager will use 1 null block. + # So there are 5 total working blocks. + TOTAL_NUM_BLOCKS = 6 + vllm_config = create_vllm_config() + scheduler = create_scheduler(vllm_config, num_blocks=TOTAL_NUM_BLOCKS) + + # Prime the KVCache. + NUM_PROMPT_BLOCKS = 2 + BLOCK_SIZE = vllm_config.cache_config.block_size + # Prompt will use 2 blocks + 1 block after we schedule. + NUM_TOKENS_LOCAL = int(BLOCK_SIZE * NUM_PROMPT_BLOCKS) + NUM_TOKENS_REMOTE = int(BLOCK_SIZE * (NUM_PROMPT_BLOCKS + 0.5)) + + request_normal = create_request(request_id=1, num_tokens=NUM_TOKENS_LOCAL) + request_remote = create_request(request_id=2, + num_tokens=NUM_TOKENS_REMOTE, + do_remote_prefill=True) + + # STEP 1: 3 blocks are in use (2 for prompt, 1 for decode). + scheduler.add_request(request_normal) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_normal]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 0 + + # Step 2: 5 blocks are in use (2 new for remote blocks). + scheduler.add_request(request_remote) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_normal]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + # Step 3: finish recving (5 blocks in use) + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output( + reqs=[request_normal], finished_recving=[request_remote.request_id]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + # Step 4: try to schedule, not enough blocks. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_normal]) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 1 + + # Step 5: finish the request, free it. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_normal], + use_eos=True) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 0 + assert len(scheduler.waiting) == 1 + + # Step 6: now we can schedule (with 2 blocks computed). + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_remote]) + assert (scheduler_output.scheduled_new_reqs[0].num_computed_tokens == + NUM_PROMPT_BLOCKS * BLOCK_SIZE) + scheduler.update_from_output(scheduler_output, model_runner_output) + assert len(scheduler.running) == 1 + assert len(scheduler.waiting) == 0 + + # Step 7: free everything. + scheduler_output = scheduler.schedule() + model_runner_output = create_model_runner_output(reqs=[request_remote], + use_eos=True) + scheduler.update_from_output(scheduler_output, model_runner_output) + _ = scheduler.schedule() + assert_scheduler_empty(scheduler) diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 1f54560a10a77..efc0de350fba7 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -310,15 +310,16 @@ class Scheduler(SchedulerInterface): break request = self.waiting[0] - num_prealloc_computed_tokens = 0 - # P/D: skip request if still waiting for remote kvs. + + # KVTransfer: skip request if still waiting for remote kvs. if request.status == RequestStatus.WAITING_FOR_REMOTE_KVS: is_ready = self._update_waiting_for_remote_kv(request) if is_ready: request.status = RequestStatus.WAITING - num_prealloc_computed_tokens = ( - request.num_computed_tokens) else: + logger.debug( + "%s is still in WAITING_FOR_REMOTE_KVS state.", + request.request_id) self.waiting.popleft() skipped_waiting_requests.appendleft(request) continue @@ -349,8 +350,9 @@ class Scheduler(SchedulerInterface): load_kv_async = False # Get already-cached tokens. - if num_prealloc_computed_tokens == 0: - new_computed_blocks, num_native_computed_tokens = \ + if request.num_computed_tokens == 0: + # Get locally-cached tokens. + new_computed_blocks, num_new_local_computed_tokens = \ self.kv_cache_manager.get_computed_blocks( request) @@ -358,23 +360,22 @@ class Scheduler(SchedulerInterface): if self.connector is not None: num_external_computed_tokens, load_kv_async = ( self.connector.get_num_new_matched_tokens( - request, num_native_computed_tokens)) + request, num_new_local_computed_tokens)) # Total computed tokens (local + external). - num_computed_tokens = (num_native_computed_tokens + + num_computed_tokens = (num_new_local_computed_tokens + num_external_computed_tokens) + # KVTransfer: WAITING reqs have num_computed_tokens > 0 + # after async KV recvs are completed. else: - # P/D: skip checking prefix cache if loaded from remote kvs. new_computed_blocks = KVCacheBlocks.create_empty() - num_native_computed_tokens = 0 - - # Total computed tokens (allocated in prior step). - num_computed_tokens = num_prealloc_computed_tokens + num_new_local_computed_tokens = 0 + num_computed_tokens = request.num_computed_tokens encoder_inputs_to_schedule = None new_encoder_budget = encoder_budget - # P/D: loading remote KV, do not allocate for new work. + # KVTransfer: loading remote KV, do not allocate for new work. if load_kv_async: assert num_external_computed_tokens > 0 num_new_tokens = 0 @@ -405,7 +406,7 @@ class Scheduler(SchedulerInterface): new_blocks = self.kv_cache_manager.allocate_slots( request, num_new_tokens + num_external_computed_tokens, - num_native_computed_tokens, + num_new_local_computed_tokens, new_computed_blocks, num_lookahead_tokens=self.num_lookahead_tokens, delay_cache_blocks=load_kv_async,