From 4fe58953611ede752e34b67ae785fed28be66465 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Tue, 28 Oct 2025 15:35:54 -0700 Subject: [PATCH] [AsyncScheduling] Make async overlap work with logprobs (#27615) Signed-off-by: Nick Hill --- tests/conftest.py | 10 ++++-- tests/v1/e2e/test_async_sched_and_preempt.py | 37 +++++++++++++++++--- vllm/v1/outputs.py | 9 +++++ vllm/v1/worker/gpu_model_runner.py | 19 +++++++--- 4 files changed, 65 insertions(+), 10 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index ec0179b9cd5ab..91155a72b16ca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -831,8 +831,9 @@ class VllmRunner: images: PromptImageInput | None = None, videos: PromptVideoInput | None = None, audios: PromptAudioInput | None = None, + return_logprobs: bool = False, **kwargs: Any, - ) -> list[tuple[list[list[int]], list[str]]]: + ) -> list[tuple[list[list[int]], list[str]]] | tuple[list, list]: inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) req_outputs = self.llm.generate( @@ -840,18 +841,23 @@ class VllmRunner: ) outputs: list[tuple[list[list[int]], list[str]]] = [] + logprobs = [] for req_output in req_outputs: prompt_str = req_output.prompt prompt_ids = req_output.prompt_token_ids req_sample_output_ids: list[list[int]] = [] req_sample_output_strs: list[str] = [] + req_logprobs = [] for sample in req_output.outputs: output_str = sample.text output_ids = list(sample.token_ids) req_sample_output_ids.append(prompt_ids + output_ids) req_sample_output_strs.append((prompt_str or "") + output_str) + if sample.logprobs: + req_logprobs.extend(sample.logprobs) outputs.append((req_sample_output_ids, req_sample_output_strs)) - return outputs + logprobs.append(req_logprobs) + return outputs if not return_logprobs else (outputs, logprobs) @staticmethod def _final_steps_generate_w_logprobs( diff --git a/tests/v1/e2e/test_async_sched_and_preempt.py b/tests/v1/e2e/test_async_sched_and_preempt.py index 7ad9606a66df6..15a1cc2558177 100644 --- a/tests/v1/e2e/test_async_sched_and_preempt.py +++ b/tests/v1/e2e/test_async_sched_and_preempt.py @@ -6,6 +6,7 @@ import pytest import torch._dynamo.config as dynamo_config from vllm import SamplingParams +from vllm.logprobs import Logprob from ...conftest import VllmRunner from ...models.utils import check_outputs_equal @@ -32,6 +33,8 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch): # dict(min_tokens=20), dict(presence_penalty=-1.0), dict(bad_words=["the", " the"]), + dict(logprobs=2), + dict(logprobs=2, presence_penalty=-1.0), ] default_params = dict( @@ -77,29 +80,33 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch): sampling_params=SamplingParams( **default_params, **override_params ), + return_logprobs=True, ) ) if not outputs: # First check that the different parameter configs # actually result in different output. - for other_test, params in zip( + for (other_test_outs, other_test_logprobs), params in zip( results[1:], sampling_param_tests[1:] ): with pytest.raises(AssertionError): check_outputs_equal( - outputs_0_lst=results[0], - outputs_1_lst=other_test, + outputs_0_lst=results[0][0], + outputs_1_lst=other_test_outs, name_0=f"baseline params={params}", name_1=f"other params={params}", ) + assert _all_logprobs_match( + results[0][1], other_test_logprobs + ) outputs.append((test_config, results)) baseline_config, baseline_tests = outputs[0] for test_config, test_outputs in outputs[1:]: - for base_outs, test_outs, params in zip( + for (base_outs, base_logprobs), (test_outs, test_logprobs), params in zip( baseline_tests, test_outputs, sampling_param_tests ): check_outputs_equal( @@ -108,5 +115,27 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch): name_0=f"baseline=[{baseline_config}], params={params}", name_1=f"config=[{test_config}], params={params}", ) + assert _all_logprobs_match(base_logprobs, test_logprobs) print(f"PASSED: config=[{test_config}], params={params}") + + +def _all_logprobs_match(req_a, req_b) -> bool: + return ( + req_a == req_b + or len(req_a) == len(req_b) + and all( + len(seq_a) == len(seq_b) + and all(_logprobs_match(a, b) for a, b in zip(seq_a, seq_b)) + for seq_a, seq_b in zip(req_a, req_b) + ) + ) + + +def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool: + return len(lps_a) == len(lps_b) and all( + a.decoded_token == b.decoded_token + and a.rank == b.rank + and a.logprob == pytest.approx(b.logprob, rel=1e-3, abs=1e-6) + for a, b in ((lps_a[x], lps_b[x]) for x in lps_a) + ) diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 10f97576b60af..e7122ba339681 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -59,6 +59,15 @@ class LogprobsTensors(NamedTuple): cu_num_generated_tokens, ) + def to_cpu_nonblocking(self) -> "LogprobsTensors": + if self.logprob_token_ids.device.type == "cpu": + return self + return LogprobsTensors( + self.logprob_token_ids.to("cpu", non_blocking=True), + self.logprobs.to("cpu", non_blocking=True), + self.selected_token_ranks.to("cpu", non_blocking=True), + ) + @staticmethod def empty_cpu( num_positions: int, num_tokens_per_position: int diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 129d7e54466ad..e350988456f12 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -164,6 +164,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): self, model_runner_output: ModelRunnerOutput, sampled_token_ids: torch.Tensor, + logprobs_tensors: torch.Tensor | None, invalid_req_indices: list[int], async_output_copy_stream: torch.cuda.Stream, ): @@ -176,6 +177,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): # Keep a reference to the device tensor to avoid it being # deallocated until we finish copying it to the host. self._sampled_token_ids = sampled_token_ids + self._logprobs_tensors = logprobs_tensors # Initiate the copy on a separate stream, but do not synchronize it. default_stream = torch.cuda.current_stream() @@ -184,6 +186,11 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): self.sampled_token_ids_cpu = self._sampled_token_ids.to( "cpu", non_blocking=True ) + self._logprobs_tensors_cpu = ( + self._logprobs_tensors.to_cpu_nonblocking() + if self._logprobs_tensors + else None + ) self.async_copy_ready_event.record() def get_output(self) -> ModelRunnerOutput: @@ -193,7 +200,8 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): """ self.async_copy_ready_event.synchronize() - # Release the device tensor once the copy has completed + # Release the device tensors once the copy has completed. + del self._logprobs_tensors del self._sampled_token_ids valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist() @@ -202,6 +210,10 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput): output = self._model_runner_output output.sampled_token_ids = valid_sampled_token_ids + if self._logprobs_tensors_cpu: + # NOTE(nick): this will need to be updated to use cu_num_accepted_tokens + # for async sched + spec decode + logprobs compatibility. + output.logprobs = self._logprobs_tensors_cpu.tolists() return output @@ -2334,11 +2346,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): cu_num_accepted_tokens[-1] + len(sampled_ids) ) - # NOTE: GPU -> CPU Sync happens here. - # Move as many CPU operations as possible before this sync point. logprobs_lists = ( logprobs_tensors.tolists(cu_num_accepted_tokens) - if logprobs_tensors is not None + if not self.use_async_scheduling and logprobs_tensors is not None else None ) @@ -2664,6 +2674,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): async_output = AsyncGPUModelRunnerOutput( model_runner_output=output, sampled_token_ids=sampler_output.sampled_token_ids, + logprobs_tensors=sampler_output.logprobs_tensors, invalid_req_indices=invalid_req_indices, async_output_copy_stream=self.async_output_copy_stream, )