mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-09 09:41:19 +08:00
[Core] Support logprobs with spec decode + async scheduling (#29223)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
e7d776273d
commit
4e57c6587f
@ -87,6 +87,11 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
|
|||||||
# Set small draft model len to force doesn't-fit-in-drafter case.
|
# Set small draft model len to force doesn't-fit-in-drafter case.
|
||||||
spec_config_short = spec_config | {"max_model_len": 50}
|
spec_config_short = spec_config | {"max_model_len": 50}
|
||||||
|
|
||||||
|
test_sampling_params = [
|
||||||
|
dict(),
|
||||||
|
dict(logprobs=2),
|
||||||
|
]
|
||||||
|
|
||||||
# test_preemption, executor, async_scheduling,
|
# test_preemption, executor, async_scheduling,
|
||||||
# spec_config, test_prefill_chunking
|
# spec_config, test_prefill_chunking
|
||||||
test_configs = [
|
test_configs = [
|
||||||
@ -103,7 +108,7 @@ def test_with_spec_decoding(monkeypatch: pytest.MonkeyPatch):
|
|||||||
(True, "uni", True, spec_config_short, True),
|
(True, "uni", True, spec_config_short, True),
|
||||||
]
|
]
|
||||||
|
|
||||||
run_tests(monkeypatch, MTP_MODEL, test_configs, [{}])
|
run_tests(monkeypatch, MTP_MODEL, test_configs, test_sampling_params)
|
||||||
|
|
||||||
|
|
||||||
@dynamo_config.patch(cache_size_limit=16)
|
@dynamo_config.patch(cache_size_limit=16)
|
||||||
|
|||||||
@ -1089,8 +1089,6 @@ class Scheduler(SchedulerInterface):
|
|||||||
and request.sampling_params.logprobs is not None
|
and request.sampling_params.logprobs is not None
|
||||||
and logprobs
|
and logprobs
|
||||||
):
|
):
|
||||||
# NOTE: once we support N tokens per step (spec decode),
|
|
||||||
# the outer lists can be of length > 1.
|
|
||||||
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
||||||
|
|
||||||
if new_token_ids and self.structured_output_manager.should_advance(request):
|
if new_token_ids and self.structured_output_manager.should_advance(request):
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -204,7 +205,9 @@ class RejectionSampler(nn.Module):
|
|||||||
def parse_output(
|
def parse_output(
|
||||||
output_token_ids: torch.Tensor,
|
output_token_ids: torch.Tensor,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
) -> list[list[int]]:
|
discard_req_indices: Sequence[int] = (),
|
||||||
|
return_cu_num_tokens: bool = False,
|
||||||
|
) -> tuple[list[list[int]], list[int] | None]:
|
||||||
"""Parse the output of the rejection sampler.
|
"""Parse the output of the rejection sampler.
|
||||||
Args:
|
Args:
|
||||||
output_token_ids: The sampled token IDs in shape
|
output_token_ids: The sampled token IDs in shape
|
||||||
@ -212,6 +215,8 @@ class RejectionSampler(nn.Module):
|
|||||||
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
|
replaced with `PLACEHOLDER_TOKEN_ID` by the rejection sampler
|
||||||
and will be filtered out in this function.
|
and will be filtered out in this function.
|
||||||
vocab_size: The size of the vocabulary.
|
vocab_size: The size of the vocabulary.
|
||||||
|
discard_req_indices: Optional row indices to discard tokens in.
|
||||||
|
return_cu_num_tokens: Whether to also return cumulative token counts.
|
||||||
Returns:
|
Returns:
|
||||||
A list of lists of token IDs.
|
A list of lists of token IDs.
|
||||||
"""
|
"""
|
||||||
@ -220,10 +225,15 @@ class RejectionSampler(nn.Module):
|
|||||||
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
|
valid_mask = (output_token_ids_np != PLACEHOLDER_TOKEN_ID) & (
|
||||||
output_token_ids_np < vocab_size
|
output_token_ids_np < vocab_size
|
||||||
)
|
)
|
||||||
|
cu_num_tokens = None
|
||||||
|
if return_cu_num_tokens:
|
||||||
|
cu_num_tokens = [0] + valid_mask.sum(axis=1).cumsum().tolist()
|
||||||
|
if len(discard_req_indices) > 0:
|
||||||
|
valid_mask[discard_req_indices] = False
|
||||||
outputs = [
|
outputs = [
|
||||||
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
|
row[valid_mask[i]].tolist() for i, row in enumerate(output_token_ids_np)
|
||||||
]
|
]
|
||||||
return outputs
|
return outputs, cu_num_tokens
|
||||||
|
|
||||||
def apply_logits_processors(
|
def apply_logits_processors(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -183,7 +183,7 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
|
|||||||
self,
|
self,
|
||||||
model_runner_output: ModelRunnerOutput,
|
model_runner_output: ModelRunnerOutput,
|
||||||
sampled_token_ids: torch.Tensor,
|
sampled_token_ids: torch.Tensor,
|
||||||
logprobs_tensors: torch.Tensor | None,
|
logprobs_tensors: LogprobsTensors | None,
|
||||||
invalid_req_indices: list[int],
|
invalid_req_indices: list[int],
|
||||||
async_output_copy_stream: torch.cuda.Stream,
|
async_output_copy_stream: torch.cuda.Stream,
|
||||||
vocab_size: int,
|
vocab_size: int,
|
||||||
@ -219,28 +219,29 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
|
|||||||
|
|
||||||
This function blocks until the copy is finished.
|
This function blocks until the copy is finished.
|
||||||
"""
|
"""
|
||||||
|
max_gen_len = self.sampled_token_ids_cpu.shape[-1]
|
||||||
self.async_copy_ready_event.synchronize()
|
self.async_copy_ready_event.synchronize()
|
||||||
|
|
||||||
# Release the device tensors once the copy has completed.
|
# Release the device tensors once the copy has completed.
|
||||||
del self._logprobs_tensors
|
del self._logprobs_tensors
|
||||||
del self._sampled_token_ids
|
del self._sampled_token_ids
|
||||||
max_gen_len = self.sampled_token_ids_cpu.shape[-1]
|
|
||||||
if max_gen_len == 1:
|
if max_gen_len == 1:
|
||||||
valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
|
valid_sampled_token_ids = self.sampled_token_ids_cpu.tolist()
|
||||||
|
for i in self._invalid_req_indices:
|
||||||
|
valid_sampled_token_ids[i].clear()
|
||||||
|
cu_num_tokens = None
|
||||||
else:
|
else:
|
||||||
valid_sampled_token_ids = RejectionSampler.parse_output(
|
valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output(
|
||||||
self.sampled_token_ids_cpu,
|
self.sampled_token_ids_cpu,
|
||||||
self.vocab_size,
|
self.vocab_size,
|
||||||
|
self._invalid_req_indices,
|
||||||
|
return_cu_num_tokens=self._logprobs_tensors_cpu is not None,
|
||||||
)
|
)
|
||||||
for i in self._invalid_req_indices:
|
|
||||||
valid_sampled_token_ids[i].clear()
|
|
||||||
|
|
||||||
output = self._model_runner_output
|
output = self._model_runner_output
|
||||||
output.sampled_token_ids = valid_sampled_token_ids
|
output.sampled_token_ids = valid_sampled_token_ids
|
||||||
if self._logprobs_tensors_cpu:
|
if self._logprobs_tensors_cpu:
|
||||||
# NOTE(nick): this will need to be updated to use cu_num_accepted_tokens
|
output.logprobs = self._logprobs_tensors_cpu.tolists(cu_num_tokens)
|
||||||
# for async sched + spec decode + logprobs compatibility.
|
|
||||||
output.logprobs = self._logprobs_tensors_cpu.tolists()
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
@ -2597,28 +2598,24 @@ class GPUModelRunner(
|
|||||||
sampled_token_ids = sampler_output.sampled_token_ids
|
sampled_token_ids = sampler_output.sampled_token_ids
|
||||||
logprobs_tensors = sampler_output.logprobs_tensors
|
logprobs_tensors = sampler_output.logprobs_tensors
|
||||||
invalid_req_indices = []
|
invalid_req_indices = []
|
||||||
cu_num_new_tokens: list[int] | None = None
|
cu_num_tokens: list[int] | None = None
|
||||||
if not self.use_async_scheduling:
|
if not self.use_async_scheduling:
|
||||||
# Get the valid generated tokens.
|
# Get the valid generated tokens.
|
||||||
max_gen_len = sampled_token_ids.shape[-1]
|
max_gen_len = sampled_token_ids.shape[-1]
|
||||||
if max_gen_len == 1:
|
if max_gen_len == 1:
|
||||||
# No spec decode tokens.
|
# No spec decode tokens.
|
||||||
valid_sampled_token_ids = self._to_list(sampled_token_ids)
|
valid_sampled_token_ids = self._to_list(sampled_token_ids)
|
||||||
|
# Mask out the sampled tokens that should not be sampled.
|
||||||
|
for i in discard_sampled_tokens_req_indices:
|
||||||
|
valid_sampled_token_ids[int(i)].clear()
|
||||||
else:
|
else:
|
||||||
# Includes spec decode tokens.
|
# Includes spec decode tokens.
|
||||||
valid_sampled_token_ids = self.rejection_sampler.parse_output(
|
valid_sampled_token_ids, cu_num_tokens = RejectionSampler.parse_output(
|
||||||
sampled_token_ids,
|
sampled_token_ids,
|
||||||
self.input_batch.vocab_size,
|
self.input_batch.vocab_size,
|
||||||
|
discard_sampled_tokens_req_indices,
|
||||||
|
return_cu_num_tokens=logprobs_tensors is not None,
|
||||||
)
|
)
|
||||||
if logprobs_tensors:
|
|
||||||
# Needed for extracting logprobs when spec decoding.
|
|
||||||
# This must be done prior to discarding sampled tokens.
|
|
||||||
cu_num_new_tokens = [0]
|
|
||||||
for toks in valid_sampled_token_ids:
|
|
||||||
cu_num_new_tokens.append(cu_num_new_tokens[-1] + len(toks))
|
|
||||||
# Mask out the sampled tokens that should not be sampled.
|
|
||||||
for i in discard_sampled_tokens_req_indices:
|
|
||||||
valid_sampled_token_ids[int(i)].clear()
|
|
||||||
else:
|
else:
|
||||||
valid_sampled_token_ids = []
|
valid_sampled_token_ids = []
|
||||||
invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
|
invalid_req_indices = discard_sampled_tokens_req_indices.tolist()
|
||||||
@ -2672,7 +2669,7 @@ class GPUModelRunner(
|
|||||||
req_state.output_token_ids.extend(sampled_ids)
|
req_state.output_token_ids.extend(sampled_ids)
|
||||||
|
|
||||||
logprobs_lists = (
|
logprobs_lists = (
|
||||||
logprobs_tensors.tolists(cu_num_new_tokens)
|
logprobs_tensors.tolists(cu_num_tokens)
|
||||||
if not self.use_async_scheduling and logprobs_tensors is not None
|
if not self.use_async_scheduling and logprobs_tensors is not None
|
||||||
else None
|
else None
|
||||||
)
|
)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user