[Core] Support logprobs with spec decode + async scheduling (#29223)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-11-25 12:55:24 -08:00 committed by GitHub
parent e7d776273d
commit 4e57c6587f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 35 additions and 25 deletions

View File

@ -87,6 +87,11 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
# Set small draft model len to force doesn't-fit-in-drafter case.
spec_config_short = spec_config | {"max_model_len": 50}
test_sampling_params = [
dict(),
dict(logprobs=2),
]
# test_preemption, executor, async_scheduling,
# spec_config, test_prefill_chunking
test_configs = [
@ -103,7 +108,7 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
(True, "uni", True, spec_config_short, True),
]
run_tests(monkeypatch, MTP_MODEL, test_configs, [{}])
run_tests(monkeypatch, MTP_MODEL, test_configs, test_sampling_params)
@dynamo_config.patch(cache_size_limit=16)

View File

@ -1089,8 +1089,6 @@ class Scheduler(SchedulerInterface):
and request.sampling_params.logprobs is not None
and logprobs
):
# NOTE: once we support N tokens per step (spec decode),
# the outer lists can be of length > 1.
new_logprobs = logprobs.slice(req_index, req_index + 1)
if new_token_ids and self.structured_output_manager.should_advance(request):

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from dataclasses import replace
import torch
@ -204,7 +205,9 @@ class RejectionSampler(nn.Module):
def parse_output(
output_token_ids: torch.Tensor,
vocab_size: int,
) -> list[list[int]]:
discard_req_indices: Sequence[int] = (),
return_cu_num_tokens: bool = False,
) -> tuple[list[list[int]], list[int] | None]:
"""Parse the output of the rejection sampler.
Args:
output_token_ids: The sampled token IDs in shape
@ -212,6 +215,8 @@ class RejectionSampler(nn.Module):
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
and will be filtered out in this function.
vocab_size: The size of the vocabulary.
discard_req_indices: Optional row indices to discard tokens in.
return_cu_num_tokens: Whether to also return cumulative token counts.
Returns:
A list of lists of token IDs.
"""
@ -220,10 +225,15 @@ class RejectionSampler(nn.Module):
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
output_token_ids_np < vocab_size
)
cu_num_tokens = None
if return_cu_num_tokens:
cu_num_tokens = [0] + valid_mask.sum(axis=1).cumsum().tolist()
if len(discard_req_indices) > 0:
valid_mask[discard_req_indices] = False
outputs = [
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
]
return outputs
return outputs, cu_num_tokens
def apply_logits_processors(
self,

View File

@ -183,7 +183,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
self,
model_runner_output: ModelRunnerOutput,
sampled_token_ids: torch.Tensor,
logprobs_tensors: torch.Tensor | None,
logprobs_tensors: LogprobsTensors | None,
invalid_req_indices: list[int],
async_output_copy_stream: torch.cuda.Stream,
vocab_size: int,
@ -219,28 +219,29 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
This function blocks until the copy is finished.
"""
max_gen_len = self.sampled_token_ids_cpu.shape[-1]
self.async_copy_ready_event.synchronize()
# Release the device tensors once the copy has completed.
del self._logprobs_tensors
del self._sampled_token_ids
max_gen_len = self.sampled_token_ids_cpu.shape[-1]
if max_gen_len == 1:
valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
for i in self._invalid_req_indices:
valid_sampled_token_ids[i].clear()
cu_num_tokens = None
else:
valid_sampled_token_ids = RejectionSampler.parse_output(
valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output(
self.sampled_token_ids_cpu,
self.vocab_size,
self._invalid_req_indices,
return_cu_num_tokens=self._logprobs_tensors_cpu is not None,
)
for i in self._invalid_req_indices:
valid_sampled_token_ids[i].clear()
output = self._model_runner_output
output.sampled_token_ids = valid_sampled_token_ids
if self._logprobs_tensors_cpu:
# NOTE(nick): this will need to be updated to use cu_num_accepted_tokens
# for async sched + spec decode + logprobs compatibility.
output.logprobs = self._logprobs_tensors_cpu.tolists()
output.logprobs = self._logprobs_tensors_cpu.tolists(cu_num_tokens)
return output
@ -2597,28 +2598,24 @@ class GPUModelRunner(
sampled_token_ids = sampler_output.sampled_token_ids
logprobs_tensors = sampler_output.logprobs_tensors
invalid_req_indices = []
cu_num_new_tokens: list[int] | None = None
cu_num_tokens: list[int] | None = None
if not self.use_async_scheduling:
# Get the valid generated tokens.
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
# No spec decode tokens.
valid_sampled_token_ids = self._to_list(sampled_token_ids)
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[int(i)].clear()
else:
# Includes spec decode tokens.
valid_sampled_token_ids = self.rejection_sampler.parse_output(
valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output(
sampled_token_ids,
self.input_batch.vocab_size,
discard_sampled_tokens_req_indices,
return_cu_num_tokens=logprobs_tensors is not None,
)
if logprobs_tensors:
# Needed for extracting logprobs when spec decoding.
# This must be done prior to discarding sampled tokens.
cu_num_new_tokens = [0]
for toks in valid_sampled_token_ids:
cu_num_new_tokens.append(cu_num_new_tokens[-1] + len(toks))
# Mask out the sampled tokens that should not be sampled.
for i in discard_sampled_tokens_req_indices:
valid_sampled_token_ids[int(i)].clear()
else:
valid_sampled_token_ids = []
invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
@ -2672,7 +2669,7 @@ class GPUModelRunner(
req_state.output_token_ids.extend(sampled_ids)
logprobs_lists = (
logprobs_tensors.tolists(cu_num_new_tokens)
logprobs_tensors.tolists(cu_num_tokens)
if not self.use_async_scheduling and logprobs_tensors is not None
else None
)