mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-24 02:24:29 +08:00
[Spec Decode] Make propose_draft_token_ids non-blocking for lower TTFT (#23041)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
0dd3f4f5ab
commit
c9b38be8aa
@ -22,7 +22,6 @@ def _make_model_runner_output(
|
||||
for i, req_id in enumerate(req_ids)
|
||||
},
|
||||
sampled_token_ids=[[i] for i in range(len(req_ids))],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
|
||||
@ -14,7 +14,7 @@ from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheGroupSpec)
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
from vllm.v1.structured_output.request import StructuredOutputRequest
|
||||
@ -158,7 +158,6 @@ def test_schedule_partial_requests():
|
||||
# Only the first request has a sampled token id because
|
||||
# the rest requests are still being prefilled.
|
||||
sampled_token_ids=[[0], [], []],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -209,7 +208,6 @@ def test_no_mm_input_chunking():
|
||||
req_ids=[request.request_id for request in requests],
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[] for _ in range(len(requests))],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -273,7 +271,6 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
||||
req_ids=[request.request_id for request in requests],
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[] for _ in range(len(requests))],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -298,7 +295,6 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
||||
req_ids=[request.request_id for request in requests],
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -355,7 +351,6 @@ def test_stop_via_update_from_output():
|
||||
sampled_token_ids=[[EOS_TOKEN_ID],
|
||||
[10,
|
||||
11]], # First request hits EOS, second continues
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
@ -409,7 +404,6 @@ def test_stop_via_update_from_output():
|
||||
},
|
||||
sampled_token_ids=[[10, 42, 12],
|
||||
[13, 14]], # First request hits stop token
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
@ -462,7 +456,6 @@ def test_stop_via_update_from_output():
|
||||
},
|
||||
sampled_token_ids=[[10, 11, 12],
|
||||
[13]], # First request exceeds max_tokens
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
@ -505,7 +498,6 @@ def test_stop_via_update_from_output():
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[])
|
||||
@ -554,7 +546,6 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -572,7 +563,6 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
|
||||
req_ids=[requests[1].request_id],
|
||||
req_id_to_index={requests[1].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -608,7 +598,6 @@ def test_preempt_during_execution():
|
||||
req_ids=[requests[0].request_id],
|
||||
req_id_to_index={requests[0].request_id: 0},
|
||||
sampled_token_ids=[[0]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -626,7 +615,6 @@ def test_preempt_during_execution():
|
||||
req_ids=[requests[1].request_id],
|
||||
req_id_to_index={requests[1].request_id: 0},
|
||||
sampled_token_ids=[[42]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -682,13 +670,14 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||
spec_token_ids=spec_tokens,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
)
|
||||
engine_core_outputs = scheduler.update_from_output(output,
|
||||
model_runner_output)
|
||||
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
|
||||
scheduler.update_draft_token_ids(draft_token_ids)
|
||||
|
||||
for i in range(len(requests)):
|
||||
running_req = scheduler.running[i]
|
||||
@ -722,7 +711,6 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=output_tokens,
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -851,7 +839,6 @@ def test_kv_connector_basic():
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[1000]] * len(req_ids),
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -898,7 +885,6 @@ def test_kv_connector_basic():
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[1000]] * len(req_ids),
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -966,7 +952,6 @@ def test_kv_connector_unable_to_allocate():
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[1000]] * len(req_ids),
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -1048,7 +1033,6 @@ def test_kv_connector_handles_preemption():
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_to_index,
|
||||
sampled_token_ids=[[1000]] * len(req_ids),
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -1142,7 +1126,6 @@ def make_output(scheduler: Scheduler):
|
||||
for i, req in enumerate(scheduler.running)
|
||||
},
|
||||
sampled_token_ids=[[1000]] * len(scheduler.running),
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -1468,7 +1451,6 @@ def test_priority_scheduling_preemption():
|
||||
for i, req in enumerate(low_priority_requests)
|
||||
},
|
||||
sampled_token_ids=[[100] for _ in low_priority_requests],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -1541,7 +1523,6 @@ def test_priority_scheduling_no_preemption_when_space_available():
|
||||
for i, req in enumerate(low_priority_requests)
|
||||
},
|
||||
sampled_token_ids=[[100] for _ in low_priority_requests],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
@ -1783,7 +1764,6 @@ def test_priority_scheduling_heap_property():
|
||||
req_ids=[req.req_id],
|
||||
req_id_to_index={req.req_id: 0},
|
||||
sampled_token_ids=[[100]],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
|
||||
@ -200,7 +200,6 @@ def create_model_runner_output(
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=req_id_to_index,
|
||||
sampled_token_ids=sampled_token_ids,
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=None,
|
||||
|
||||
@ -9,7 +9,7 @@ if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.engine import EngineCoreOutputs
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
|
||||
|
||||
@ -61,6 +61,14 @@ class SchedulerInterface(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def update_draft_token_ids(
|
||||
self,
|
||||
draft_token_ids: "DraftTokenIds",
|
||||
) -> None:
|
||||
"""Update the draft token ids for the scheduled requests."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def add_request(self, request: "Request") -> None:
|
||||
"""Add a new request to the scheduler's internal queue.
|
||||
|
||||
@ -30,7 +30,7 @@ from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
|
||||
EngineCoreOutputs)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
@ -141,7 +141,6 @@ class Scheduler(SchedulerInterface):
|
||||
cache_size=encoder_cache_size)
|
||||
|
||||
speculative_config = vllm_config.speculative_config
|
||||
|
||||
self.use_eagle = False
|
||||
self.num_spec_tokens = self.num_lookahead_tokens = 0
|
||||
if speculative_config:
|
||||
@ -760,7 +759,6 @@ class Scheduler(SchedulerInterface):
|
||||
model_runner_output: ModelRunnerOutput,
|
||||
) -> dict[int, EngineCoreOutputs]:
|
||||
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||
spec_token_ids = model_runner_output.spec_token_ids
|
||||
logprobs = model_runner_output.logprobs
|
||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
@ -845,20 +843,9 @@ class Scheduler(SchedulerInterface):
|
||||
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
|
||||
req_id, new_token_ids)
|
||||
|
||||
# spec_token_ids comes from the model runner output
|
||||
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
|
||||
request.num_nans_in_logits = num_nans_in_logits[req_id]
|
||||
|
||||
# Add newly generated spec token ids to the request.
|
||||
if spec_token_ids is not None:
|
||||
if self.structured_output_manager.should_advance(request):
|
||||
metadata = request.structured_output_request
|
||||
# Needs to happen after new_token_ids are accepted.
|
||||
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
|
||||
spec_token_ids[req_index])
|
||||
else:
|
||||
request.spec_token_ids = spec_token_ids[req_index]
|
||||
|
||||
# Get prompt logprobs for this request.
|
||||
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
||||
if new_token_ids or pooler_output is not None \
|
||||
@ -963,6 +950,30 @@ class Scheduler(SchedulerInterface):
|
||||
self.encoder_cache_manager.free_encoder_input(
|
||||
request, input_id)
|
||||
|
||||
def update_draft_token_ids(
|
||||
self,
|
||||
draft_token_ids: DraftTokenIds,
|
||||
) -> None:
|
||||
for req_id, spec_token_ids in zip(
|
||||
draft_token_ids.req_ids,
|
||||
draft_token_ids.draft_token_ids,
|
||||
):
|
||||
request = self.requests.get(req_id)
|
||||
if request is None or request.is_finished():
|
||||
# The request may have been finished. Skip.
|
||||
continue
|
||||
|
||||
# Add newly generated spec token ids to the request.
|
||||
if not spec_token_ids:
|
||||
# NOTE(woosuk): request.spec_token_ids should be updated.
|
||||
request.spec_token_ids.clear()
|
||||
elif self.structured_output_manager.should_advance(request):
|
||||
metadata = request.structured_output_request
|
||||
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
|
||||
spec_token_ids)
|
||||
else:
|
||||
request.spec_token_ids = spec_token_ids
|
||||
|
||||
def get_request_counts(self) -> tuple[int, int]:
|
||||
"""Returns (num_running_reqs, num_waiting_reqs)."""
|
||||
return len(self.running), len(self.waiting)
|
||||
|
||||
@ -126,6 +126,7 @@ class EngineCore:
|
||||
> 1,
|
||||
log_stats=self.log_stats,
|
||||
)
|
||||
self.use_spec_decode = vllm_config.speculative_config is not None
|
||||
|
||||
self.mm_input_cache_server = MultiModalInputCacheServer(
|
||||
vllm_config.model_config, MULTIMODAL_REGISTRY)
|
||||
@ -294,6 +295,13 @@ class EngineCore:
|
||||
return (engine_core_outputs,
|
||||
scheduler_output.total_num_scheduled_tokens > 0)
|
||||
|
||||
def post_step(self, model_executed: bool) -> None:
|
||||
if self.use_spec_decode and model_executed:
|
||||
# Take the draft token ids.
|
||||
draft_token_ids = self.model_executor.take_draft_token_ids()
|
||||
if draft_token_ids is not None:
|
||||
self.scheduler.update_draft_token_ids(draft_token_ids)
|
||||
|
||||
def step_with_batch_queue(
|
||||
self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
|
||||
"""Schedule and execute batches with the batch queue.
|
||||
@ -746,6 +754,8 @@ class EngineCoreProc(EngineCore):
|
||||
# Put EngineCoreOutputs into the output queue.
|
||||
for output in (outputs.items() if outputs else ()):
|
||||
self.output_queue.put_nowait(output)
|
||||
# Post-step hook.
|
||||
self.post_step(model_executed)
|
||||
|
||||
return model_executed
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from concurrent.futures import Future
|
||||
from typing import Callable, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -14,7 +14,7 @@ from vllm.executor.uniproc_executor import ( # noqa
|
||||
from vllm.executor.uniproc_executor import ( # noqa
|
||||
UniProcExecutor as UniProcExecutorV0)
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
|
||||
FailureCallback = Callable[[], None]
|
||||
|
||||
@ -88,6 +88,10 @@ class Executor(ExecutorBase):
|
||||
args=(scheduler_output, ))
|
||||
return output[0]
|
||||
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
output = self.collective_rpc("take_draft_token_ids")
|
||||
return output[0]
|
||||
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
return 1
|
||||
|
||||
@ -33,7 +33,7 @@ from vllm.utils import (decorate_logs, get_distributed_init_method,
|
||||
get_loopback_ip, get_mp_context, get_open_port,
|
||||
set_process_title)
|
||||
from vllm.v1.executor.abstract import Executor, FailureCallback
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -191,6 +191,12 @@ class MultiprocExecutor(Executor):
|
||||
outputs, self.output_rank)
|
||||
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
|
||||
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
# OPTIMIZATION: Get output only from a single worker (output_rank)
|
||||
outputs = self.collective_rpc("take_draft_token_ids",
|
||||
unique_reply_rank=self.output_rank)
|
||||
return outputs[0]
|
||||
|
||||
def collective_rpc(self,
|
||||
method: Union[str, Callable],
|
||||
timeout: Optional[float] = None,
|
||||
|
||||
@ -94,9 +94,6 @@ class ModelRunnerOutput:
|
||||
# each request due to speculative/jump decoding.
|
||||
sampled_token_ids: list[list[int]]
|
||||
|
||||
# num_reqs x num_spec_tokens
|
||||
spec_token_ids: Optional[list[list[int]]]
|
||||
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
# [num_reqs, max_num_logprobs + 1]
|
||||
# [num_reqs]
|
||||
@ -117,10 +114,18 @@ class ModelRunnerOutput:
|
||||
num_nans_in_logits: Optional[dict[str, int]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DraftTokenIds:
|
||||
|
||||
# [num_reqs]
|
||||
req_ids: list[str]
|
||||
# num_reqs x num_draft_tokens
|
||||
draft_token_ids: list[list[int]]
|
||||
|
||||
|
||||
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
|
||||
req_id_to_index={},
|
||||
sampled_token_ids=[],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[],
|
||||
|
||||
@ -38,12 +38,14 @@ class MedusaProposer:
|
||||
self,
|
||||
target_hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> torch.Tensor:
|
||||
) -> list[list[int]]:
|
||||
# Generate blocks and compute logits
|
||||
blocks = self.model(target_hidden_states)
|
||||
logits = self.model.compute_logits(blocks, None)
|
||||
|
||||
# Get draft tokens and transpose the result
|
||||
# TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU
|
||||
# synchronization.
|
||||
draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits]
|
||||
return [list(row) for row in zip(*draft_tokens)]
|
||||
|
||||
|
||||
@ -65,8 +65,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
FullAttentionSpec, KVCacheConfig,
|
||||
KVCacheSpec, MambaSpec,
|
||||
SlidingWindowSpec)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
||||
ModelRunnerOutput)
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
|
||||
LogprobsTensors, ModelRunnerOutput)
|
||||
from vllm.v1.pool.metadata import PoolingMetadata
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
@ -348,6 +348,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
self.reorder_batch_threshold: Optional[int] = None
|
||||
|
||||
# Cached outputs.
|
||||
self._draft_token_ids: Optional[Union[list[list[int]],
|
||||
torch.Tensor]] = None
|
||||
|
||||
def _init_model_kwargs(self, num_tokens: int):
|
||||
model_kwargs = dict[str, Any]()
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
@ -1493,7 +1497,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=[],
|
||||
spec_token_ids=None,
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=pooler_output,
|
||||
@ -1764,12 +1767,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_state = self.requests[req_id]
|
||||
req_state.output_token_ids.extend(sampled_ids)
|
||||
|
||||
if not self.speculative_config:
|
||||
# Speculative decoding is not enabled.
|
||||
spec_token_ids = None
|
||||
else:
|
||||
if self.speculative_config:
|
||||
assert spec_decode_common_attn_metadata is not None
|
||||
spec_token_ids = self.propose_draft_token_ids(
|
||||
self._draft_token_ids = self.propose_draft_token_ids(
|
||||
scheduler_output,
|
||||
valid_sampled_token_ids,
|
||||
sampling_metadata,
|
||||
@ -1786,7 +1786,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_ids=self.input_batch.req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
spec_token_ids=spec_token_ids,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
@ -1794,6 +1793,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
num_nans_in_logits=num_nans_in_logits,
|
||||
)
|
||||
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
if self._draft_token_ids is None:
|
||||
return None
|
||||
req_ids = self.input_batch.req_ids
|
||||
if isinstance(self._draft_token_ids, torch.Tensor):
|
||||
draft_token_ids = self._draft_token_ids.tolist()
|
||||
else:
|
||||
draft_token_ids = self._draft_token_ids
|
||||
self._draft_token_ids = None
|
||||
return DraftTokenIds(req_ids, draft_token_ids)
|
||||
|
||||
def propose_draft_token_ids(
|
||||
self,
|
||||
scheduler_output: "SchedulerOutput",
|
||||
@ -1804,11 +1814,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
aux_hidden_states: Optional[torch.Tensor],
|
||||
spec_decode_metadata: Optional[SpecDecodeMetadata],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
) -> list[list[int]]:
|
||||
) -> Union[list[list[int]], torch.Tensor]:
|
||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||
if self.speculative_config.method == "ngram":
|
||||
assert isinstance(self.drafter, NgramProposer)
|
||||
spec_token_ids = self.propose_ngram_draft_token_ids(
|
||||
draft_token_ids = self.propose_ngram_draft_token_ids(
|
||||
sampled_token_ids)
|
||||
elif self.speculative_config.method == "medusa":
|
||||
assert isinstance(self.drafter, MedusaProposer)
|
||||
@ -1826,7 +1836,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
indices = torch.tensor(indices, device=self.device)
|
||||
hidden_states = sample_hidden_states[indices]
|
||||
|
||||
spec_token_ids = self.drafter.propose(
|
||||
draft_token_ids = self.drafter.propose(
|
||||
target_hidden_states=hidden_states,
|
||||
sampling_metadata=sampling_metadata,
|
||||
)
|
||||
@ -1897,8 +1907,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
mm_embeds=mm_embeds,
|
||||
)
|
||||
spec_token_ids = draft_token_ids.tolist()
|
||||
return spec_token_ids
|
||||
return draft_token_ids
|
||||
|
||||
def propose_ngram_draft_token_ids(
|
||||
self,
|
||||
|
||||
@ -28,7 +28,8 @@ from vllm.tasks import SupportedTask
|
||||
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
|
||||
ModelRunnerOutput)
|
||||
from vllm.v1.utils import report_usage_stats
|
||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
@ -386,6 +387,9 @@ class Worker(WorkerBase):
|
||||
assert isinstance(output, ModelRunnerOutput)
|
||||
return output
|
||||
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
return self.model_runner.take_draft_token_ids()
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
if self.profiler is None:
|
||||
raise RuntimeError("Profiler is not enabled.")
|
||||
|
||||
@ -1145,7 +1145,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
req_ids=req_ids,
|
||||
req_id_to_index=self.input_batch.req_id_to_index,
|
||||
sampled_token_ids=valid_sampled_token_ids,
|
||||
spec_token_ids=None,
|
||||
logprobs=logprobs_lists,
|
||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||
pooler_output=[],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user