From c9b38be8aafb02b69ccb704b33d2bb4329fbb0e6 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Mon, 18 Aug 2025 17:20:38 -0700 Subject: [PATCH] [Spec Decode] Make `propose_draft_token_ids` non-blocking for lower TTFT (#23041) Signed-off-by: Woosuk Kwon --- tests/v1/core/test_async_scheduler.py | 1 - tests/v1/core/test_scheduler.py | 26 ++--------------- tests/v1/kv_connector/unit/utils.py | 1 - vllm/v1/core/sched/interface.py | 10 ++++++- vllm/v1/core/sched/scheduler.py | 39 +++++++++++++++++--------- vllm/v1/engine/core.py | 10 +++++++ vllm/v1/executor/abstract.py | 8 ++++-- vllm/v1/executor/multiproc_executor.py | 8 +++++- vllm/v1/outputs.py | 13 ++++++--- vllm/v1/spec_decode/medusa.py | 4 ++- vllm/v1/worker/gpu_model_runner.py | 37 +++++++++++++++--------- vllm/v1/worker/gpu_worker.py | 6 +++- vllm/v1/worker/tpu_model_runner.py | 1 - 13 files changed, 100 insertions(+), 64 deletions(-) diff --git a/tests/v1/core/test_async_scheduler.py b/tests/v1/core/test_async_scheduler.py index 3a9492269f9c9..c153e38fe3df3 100644 --- a/tests/v1/core/test_async_scheduler.py +++ b/tests/v1/core/test_async_scheduler.py @@ -22,7 +22,6 @@ def _make_model_runner_output( for i, req_id in enumerate(req_ids) }, sampled_token_ids=[[i] for i in range(len(req_ids))], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], diff --git a/tests/v1/core/test_scheduler.py b/tests/v1/core/test_scheduler.py index 23762a0fb6223..070008fcbf59f 100644 --- a/tests/v1/core/test_scheduler.py +++ b/tests/v1/core/test_scheduler.py @@ -14,7 +14,7 @@ from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec) -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output.request import StructuredOutputRequest @@ -158,7 +158,6 @@ def test_schedule_partial_requests(): # Only the first request has a sampled token id because # the rest requests are still being prefilled. sampled_token_ids=[[0], [], []], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -209,7 +208,6 @@ def test_no_mm_input_chunking(): req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, sampled_token_ids=[[] for _ in range(len(requests))], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -273,7 +271,6 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, sampled_token_ids=[[] for _ in range(len(requests))], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -298,7 +295,6 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool): req_ids=[request.request_id for request in requests], req_id_to_index=req_to_index, sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -355,7 +351,6 @@ def test_stop_via_update_from_output(): sampled_token_ids=[[EOS_TOKEN_ID], [10, 11]], # First request hits EOS, second continues - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[]) @@ -409,7 +404,6 @@ def test_stop_via_update_from_output(): }, sampled_token_ids=[[10, 42, 12], [13, 14]], # First request hits stop token - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[]) @@ -462,7 +456,6 @@ def test_stop_via_update_from_output(): }, sampled_token_ids=[[10, 11, 12], [13]], # First request exceeds max_tokens - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[]) @@ -505,7 +498,6 @@ def test_stop_via_update_from_output(): req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[]) @@ -554,7 +546,6 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[0]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -572,7 +563,6 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool], req_ids=[requests[1].request_id], req_id_to_index={requests[1].request_id: 0}, sampled_token_ids=[[0]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -608,7 +598,6 @@ def test_preempt_during_execution(): req_ids=[requests[0].request_id], req_id_to_index={requests[0].request_id: 0}, sampled_token_ids=[[0]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -626,7 +615,6 @@ def test_preempt_during_execution(): req_ids=[requests[1].request_id], req_id_to_index={requests[1].request_id: 0}, sampled_token_ids=[[42]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -682,13 +670,14 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[0] for _ in range(len(requests))], - spec_token_ids=spec_tokens, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], ) engine_core_outputs = scheduler.update_from_output(output, model_runner_output) + draft_token_ids = DraftTokenIds(req_ids, spec_tokens) + scheduler.update_draft_token_ids(draft_token_ids) for i in range(len(requests)): running_req = scheduler.running[i] @@ -722,7 +711,6 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=output_tokens, - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -851,7 +839,6 @@ def test_kv_connector_basic(): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -898,7 +885,6 @@ def test_kv_connector_basic(): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -966,7 +952,6 @@ def test_kv_connector_unable_to_allocate(): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1048,7 +1033,6 @@ def test_kv_connector_handles_preemption(): req_ids=req_ids, req_id_to_index=req_to_index, sampled_token_ids=[[1000]] * len(req_ids), - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1142,7 +1126,6 @@ def make_output(scheduler: Scheduler): for i, req in enumerate(scheduler.running) }, sampled_token_ids=[[1000]] * len(scheduler.running), - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1468,7 +1451,6 @@ def test_priority_scheduling_preemption(): for i, req in enumerate(low_priority_requests) }, sampled_token_ids=[[100] for _ in low_priority_requests], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1541,7 +1523,6 @@ def test_priority_scheduling_no_preemption_when_space_available(): for i, req in enumerate(low_priority_requests) }, sampled_token_ids=[[100] for _ in low_priority_requests], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], @@ -1783,7 +1764,6 @@ def test_priority_scheduling_heap_property(): req_ids=[req.req_id], req_id_to_index={req.req_id: 0}, sampled_token_ids=[[100]], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], diff --git a/tests/v1/kv_connector/unit/utils.py b/tests/v1/kv_connector/unit/utils.py index 8c5d132c00ae4..a47f583b329e2 100644 --- a/tests/v1/kv_connector/unit/utils.py +++ b/tests/v1/kv_connector/unit/utils.py @@ -200,7 +200,6 @@ def create_model_runner_output( req_ids=req_ids, req_id_to_index=req_id_to_index, sampled_token_ids=sampled_token_ids, - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=None, diff --git a/vllm/v1/core/sched/interface.py b/vllm/v1/core/sched/interface.py index dd5052a3480b7..5b1de3a66ceb4 100644 --- a/vllm/v1/core/sched/interface.py +++ b/vllm/v1/core/sched/interface.py @@ -9,7 +9,7 @@ if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.engine import EngineCoreOutputs from vllm.v1.metrics.stats import SchedulerStats - from vllm.v1.outputs import ModelRunnerOutput + from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus @@ -61,6 +61,14 @@ class SchedulerInterface(ABC): """ raise NotImplementedError + @abstractmethod + def update_draft_token_ids( + self, + draft_token_ids: "DraftTokenIds", + ) -> None: + """Update the draft token ids for the scheduled requests.""" + raise NotImplementedError + @abstractmethod def add_request(self, request: "Request") -> None: """Add a new request to the scheduler's internal queue. diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 9810234090453..b3defa443186e 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -30,7 +30,7 @@ from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs) from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.metrics.stats import SchedulerStats -from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput +from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput from vllm.v1.request import Request, RequestStatus from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.structured_output import StructuredOutputManager @@ -141,7 +141,6 @@ class Scheduler(SchedulerInterface): cache_size=encoder_cache_size) speculative_config = vllm_config.speculative_config - self.use_eagle = False self.num_spec_tokens = self.num_lookahead_tokens = 0 if speculative_config: @@ -760,7 +759,6 @@ class Scheduler(SchedulerInterface): model_runner_output: ModelRunnerOutput, ) -> dict[int, EngineCoreOutputs]: sampled_token_ids = model_runner_output.sampled_token_ids - spec_token_ids = model_runner_output.spec_token_ids logprobs = model_runner_output.logprobs prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict num_scheduled_tokens = scheduler_output.num_scheduled_tokens @@ -845,20 +843,9 @@ class Scheduler(SchedulerInterface): request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] req_id, new_token_ids) - # spec_token_ids comes from the model runner output if num_nans_in_logits is not None and req_id in num_nans_in_logits: request.num_nans_in_logits = num_nans_in_logits[req_id] - # Add newly generated spec token ids to the request. - if spec_token_ids is not None: - if self.structured_output_manager.should_advance(request): - metadata = request.structured_output_request - # Needs to happen after new_token_ids are accepted. - request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] - spec_token_ids[req_index]) - else: - request.spec_token_ids = spec_token_ids[req_index] - # Get prompt logprobs for this request. prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) if new_token_ids or pooler_output is not None \ @@ -963,6 +950,30 @@ class Scheduler(SchedulerInterface): self.encoder_cache_manager.free_encoder_input( request, input_id) + def update_draft_token_ids( + self, + draft_token_ids: DraftTokenIds, + ) -> None: + for req_id, spec_token_ids in zip( + draft_token_ids.req_ids, + draft_token_ids.draft_token_ids, + ): + request = self.requests.get(req_id) + if request is None or request.is_finished(): + # The request may have been finished. Skip. + continue + + # Add newly generated spec token ids to the request. + if not spec_token_ids: + # NOTE(woosuk): request.spec_token_ids should be updated. + request.spec_token_ids.clear() + elif self.structured_output_manager.should_advance(request): + metadata = request.structured_output_request + request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr] + spec_token_ids) + else: + request.spec_token_ids = spec_token_ids + def get_request_counts(self) -> tuple[int, int]: """Returns (num_running_reqs, num_waiting_reqs).""" return len(self.running), len(self.waiting) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 1e52f93a581b3..32765cda6482f 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -126,6 +126,7 @@ class EngineCore: > 1, log_stats=self.log_stats, ) + self.use_spec_decode = vllm_config.speculative_config is not None self.mm_input_cache_server = MultiModalInputCacheServer( vllm_config.model_config, MULTIMODAL_REGISTRY) @@ -294,6 +295,13 @@ class EngineCore: return (engine_core_outputs, scheduler_output.total_num_scheduled_tokens > 0) + def post_step(self, model_executed: bool) -> None: + if self.use_spec_decode and model_executed: + # Take the draft token ids. + draft_token_ids = self.model_executor.take_draft_token_ids() + if draft_token_ids is not None: + self.scheduler.update_draft_token_ids(draft_token_ids) + def step_with_batch_queue( self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]: """Schedule and execute batches with the batch queue. @@ -746,6 +754,8 @@ class EngineCoreProc(EngineCore): # Put EngineCoreOutputs into the output queue. for output in (outputs.items() if outputs else ()): self.output_queue.put_nowait(output) + # Post-step hook. + self.post_step(model_executed) return model_executed diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 50b9634a49e1b..063a5f592e1a0 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -2,7 +2,7 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from concurrent.futures import Future -from typing import Callable, Union +from typing import Callable, Optional, Union import torch import torch.distributed as dist @@ -14,7 +14,7 @@ from vllm.executor.uniproc_executor import ( # noqa from vllm.executor.uniproc_executor import ( # noqa UniProcExecutor as UniProcExecutorV0) from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput FailureCallback = Callable[[], None] @@ -88,6 +88,10 @@ class Executor(ExecutorBase): args=(scheduler_output, )) return output[0] + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + output = self.collective_rpc("take_draft_token_ids") + return output[0] + @property def max_concurrent_batches(self) -> int: return 1 diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 0db3bcd7fb408..15b88a2128994 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -33,7 +33,7 @@ from vllm.utils import (decorate_logs, get_distributed_init_method, get_loopback_ip, get_mp_context, get_open_port, set_process_title) from vllm.v1.executor.abstract import Executor, FailureCallback -from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -191,6 +191,12 @@ class MultiprocExecutor(Executor): outputs, self.output_rank) return self.kv_output_aggregator.aggregate(outputs, self.output_rank) + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + # OPTIMIZATION: Get output only from a single worker (output_rank) + outputs = self.collective_rpc("take_draft_token_ids", + unique_reply_rank=self.output_rank) + return outputs[0] + def collective_rpc(self, method: Union[str, Callable], timeout: Optional[float] = None, diff --git a/vllm/v1/outputs.py b/vllm/v1/outputs.py index 7d7cd0c94dd04..f8d6b24702f3c 100644 --- a/vllm/v1/outputs.py +++ b/vllm/v1/outputs.py @@ -94,9 +94,6 @@ class ModelRunnerOutput: # each request due to speculative/jump decoding. sampled_token_ids: list[list[int]] - # num_reqs x num_spec_tokens - spec_token_ids: Optional[list[list[int]]] - # [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1] # [num_reqs] @@ -117,10 +114,18 @@ class ModelRunnerOutput: num_nans_in_logits: Optional[dict[str, int]] = None +@dataclass +class DraftTokenIds: + + # [num_reqs] + req_ids: list[str] + # num_reqs x num_draft_tokens + draft_token_ids: list[list[int]] + + EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], req_id_to_index={}, sampled_token_ids=[], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=[], diff --git a/vllm/v1/spec_decode/medusa.py b/vllm/v1/spec_decode/medusa.py index 309fd926aecd7..3e90179e78d99 100644 --- a/vllm/v1/spec_decode/medusa.py +++ b/vllm/v1/spec_decode/medusa.py @@ -38,12 +38,14 @@ class MedusaProposer: self, target_hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, - ) -> torch.Tensor: + ) -> list[list[int]]: # Generate blocks and compute logits blocks = self.model(target_hidden_states) logits = self.model.compute_logits(blocks, None) # Get draft tokens and transpose the result + # TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU + # synchronization. draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits] return [list(row) for row in zip(*draft_tokens)] diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 43119fcad3fb3..9b0345a6aa3ad 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -65,8 +65,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheSpec, MambaSpec, SlidingWindowSpec) -from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, - ModelRunnerOutput) +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, + LogprobsTensors, ModelRunnerOutput) from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.metadata import SamplingMetadata @@ -348,6 +348,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): self.reorder_batch_threshold: Optional[int] = None + # Cached outputs. + self._draft_token_ids: Optional[Union[list[list[int]], + torch.Tensor]] = None + def _init_model_kwargs(self, num_tokens: int): model_kwargs = dict[str, Any]() num_reqs = self.input_batch.num_reqs @@ -1493,7 +1497,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=[], - spec_token_ids=None, logprobs=None, prompt_logprobs_dict={}, pooler_output=pooler_output, @@ -1764,12 +1767,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_state = self.requests[req_id] req_state.output_token_ids.extend(sampled_ids) - if not self.speculative_config: - # Speculative decoding is not enabled. - spec_token_ids = None - else: + if self.speculative_config: assert spec_decode_common_attn_metadata is not None - spec_token_ids = self.propose_draft_token_ids( + self._draft_token_ids = self.propose_draft_token_ids( scheduler_output, valid_sampled_token_ids, sampling_metadata, @@ -1786,7 +1786,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_ids=self.input_batch.req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=spec_token_ids, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[], @@ -1794,6 +1793,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): num_nans_in_logits=num_nans_in_logits, ) + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + if self._draft_token_ids is None: + return None + req_ids = self.input_batch.req_ids + if isinstance(self._draft_token_ids, torch.Tensor): + draft_token_ids = self._draft_token_ids.tolist() + else: + draft_token_ids = self._draft_token_ids + self._draft_token_ids = None + return DraftTokenIds(req_ids, draft_token_ids) + def propose_draft_token_ids( self, scheduler_output: "SchedulerOutput", @@ -1804,11 +1814,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): aux_hidden_states: Optional[torch.Tensor], spec_decode_metadata: Optional[SpecDecodeMetadata], common_attn_metadata: CommonAttentionMetadata, - ) -> list[list[int]]: + ) -> Union[list[list[int]], torch.Tensor]: num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens if self.speculative_config.method == "ngram": assert isinstance(self.drafter, NgramProposer) - spec_token_ids = self.propose_ngram_draft_token_ids( + draft_token_ids = self.propose_ngram_draft_token_ids( sampled_token_ids) elif self.speculative_config.method == "medusa": assert isinstance(self.drafter, MedusaProposer) @@ -1826,7 +1836,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): indices = torch.tensor(indices, device=self.device) hidden_states = sample_hidden_states[indices] - spec_token_ids = self.drafter.propose( + draft_token_ids = self.drafter.propose( target_hidden_states=hidden_states, sampling_metadata=sampling_metadata, ) @@ -1897,8 +1907,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): common_attn_metadata=common_attn_metadata, mm_embeds=mm_embeds, ) - spec_token_ids = draft_token_ids.tolist() - return spec_token_ids + return draft_token_ids def propose_ngram_draft_token_ids( self, diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 04de8d36680a4..22e639b97d09c 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -28,7 +28,8 @@ from vllm.tasks import SupportedTask from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec -from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput +from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds, + ModelRunnerOutput) from vllm.v1.utils import report_usage_stats from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.worker_base import WorkerBase @@ -386,6 +387,9 @@ class Worker(WorkerBase): assert isinstance(output, ModelRunnerOutput) return output + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: + return self.model_runner.take_draft_token_ids() + def profile(self, is_start: bool = True): if self.profiler is None: raise RuntimeError("Profiler is not enabled.") diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index af837e4d946e6..9196c62377b91 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -1145,7 +1145,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin): req_ids=req_ids, req_id_to_index=self.input_batch.req_id_to_index, sampled_token_ids=valid_sampled_token_ids, - spec_token_ids=None, logprobs=logprobs_lists, prompt_logprobs_dict=prompt_logprobs_dict, pooler_output=[],