[AsyncScheduling] Make async overlap work with logprobs (#27615)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-10-28 15:35:54 -07:00 committed by GitHub
parent 111faf1118
commit 4fe5895361
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 65 additions and 10 deletions

View File

@ -831,8 +831,9 @@ class VllmRunner:
images: PromptImageInput | None = None,
videos: PromptVideoInput | None = None,
audios: PromptAudioInput | None = None,
return_logprobs: bool = False,
**kwargs: Any,
) -> list[tuple[list[list[int]], list[str]]]:
) -> list[tuple[list[list[int]], list[str]]] | tuple[list, list]:
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
req_outputs = self.llm.generate(
@ -840,18 +841,23 @@ class VllmRunner:
)
outputs: list[tuple[list[list[int]], list[str]]] = []
logprobs = []
for req_output in req_outputs:
prompt_str = req_output.prompt
prompt_ids = req_output.prompt_token_ids
req_sample_output_ids: list[list[int]] = []
req_sample_output_strs: list[str] = []
req_logprobs = []
for sample in req_output.outputs:
output_str = sample.text
output_ids = list(sample.token_ids)
req_sample_output_ids.append(prompt_ids + output_ids)
req_sample_output_strs.append((prompt_str or "") + output_str)
if sample.logprobs:
req_logprobs.extend(sample.logprobs)
outputs.append((req_sample_output_ids, req_sample_output_strs))
return outputs
logprobs.append(req_logprobs)
return outputs if not return_logprobs else (outputs, logprobs)
@staticmethod
def _final_steps_generate_w_logprobs(

View File

@ -6,6 +6,7 @@ import pytest
import torch._dynamo.config as dynamo_config
from vllm import SamplingParams
from vllm.logprobs import Logprob
from ...conftest import VllmRunner
from ...models.utils import check_outputs_equal
@ -32,6 +33,8 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
# dict(min_tokens=20),
dict(presence_penalty=-1.0),
dict(bad_words=["the", " the"]),
dict(logprobs=2),
dict(logprobs=2, presence_penalty=-1.0),
]
default_params = dict(
@ -77,29 +80,33 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
sampling_params=SamplingParams(
**default_params, **override_params
),
return_logprobs=True,
)
)
if not outputs:
# First check that the different parameter configs
# actually result in different output.
for other_test, params in zip(
for (other_test_outs, other_test_logprobs), params in zip(
results[1:], sampling_param_tests[1:]
):
with pytest.raises(AssertionError):
check_outputs_equal(
outputs_0_lst=results[0],
outputs_1_lst=other_test,
outputs_0_lst=results[0][0],
outputs_1_lst=other_test_outs,
name_0=f"baseline params={params}",
name_1=f"other params={params}",
)
assert _all_logprobs_match(
results[0][1], other_test_logprobs
)
outputs.append((test_config, results))
baseline_config, baseline_tests = outputs[0]
for test_config, test_outputs in outputs[1:]:
for base_outs, test_outs, params in zip(
for (base_outs, base_logprobs), (test_outs, test_logprobs), params in zip(
baseline_tests, test_outputs, sampling_param_tests
):
check_outputs_equal(
@ -108,5 +115,27 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
name_0=f"baseline=[{baseline_config}], params={params}",
name_1=f"config=[{test_config}], params={params}",
)
assert _all_logprobs_match(base_logprobs, test_logprobs)
print(f"PASSED: config=[{test_config}], params={params}")
def _all_logprobs_match(req_a, req_b) -> bool:
return (
req_a == req_b
or len(req_a) == len(req_b)
and all(
len(seq_a) == len(seq_b)
and all(_logprobs_match(a, b) for a, b in zip(seq_a, seq_b))
for seq_a, seq_b in zip(req_a, req_b)
)
)
def _logprobs_match(lps_a: dict[int, Logprob], lps_b: dict[int, Logprob]) -> bool:
return len(lps_a) == len(lps_b) and all(
a.decoded_token == b.decoded_token
and a.rank == b.rank
and a.logprob == pytest.approx(b.logprob, rel=1e-3, abs=1e-6)
for a, b in ((lps_a[x], lps_b[x]) for x in lps_a)
)

View File

@ -59,6 +59,15 @@ class LogprobsTensors(NamedTuple):
cu_num_generated_tokens,
)
def to_cpu_nonblocking(self) -> "LogprobsTensors":
if self.logprob_token_ids.device.type == "cpu":
return self
return LogprobsTensors(
self.logprob_token_ids.to("cpu", non_blocking=True),
self.logprobs.to("cpu", non_blocking=True),
self.selected_token_ranks.to("cpu", non_blocking=True),
)
@staticmethod
def empty_cpu(
num_positions: int, num_tokens_per_position: int

View File

@ -164,6 +164,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
self,
model_runner_output: ModelRunnerOutput,
sampled_token_ids: torch.Tensor,
logprobs_tensors: torch.Tensor | None,
invalid_req_indices: list[int],
async_output_copy_stream: torch.cuda.Stream,
):
@ -176,6 +177,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
# Keep a reference to the device tensor to avoid it being
# deallocated until we finish copying it to the host.
self._sampled_token_ids = sampled_token_ids
self._logprobs_tensors = logprobs_tensors
# Initiate the copy on a separate stream, but do not synchronize it.
default_stream = torch.cuda.current_stream()
@ -184,6 +186,11 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
self.sampled_token_ids_cpu = self._sampled_token_ids.to(
"cpu", non_blocking=True
)
self._logprobs_tensors_cpu = (
self._logprobs_tensors.to_cpu_nonblocking()
if self._logprobs_tensors
else None
)
self.async_copy_ready_event.record()
def get_output(self) -> ModelRunnerOutput:
@ -193,7 +200,8 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
"""
self.async_copy_ready_event.synchronize()
# Release the device tensor once the copy has completed
# Release the device tensors once the copy has completed.
del self._logprobs_tensors
del self._sampled_token_ids
valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
@ -202,6 +210,10 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
output = self._model_runner_output
output.sampled_token_ids = valid_sampled_token_ids
if self._logprobs_tensors_cpu:
# NOTE(nick): this will need to be updated to use cu_num_accepted_tokens
# for async sched + spec decode + logprobs compatibility.
output.logprobs = self._logprobs_tensors_cpu.tolists()
return output
@ -2334,11 +2346,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
cu_num_accepted_tokens[-1] + len(sampled_ids)
)
# NOTE: GPU -> CPU Sync happens here.
# Move as many CPU operations as possible before this sync point.
logprobs_lists = (
logprobs_tensors.tolists(cu_num_accepted_tokens)
if logprobs_tensors is not None
if not self.use_async_scheduling and logprobs_tensors is not None
else None
)
@ -2664,6 +2674,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
async_output = AsyncGPUModelRunnerOutput(
model_runner_output=output,
sampled_token_ids=sampler_output.sampled_token_ids,
logprobs_tensors=sampler_output.logprobs_tensors,
invalid_req_indices=invalid_req_indices,
async_output_copy_stream=self.async_output_copy_stream,
)