[Spec Decode] Make propose_draft_token_ids non-blocking for lower TTFT (#23041)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-08-18 17:20:38 -07:00 committed by GitHub
parent 0dd3f4f5ab
commit c9b38be8aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 100 additions and 64 deletions

View File

@ -22,7 +22,6 @@ def _make_model_runner_output(
for i, req_id in enumerate(req_ids)
},
sampled_token_ids=[[i] for i in range(len(req_ids))],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],

View File

@ -14,7 +14,7 @@ from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.structured_output.request import StructuredOutputRequest
@ -158,7 +158,6 @@ def test_schedule_partial_requests():
# Only the first request has a sampled token id because
# the rest requests are still being prefilled.
sampled_token_ids=[[0], [], []],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@ -209,7 +208,6 @@ def test_no_mm_input_chunking():
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[[] for _ in range(len(requests))],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@ -273,7 +271,6 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[[] for _ in range(len(requests))],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@ -298,7 +295,6 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@ -355,7 +351,6 @@ def test_stop_via_update_from_output():
sampled_token_ids=[[EOS_TOKEN_ID],
[10,
11]], # First request hits EOS, second continues
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
@ -409,7 +404,6 @@ def test_stop_via_update_from_output():
},
sampled_token_ids=[[10, 42, 12],
[13, 14]], # First request hits stop token
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
@ -462,7 +456,6 @@ def test_stop_via_update_from_output():
},
sampled_token_ids=[[10, 11, 12],
[13]], # First request exceeds max_tokens
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
@ -505,7 +498,6 @@ def test_stop_via_update_from_output():
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[])
@ -554,7 +546,6 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@ -572,7 +563,6 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@ -608,7 +598,6 @@ def test_preempt_during_execution():
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@ -626,7 +615,6 @@ def test_preempt_during_execution():
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[42]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@ -682,13 +670,14 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))],
spec_token_ids=spec_tokens,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
)
engine_core_outputs = scheduler.update_from_output(output,
model_runner_output)
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
scheduler.update_draft_token_ids(draft_token_ids)
for i in range(len(requests)):
running_req = scheduler.running[i]
@ -722,7 +711,6 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=output_tokens,
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@ -851,7 +839,6 @@ def test_kv_connector_basic():
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@ -898,7 +885,6 @@ def test_kv_connector_basic():
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@ -966,7 +952,6 @@ def test_kv_connector_unable_to_allocate():
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@ -1048,7 +1033,6 @@ def test_kv_connector_handles_preemption():
req_ids=req_ids,
req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@ -1142,7 +1126,6 @@ def make_output(scheduler: Scheduler):
for i, req in enumerate(scheduler.running)
},
sampled_token_ids=[[1000]] * len(scheduler.running),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@ -1468,7 +1451,6 @@ def test_priority_scheduling_preemption():
for i, req in enumerate(low_priority_requests)
},
sampled_token_ids=[[100] for _ in low_priority_requests],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@ -1541,7 +1523,6 @@ def test_priority_scheduling_no_preemption_when_space_available():
for i, req in enumerate(low_priority_requests)
},
sampled_token_ids=[[100] for _ in low_priority_requests],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],
@ -1783,7 +1764,6 @@ def test_priority_scheduling_heap_property():
req_ids=[req.req_id],
req_id_to_index={req.req_id: 0},
sampled_token_ids=[[100]],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],

View File

@ -200,7 +200,6 @@ def create_model_runner_output(
req_ids=req_ids,
req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids,
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=None,

View File

@ -9,7 +9,7 @@ if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
@ -61,6 +61,14 @@ class SchedulerInterface(ABC):
"""
raise NotImplementedError
@abstractmethod
def update_draft_token_ids(
self,
draft_token_ids: "DraftTokenIds",
) -> None:
"""Update the draft token ids for the scheduled requests."""
raise NotImplementedError
@abstractmethod
def add_request(self, request: "Request") -> None:
"""Add a new request to the scheduler's internal queue.

View File

@ -30,7 +30,7 @@ from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
EngineCoreOutputs)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
@ -141,7 +141,6 @@ class Scheduler(SchedulerInterface):
cache_size=encoder_cache_size)
speculative_config = vllm_config.speculative_config
self.use_eagle = False
self.num_spec_tokens = self.num_lookahead_tokens = 0
if speculative_config:
@ -760,7 +759,6 @@ class Scheduler(SchedulerInterface):
model_runner_output: ModelRunnerOutput,
) -> dict[int, EngineCoreOutputs]:
sampled_token_ids = model_runner_output.sampled_token_ids
spec_token_ids = model_runner_output.spec_token_ids
logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
@ -845,20 +843,9 @@ class Scheduler(SchedulerInterface):
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
req_id, new_token_ids)
# spec_token_ids comes from the model runner output
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id]
# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
if self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
# Needs to happen after new_token_ids are accepted.
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
spec_token_ids[req_index])
else:
request.spec_token_ids = spec_token_ids[req_index]
# Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids or pooler_output is not None \
@ -963,6 +950,30 @@ class Scheduler(SchedulerInterface):
self.encoder_cache_manager.free_encoder_input(
request, input_id)
def update_draft_token_ids(
self,
draft_token_ids: DraftTokenIds,
) -> None:
for req_id, spec_token_ids in zip(
draft_token_ids.req_ids,
draft_token_ids.draft_token_ids,
):
request = self.requests.get(req_id)
if request is None or request.is_finished():
# The request may have been finished. Skip.
continue
# Add newly generated spec token ids to the request.
if not spec_token_ids:
# NOTE(woosuk): request.spec_token_ids should be updated.
request.spec_token_ids.clear()
elif self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
spec_token_ids)
else:
request.spec_token_ids = spec_token_ids
def get_request_counts(self) -> tuple[int, int]:
"""Returns (num_running_reqs, num_waiting_reqs)."""
return len(self.running), len(self.waiting)

View File

@ -126,6 +126,7 @@ class EngineCore:
> 1,
log_stats=self.log_stats,
)
self.use_spec_decode = vllm_config.speculative_config is not None
self.mm_input_cache_server = MultiModalInputCacheServer(
vllm_config.model_config, MULTIMODAL_REGISTRY)
@ -294,6 +295,13 @@ class EngineCore:
return (engine_core_outputs,
scheduler_output.total_num_scheduled_tokens > 0)
def post_step(self, model_executed: bool) -> None:
if self.use_spec_decode and model_executed:
# Take the draft token ids.
draft_token_ids = self.model_executor.take_draft_token_ids()
if draft_token_ids is not None:
self.scheduler.update_draft_token_ids(draft_token_ids)
def step_with_batch_queue(
self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
"""Schedule and execute batches with the batch queue.
@ -746,6 +754,8 @@ class EngineCoreProc(EngineCore):
# Put EngineCoreOutputs into the output queue.
for output in (outputs.items() if outputs else ()):
self.output_queue.put_nowait(output)
# Post-step hook.
self.post_step(model_executed)
return model_executed

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from concurrent.futures import Future
from typing import Callable, Union
from typing import Callable, Optional, Union
import torch
import torch.distributed as dist
@ -14,7 +14,7 @@ from vllm.executor.uniproc_executor import ( # noqa
from vllm.executor.uniproc_executor import ( # noqa
UniProcExecutor as UniProcExecutorV0)
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
FailureCallback = Callable[[], None]
@ -88,6 +88,10 @@ class Executor(ExecutorBase):
args=(scheduler_output, ))
return output[0]
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
output = self.collective_rpc("take_draft_token_ids")
return output[0]
@property
def max_concurrent_batches(self) -> int:
return 1

View File

@ -33,7 +33,7 @@ from vllm.utils import (decorate_logs, get_distributed_init_method,
get_loopback_ip, get_mp_context, get_open_port,
set_process_title)
from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__)
@ -191,6 +191,12 @@ class MultiprocExecutor(Executor):
outputs, self.output_rank)
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
# OPTIMIZATION: Get output only from a single worker (output_rank)
outputs = self.collective_rpc("take_draft_token_ids",
unique_reply_rank=self.output_rank)
return outputs[0]
def collective_rpc(self,
method: Union[str, Callable],
timeout: Optional[float] = None,

View File

@ -94,9 +94,6 @@ class ModelRunnerOutput:
# each request due to speculative/jump decoding.
sampled_token_ids: list[list[int]]
# num_reqs x num_spec_tokens
spec_token_ids: Optional[list[list[int]]]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1]
# [num_reqs]
@ -117,10 +114,18 @@ class ModelRunnerOutput:
num_nans_in_logits: Optional[dict[str, int]] = None
@dataclass
class DraftTokenIds:
# [num_reqs]
req_ids: list[str]
# num_reqs x num_draft_tokens
draft_token_ids: list[list[int]]
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
req_id_to_index={},
sampled_token_ids=[],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=[],

View File

@ -38,12 +38,14 @@ class MedusaProposer:
self,
target_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
) -> list[list[int]]:
# Generate blocks and compute logits
blocks = self.model(target_hidden_states)
logits = self.model.compute_logits(blocks, None)
# Get draft tokens and transpose the result
# TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU
# synchronization.
draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits]
return [list(row) for row in zip(*draft_tokens)]

View File

@ -65,8 +65,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
FullAttentionSpec, KVCacheConfig,
KVCacheSpec, MambaSpec,
SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
ModelRunnerOutput)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
LogprobsTensors, ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
from vllm.v1.sample.metadata import SamplingMetadata
@ -348,6 +348,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.reorder_batch_threshold: Optional[int] = None
# Cached outputs.
self._draft_token_ids: Optional[Union[list[list[int]],
torch.Tensor]] = None
def _init_model_kwargs(self, num_tokens: int):
model_kwargs = dict[str, Any]()
num_reqs = self.input_batch.num_reqs
@ -1493,7 +1497,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=[],
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
pooler_output=pooler_output,
@ -1764,12 +1767,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids)
if not self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
else:
if self.speculative_config:
assert spec_decode_common_attn_metadata is not None
spec_token_ids = self.propose_draft_token_ids(
self._draft_token_ids = self.propose_draft_token_ids(
scheduler_output,
valid_sampled_token_ids,
sampling_metadata,
@ -1786,7 +1786,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],
@ -1794,6 +1793,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_nans_in_logits=num_nans_in_logits,
)
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
if self._draft_token_ids is None:
return None
req_ids = self.input_batch.req_ids
if isinstance(self._draft_token_ids, torch.Tensor):
draft_token_ids = self._draft_token_ids.tolist()
else:
draft_token_ids = self._draft_token_ids
self._draft_token_ids = None
return DraftTokenIds(req_ids, draft_token_ids)
def propose_draft_token_ids(
self,
scheduler_output: "SchedulerOutput",
@ -1804,11 +1814,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
aux_hidden_states: Optional[torch.Tensor],
spec_decode_metadata: Optional[SpecDecodeMetadata],
common_attn_metadata: CommonAttentionMetadata,
) -> list[list[int]]:
) -> Union[list[list[int]], torch.Tensor]:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if self.speculative_config.method == "ngram":
assert isinstance(self.drafter, NgramProposer)
spec_token_ids = self.propose_ngram_draft_token_ids(
draft_token_ids = self.propose_ngram_draft_token_ids(
sampled_token_ids)
elif self.speculative_config.method == "medusa":
assert isinstance(self.drafter, MedusaProposer)
@ -1826,7 +1836,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
indices = torch.tensor(indices, device=self.device)
hidden_states = sample_hidden_states[indices]
spec_token_ids = self.drafter.propose(
draft_token_ids = self.drafter.propose(
target_hidden_states=hidden_states,
sampling_metadata=sampling_metadata,
)
@ -1897,8 +1907,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
common_attn_metadata=common_attn_metadata,
mm_embeds=mm_embeds,
)
spec_token_ids = draft_token_ids.tolist()
return spec_token_ids
return draft_token_ids
def propose_ngram_draft_token_ids(
self,

View File

@ -28,7 +28,8 @@ from vllm.tasks import SupportedTask
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
ModelRunnerOutput)
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase
@ -386,6 +387,9 @@ class Worker(WorkerBase):
assert isinstance(output, ModelRunnerOutput)
return output
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
return self.model_runner.take_draft_token_ids()
def profile(self, is_start: bool = True):
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")

View File

@ -1145,7 +1145,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_ids=req_ids,
req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=None,
logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[],