From 949cb0170d157fb4b1df8c5c82fea6b1362f0308 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 10 Oct 2025 13:29:57 -0700 Subject: [PATCH] [BugFix] Fix async scheduling + request preemption (#26385) Signed-off-by: Nick Hill --- tests/v1/e2e/test_async_sched_and_preempt.py | 96 ++++++++++++++++++++ vllm/v1/worker/gpu_model_runner.py | 11 ++- 2 files changed, 104 insertions(+), 3 deletions(-) create mode 100644 tests/v1/e2e/test_async_sched_and_preempt.py diff --git a/tests/v1/e2e/test_async_sched_and_preempt.py b/tests/v1/e2e/test_async_sched_and_preempt.py new file mode 100644 index 0000000000000..54fa9aca381bb --- /dev/null +++ b/tests/v1/e2e/test_async_sched_and_preempt.py @@ -0,0 +1,96 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Any + +import pytest + +from vllm import SamplingParams + +from ...conftest import VllmRunner +from ...models.utils import check_outputs_equal + +MODEL = "Qwen/Qwen3-0.6B" + + +def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch): + """Test consistency of combos of async scheduling, preemption, + uni/multiproc executor, and various sampling parameters.""" + + first_prompt = ( + "The following numbers of the sequence " + + ", ".join(str(i) for i in range(10)) + + " are:" + ) + example_prompts = [first_prompt, "In one word, the capital of France is "] + [ + f"Tell me about the number {i}: " for i in range(32) + ] + + sampling_param_tests: list[dict[str, Any]] = [ + dict(), + # dict(min_tokens=20), + # TODO enable these with https://github.com/vllm-project/vllm/pull/26467. + # dict(repetition_penalty=0.1), + # dict(bad_words=[]), + ] + + default_params = dict( + temperature=0.0, # greedy + max_tokens=20, + ) + + with monkeypatch.context() as m: + m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION") + # m.setenv("VLLM_KERNEL_OVERRIDE_BATCH_INVARIANT", "1") + + outputs = [] + for test_preemption in [False, True]: + for executor in ["uni", "mp"]: + for async_scheduling in [False, True]: + cache_arg: dict[str, Any] = ( + dict(num_gpu_blocks_override=32) + if test_preemption + else dict(gpu_memory_utilization=0.7) + ) + test_config = ( + f"executor={executor}, preemption={test_preemption}," + f" async_sched={async_scheduling}" + ) + print("-" * 80) + print(f"---- TESTING: {test_config}") + print("-" * 80) + with VllmRunner( + MODEL, + max_model_len=512, + enforce_eager=True, + async_scheduling=async_scheduling, + distributed_executor_backend=executor, + dtype="float32", # avoid precision errors + **cache_arg, + ) as vllm_model: + results = [] + for override_params in sampling_param_tests: + print(f"----------- RUNNING PARAMS: {override_params}") + results.append( + vllm_model.generate( + example_prompts, + sampling_params=SamplingParams( + **default_params, **override_params + ), + ) + ) + outputs.append((test_config, results)) + + baseline_config, baseline_tests = outputs[0] + + for test_config, test_outputs in outputs[1:]: + for base_outs, test_outs, params in zip( + baseline_tests, test_outputs, sampling_param_tests + ): + check_outputs_equal( + outputs_0_lst=base_outs, + outputs_1_lst=test_outs, + name_0=f"baseline=[{baseline_config}], params={params}", + name_1=f"config=[{test_config}], params={params}", + ) + + print(f"PASSED: config=[{test_config}], params={params}") diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f5b73e46a239c..ba27385589a7a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -754,6 +754,12 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): # Replace the existing block IDs with the new ones. req_state.block_ids = new_block_ids + if self.use_async_scheduling and num_output_tokens > 0: + # We must recover the output token ids for resumed requests in the + # async scheduling case, so that correct input_ids are obtained. + resumed_token_ids = req_data.resumed_req_token_ids[i] + assert resumed_token_ids is not None + req_state.output_token_ids = resumed_token_ids[-num_output_tokens:] if req_index is None: # The request is not in the persistent batch. # The request was either preempted and resumed later, or was not @@ -991,7 +997,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.is_token_ids.copy_to_gpu(total_num_scheduled_tokens) if num_commmon_tokens == 0: # No requests in common with the previous iteration - # So input_ids_cpu will have all the input ids. + # So input_ids.cpu will have all the input ids. return if indices_match and max_flattened_index == (num_commmon_tokens - 1): # Common-case optimization: the batch is unchanged @@ -1005,8 +1011,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): if self.enable_prompt_embeds: self.is_token_ids.gpu[:num_commmon_tokens] = True return - # Upload the index tensors asynchronously - # so the scatter can be non-blocking. + # Upload the index tensors asynchronously so the scatter can be non-blocking. input_ids_index_tensor = torch.tensor( flattened_indices, dtype=torch.int64, pin_memory=self.pin_memory ).to(self.device, non_blocking=True)