[Spec Decode] Make propose_draft_token_ids non-blocking for lower TTFT (#23041)

Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-08-18 17:20:38 -07:00 committed by GitHub
parent 0dd3f4f5ab
commit c9b38be8aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 100 additions and 64 deletions

View File

@ -22,7 +22,6 @@ def _make_model_runner_output(
for i, req_id in enumerate(req_ids) for i, req_id in enumerate(req_ids)
}, },
sampled_token_ids=[[i] for i in range(len(req_ids))], sampled_token_ids=[[i] for i in range(len(req_ids))],
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],

View File

@ -14,7 +14,7 @@ from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
KVCacheGroupSpec) KVCacheGroupSpec)
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.structured_output.request import StructuredOutputRequest from vllm.v1.structured_output.request import StructuredOutputRequest
@ -158,7 +158,6 @@ def test_schedule_partial_requests():
# Only the first request has a sampled token id because # Only the first request has a sampled token id because
# the rest requests are still being prefilled. # the rest requests are still being prefilled.
sampled_token_ids=[[0], [], []], sampled_token_ids=[[0], [], []],
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -209,7 +208,6 @@ def test_no_mm_input_chunking():
req_ids=[request.request_id for request in requests], req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
sampled_token_ids=[[] for _ in range(len(requests))], sampled_token_ids=[[] for _ in range(len(requests))],
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -273,7 +271,6 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
req_ids=[request.request_id for request in requests], req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
sampled_token_ids=[[] for _ in range(len(requests))], sampled_token_ids=[[] for _ in range(len(requests))],
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -298,7 +295,6 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
req_ids=[request.request_id for request in requests], req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)], sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -355,7 +351,6 @@ def test_stop_via_update_from_output():
sampled_token_ids=[[EOS_TOKEN_ID], sampled_token_ids=[[EOS_TOKEN_ID],
[10, [10,
11]], # First request hits EOS, second continues 11]], # First request hits EOS, second continues
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[]) pooler_output=[])
@ -409,7 +404,6 @@ def test_stop_via_update_from_output():
}, },
sampled_token_ids=[[10, 42, 12], sampled_token_ids=[[10, 42, 12],
[13, 14]], # First request hits stop token [13, 14]], # First request hits stop token
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[]) pooler_output=[])
@ -462,7 +456,6 @@ def test_stop_via_update_from_output():
}, },
sampled_token_ids=[[10, 11, 12], sampled_token_ids=[[10, 11, 12],
[13]], # First request exceeds max_tokens [13]], # First request exceeds max_tokens
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[]) pooler_output=[])
@ -505,7 +498,6 @@ def test_stop_via_update_from_output():
req_ids=[requests[0].request_id], req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0}, req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]], sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[]) pooler_output=[])
@ -554,7 +546,6 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
req_ids=[requests[0].request_id], req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0}, req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]], sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -572,7 +563,6 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
req_ids=[requests[1].request_id], req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0}, req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[0]], sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -608,7 +598,6 @@ def test_preempt_during_execution():
req_ids=[requests[0].request_id], req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0}, req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[0]], sampled_token_ids=[[0]],
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -626,7 +615,6 @@ def test_preempt_during_execution():
req_ids=[requests[1].request_id], req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0}, req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[[42]], sampled_token_ids=[[42]],
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -682,13 +670,14 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
req_ids=req_ids, req_ids=req_ids,
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
sampled_token_ids=[[0] for _ in range(len(requests))], sampled_token_ids=[[0] for _ in range(len(requests))],
spec_token_ids=spec_tokens,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
) )
engine_core_outputs = scheduler.update_from_output(output, engine_core_outputs = scheduler.update_from_output(output,
model_runner_output) model_runner_output)
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
scheduler.update_draft_token_ids(draft_token_ids)
for i in range(len(requests)): for i in range(len(requests)):
running_req = scheduler.running[i] running_req = scheduler.running[i]
@ -722,7 +711,6 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
req_ids=req_ids, req_ids=req_ids,
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
sampled_token_ids=output_tokens, sampled_token_ids=output_tokens,
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -851,7 +839,6 @@ def test_kv_connector_basic():
req_ids=req_ids, req_ids=req_ids,
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids), sampled_token_ids=[[1000]] * len(req_ids),
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -898,7 +885,6 @@ def test_kv_connector_basic():
req_ids=req_ids, req_ids=req_ids,
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids), sampled_token_ids=[[1000]] * len(req_ids),
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -966,7 +952,6 @@ def test_kv_connector_unable_to_allocate():
req_ids=req_ids, req_ids=req_ids,
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids), sampled_token_ids=[[1000]] * len(req_ids),
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -1048,7 +1033,6 @@ def test_kv_connector_handles_preemption():
req_ids=req_ids, req_ids=req_ids,
req_id_to_index=req_to_index, req_id_to_index=req_to_index,
sampled_token_ids=[[1000]] * len(req_ids), sampled_token_ids=[[1000]] * len(req_ids),
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -1142,7 +1126,6 @@ def make_output(scheduler: Scheduler):
for i, req in enumerate(scheduler.running) for i, req in enumerate(scheduler.running)
}, },
sampled_token_ids=[[1000]] * len(scheduler.running), sampled_token_ids=[[1000]] * len(scheduler.running),
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -1468,7 +1451,6 @@ def test_priority_scheduling_preemption():
for i, req in enumerate(low_priority_requests) for i, req in enumerate(low_priority_requests)
}, },
sampled_token_ids=[[100] for _ in low_priority_requests], sampled_token_ids=[[100] for _ in low_priority_requests],
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -1541,7 +1523,6 @@ def test_priority_scheduling_no_preemption_when_space_available():
for i, req in enumerate(low_priority_requests) for i, req in enumerate(low_priority_requests)
}, },
sampled_token_ids=[[100] for _ in low_priority_requests], sampled_token_ids=[[100] for _ in low_priority_requests],
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],
@ -1783,7 +1764,6 @@ def test_priority_scheduling_heap_property():
req_ids=[req.req_id], req_ids=[req.req_id],
req_id_to_index={req.req_id: 0}, req_id_to_index={req.req_id: 0},
sampled_token_ids=[[100]], sampled_token_ids=[[100]],
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],

View File

@ -200,7 +200,6 @@ def create_model_runner_output(
req_ids=req_ids, req_ids=req_ids,
req_id_to_index=req_id_to_index, req_id_to_index=req_id_to_index,
sampled_token_ids=sampled_token_ids, sampled_token_ids=sampled_token_ids,
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=None, pooler_output=None,

View File

@ -9,7 +9,7 @@ if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.engine import EngineCoreOutputs from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
@ -61,6 +61,14 @@ class SchedulerInterface(ABC):
""" """
raise NotImplementedError raise NotImplementedError
@abstractmethod
def update_draft_token_ids(
self,
draft_token_ids: "DraftTokenIds",
) -> None:
"""Update the draft token ids for the scheduled requests."""
raise NotImplementedError
@abstractmethod @abstractmethod
def add_request(self, request: "Request") -> None: def add_request(self, request: "Request") -> None:
"""Add a new request to the scheduler's internal queue. """Add a new request to the scheduler's internal queue.

View File

@ -30,7 +30,7 @@ from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
EngineCoreOutputs) EngineCoreOutputs)
from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
@ -141,7 +141,6 @@ class Scheduler(SchedulerInterface):
cache_size=encoder_cache_size) cache_size=encoder_cache_size)
speculative_config = vllm_config.speculative_config speculative_config = vllm_config.speculative_config
self.use_eagle = False self.use_eagle = False
self.num_spec_tokens = self.num_lookahead_tokens = 0 self.num_spec_tokens = self.num_lookahead_tokens = 0
if speculative_config: if speculative_config:
@ -760,7 +759,6 @@ class Scheduler(SchedulerInterface):
model_runner_output: ModelRunnerOutput, model_runner_output: ModelRunnerOutput,
) -> dict[int, EngineCoreOutputs]: ) -> dict[int, EngineCoreOutputs]:
sampled_token_ids = model_runner_output.sampled_token_ids sampled_token_ids = model_runner_output.sampled_token_ids
spec_token_ids = model_runner_output.spec_token_ids
logprobs = model_runner_output.logprobs logprobs = model_runner_output.logprobs
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
num_scheduled_tokens = scheduler_output.num_scheduled_tokens num_scheduled_tokens = scheduler_output.num_scheduled_tokens
@ -845,20 +843,9 @@ class Scheduler(SchedulerInterface):
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr] request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
req_id, new_token_ids) req_id, new_token_ids)
# spec_token_ids comes from the model runner output
if num_nans_in_logits is not None and req_id in num_nans_in_logits: if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id] request.num_nans_in_logits = num_nans_in_logits[req_id]
# Add newly generated spec token ids to the request.
if spec_token_ids is not None:
if self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
# Needs to happen after new_token_ids are accepted.
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
spec_token_ids[req_index])
else:
request.spec_token_ids = spec_token_ids[req_index]
# Get prompt logprobs for this request. # Get prompt logprobs for this request.
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id) prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
if new_token_ids or pooler_output is not None \ if new_token_ids or pooler_output is not None \
@ -963,6 +950,30 @@ class Scheduler(SchedulerInterface):
self.encoder_cache_manager.free_encoder_input( self.encoder_cache_manager.free_encoder_input(
request, input_id) request, input_id)
def update_draft_token_ids(
self,
draft_token_ids: DraftTokenIds,
) -> None:
for req_id, spec_token_ids in zip(
draft_token_ids.req_ids,
draft_token_ids.draft_token_ids,
):
request = self.requests.get(req_id)
if request is None or request.is_finished():
# The request may have been finished. Skip.
continue
# Add newly generated spec token ids to the request.
if not spec_token_ids:
# NOTE(woosuk): request.spec_token_ids should be updated.
request.spec_token_ids.clear()
elif self.structured_output_manager.should_advance(request):
metadata = request.structured_output_request
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
spec_token_ids)
else:
request.spec_token_ids = spec_token_ids
def get_request_counts(self) -> tuple[int, int]: def get_request_counts(self) -> tuple[int, int]:
"""Returns (num_running_reqs, num_waiting_reqs).""" """Returns (num_running_reqs, num_waiting_reqs)."""
return len(self.running), len(self.waiting) return len(self.running), len(self.waiting)

View File

@ -126,6 +126,7 @@ class EngineCore:
> 1, > 1,
log_stats=self.log_stats, log_stats=self.log_stats,
) )
self.use_spec_decode = vllm_config.speculative_config is not None
self.mm_input_cache_server = MultiModalInputCacheServer( self.mm_input_cache_server = MultiModalInputCacheServer(
vllm_config.model_config, MULTIMODAL_REGISTRY) vllm_config.model_config, MULTIMODAL_REGISTRY)
@ -294,6 +295,13 @@ class EngineCore:
return (engine_core_outputs, return (engine_core_outputs,
scheduler_output.total_num_scheduled_tokens > 0) scheduler_output.total_num_scheduled_tokens > 0)
def post_step(self, model_executed: bool) -> None:
if self.use_spec_decode and model_executed:
# Take the draft token ids.
draft_token_ids = self.model_executor.take_draft_token_ids()
if draft_token_ids is not None:
self.scheduler.update_draft_token_ids(draft_token_ids)
def step_with_batch_queue( def step_with_batch_queue(
self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]: self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
"""Schedule and execute batches with the batch queue. """Schedule and execute batches with the batch queue.
@ -746,6 +754,8 @@ class EngineCoreProc(EngineCore):
# Put EngineCoreOutputs into the output queue. # Put EngineCoreOutputs into the output queue.
for output in (outputs.items() if outputs else ()): for output in (outputs.items() if outputs else ()):
self.output_queue.put_nowait(output) self.output_queue.put_nowait(output)
# Post-step hook.
self.post_step(model_executed)
return model_executed return model_executed

View File

@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from concurrent.futures import Future from concurrent.futures import Future
from typing import Callable, Union from typing import Callable, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -14,7 +14,7 @@ from vllm.executor.uniproc_executor import ( # noqa
from vllm.executor.uniproc_executor import ( # noqa from vllm.executor.uniproc_executor import ( # noqa
UniProcExecutor as UniProcExecutorV0) UniProcExecutor as UniProcExecutorV0)
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
FailureCallback = Callable[[], None] FailureCallback = Callable[[], None]
@ -88,6 +88,10 @@ class Executor(ExecutorBase):
args=(scheduler_output, )) args=(scheduler_output, ))
return output[0] return output[0]
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
output = self.collective_rpc("take_draft_token_ids")
return output[0]
@property @property
def max_concurrent_batches(self) -> int: def max_concurrent_batches(self) -> int:
return 1 return 1

View File

@ -33,7 +33,7 @@ from vllm.utils import (decorate_logs, get_distributed_init_method,
get_loopback_ip, get_mp_context, get_open_port, get_loopback_ip, get_mp_context, get_open_port,
set_process_title) set_process_title)
from vllm.v1.executor.abstract import Executor, FailureCallback from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
logger = init_logger(__name__) logger = init_logger(__name__)
@ -191,6 +191,12 @@ class MultiprocExecutor(Executor):
outputs, self.output_rank) outputs, self.output_rank)
return self.kv_output_aggregator.aggregate(outputs, self.output_rank) return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
# OPTIMIZATION: Get output only from a single worker (output_rank)
outputs = self.collective_rpc("take_draft_token_ids",
unique_reply_rank=self.output_rank)
return outputs[0]
def collective_rpc(self, def collective_rpc(self,
method: Union[str, Callable], method: Union[str, Callable],
timeout: Optional[float] = None, timeout: Optional[float] = None,

View File

@ -94,9 +94,6 @@ class ModelRunnerOutput:
# each request due to speculative/jump decoding. # each request due to speculative/jump decoding.
sampled_token_ids: list[list[int]] sampled_token_ids: list[list[int]]
# num_reqs x num_spec_tokens
spec_token_ids: Optional[list[list[int]]]
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
# [num_reqs, max_num_logprobs + 1] # [num_reqs, max_num_logprobs + 1]
# [num_reqs] # [num_reqs]
@ -117,10 +114,18 @@ class ModelRunnerOutput:
num_nans_in_logits: Optional[dict[str, int]] = None num_nans_in_logits: Optional[dict[str, int]] = None
@dataclass
class DraftTokenIds:
# [num_reqs]
req_ids: list[str]
# num_reqs x num_draft_tokens
draft_token_ids: list[list[int]]
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[], EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
req_id_to_index={}, req_id_to_index={},
sampled_token_ids=[], sampled_token_ids=[],
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=[], pooler_output=[],

View File

@ -38,12 +38,14 @@ class MedusaProposer:
self, self,
target_hidden_states: torch.Tensor, target_hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata, sampling_metadata: SamplingMetadata,
) -> torch.Tensor: ) -> list[list[int]]:
# Generate blocks and compute logits # Generate blocks and compute logits
blocks = self.model(target_hidden_states) blocks = self.model(target_hidden_states)
logits = self.model.compute_logits(blocks, None) logits = self.model.compute_logits(blocks, None)
# Get draft tokens and transpose the result # Get draft tokens and transpose the result
# TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU
# synchronization.
draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits] draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits]
return [list(row) for row in zip(*draft_tokens)] return [list(row) for row in zip(*draft_tokens)]

View File

@ -65,8 +65,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
FullAttentionSpec, KVCacheConfig, FullAttentionSpec, KVCacheConfig,
KVCacheSpec, MambaSpec, KVCacheSpec, MambaSpec,
SlidingWindowSpec) SlidingWindowSpec)
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors, from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
ModelRunnerOutput) LogprobsTensors, ModelRunnerOutput)
from vllm.v1.pool.metadata import PoolingMetadata from vllm.v1.pool.metadata import PoolingMetadata
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
@ -348,6 +348,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.reorder_batch_threshold: Optional[int] = None self.reorder_batch_threshold: Optional[int] = None
# Cached outputs.
self._draft_token_ids: Optional[Union[list[list[int]],
torch.Tensor]] = None
def _init_model_kwargs(self, num_tokens: int): def _init_model_kwargs(self, num_tokens: int):
model_kwargs = dict[str, Any]() model_kwargs = dict[str, Any]()
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
@ -1493,7 +1497,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_ids=self.input_batch.req_ids, req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index, req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=[], sampled_token_ids=[],
spec_token_ids=None,
logprobs=None, logprobs=None,
prompt_logprobs_dict={}, prompt_logprobs_dict={},
pooler_output=pooler_output, pooler_output=pooler_output,
@ -1764,12 +1767,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_state = self.requests[req_id] req_state = self.requests[req_id]
req_state.output_token_ids.extend(sampled_ids) req_state.output_token_ids.extend(sampled_ids)
if not self.speculative_config: if self.speculative_config:
# Speculative decoding is not enabled.
spec_token_ids = None
else:
assert spec_decode_common_attn_metadata is not None assert spec_decode_common_attn_metadata is not None
spec_token_ids = self.propose_draft_token_ids( self._draft_token_ids = self.propose_draft_token_ids(
scheduler_output, scheduler_output,
valid_sampled_token_ids, valid_sampled_token_ids,
sampling_metadata, sampling_metadata,
@ -1786,7 +1786,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_ids=self.input_batch.req_ids, req_ids=self.input_batch.req_ids,
req_id_to_index=self.input_batch.req_id_to_index, req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids, sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=spec_token_ids,
logprobs=logprobs_lists, logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict, prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[], pooler_output=[],
@ -1794,6 +1793,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_nans_in_logits=num_nans_in_logits, num_nans_in_logits=num_nans_in_logits,
) )
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
if self._draft_token_ids is None:
return None
req_ids = self.input_batch.req_ids
if isinstance(self._draft_token_ids, torch.Tensor):
draft_token_ids = self._draft_token_ids.tolist()
else:
draft_token_ids = self._draft_token_ids
self._draft_token_ids = None
return DraftTokenIds(req_ids, draft_token_ids)
def propose_draft_token_ids( def propose_draft_token_ids(
self, self,
scheduler_output: "SchedulerOutput", scheduler_output: "SchedulerOutput",
@ -1804,11 +1814,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
aux_hidden_states: Optional[torch.Tensor], aux_hidden_states: Optional[torch.Tensor],
spec_decode_metadata: Optional[SpecDecodeMetadata], spec_decode_metadata: Optional[SpecDecodeMetadata],
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
) -> list[list[int]]: ) -> Union[list[list[int]], torch.Tensor]:
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
if self.speculative_config.method == "ngram": if self.speculative_config.method == "ngram":
assert isinstance(self.drafter, NgramProposer) assert isinstance(self.drafter, NgramProposer)
spec_token_ids = self.propose_ngram_draft_token_ids( draft_token_ids = self.propose_ngram_draft_token_ids(
sampled_token_ids) sampled_token_ids)
elif self.speculative_config.method == "medusa": elif self.speculative_config.method == "medusa":
assert isinstance(self.drafter, MedusaProposer) assert isinstance(self.drafter, MedusaProposer)
@ -1826,7 +1836,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
indices = torch.tensor(indices, device=self.device) indices = torch.tensor(indices, device=self.device)
hidden_states = sample_hidden_states[indices] hidden_states = sample_hidden_states[indices]
spec_token_ids = self.drafter.propose( draft_token_ids = self.drafter.propose(
target_hidden_states=hidden_states, target_hidden_states=hidden_states,
sampling_metadata=sampling_metadata, sampling_metadata=sampling_metadata,
) )
@ -1897,8 +1907,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
common_attn_metadata=common_attn_metadata, common_attn_metadata=common_attn_metadata,
mm_embeds=mm_embeds, mm_embeds=mm_embeds,
) )
spec_token_ids = draft_token_ids.tolist() return draft_token_ids
return spec_token_ids
def propose_ngram_draft_token_ids( def propose_ngram_draft_token_ids(
self, self,

View File

@ -28,7 +28,8 @@ from vllm.tasks import SupportedTask
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
ModelRunnerOutput)
from vllm.v1.utils import report_usage_stats from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.gpu_model_runner import GPUModelRunner from vllm.v1.worker.gpu_model_runner import GPUModelRunner
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import WorkerBase
@ -386,6 +387,9 @@ class Worker(WorkerBase):
assert isinstance(output, ModelRunnerOutput) assert isinstance(output, ModelRunnerOutput)
return output return output
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
return self.model_runner.take_draft_token_ids()
def profile(self, is_start: bool = True): def profile(self, is_start: bool = True):
if self.profiler is None: if self.profiler is None:
raise RuntimeError("Profiler is not enabled.") raise RuntimeError("Profiler is not enabled.")

View File

@ -1145,7 +1145,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
req_ids=req_ids, req_ids=req_ids,
req_id_to_index=self.input_batch.req_id_to_index, req_id_to_index=self.input_batch.req_id_to_index,
sampled_token_ids=valid_sampled_token_ids, sampled_token_ids=valid_sampled_token_ids,
spec_token_ids=None,
logprobs=logprobs_lists, logprobs=logprobs_lists,
prompt_logprobs_dict=prompt_logprobs_dict, prompt_logprobs_dict=prompt_logprobs_dict,
pooler_output=[], pooler_output=[],