mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-02 11:37:59 +08:00
[Spec Decode] Make propose_draft_token_ids non-blocking for lower TTFT (#23041)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
0dd3f4f5ab
commit
c9b38be8aa
@ -22,7 +22,6 @@ def _make_model_runner_output(
|
|||||||
for i, req_id in enumerate(req_ids)
|
for i, req_id in enumerate(req_ids)
|
||||||
},
|
},
|
||||||
sampled_token_ids=[[i] for i in range(len(req_ids))],
|
sampled_token_ids=[[i] for i in range(len(req_ids))],
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
|
|||||||
@ -14,7 +14,7 @@ from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
|||||||
from vllm.v1.core.sched.scheduler import Scheduler
|
from vllm.v1.core.sched.scheduler import Scheduler
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheGroupSpec)
|
KVCacheGroupSpec)
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
from vllm.v1.structured_output import StructuredOutputManager
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
from vllm.v1.structured_output.request import StructuredOutputRequest
|
from vllm.v1.structured_output.request import StructuredOutputRequest
|
||||||
@ -158,7 +158,6 @@ def test_schedule_partial_requests():
|
|||||||
# Only the first request has a sampled token id because
|
# Only the first request has a sampled token id because
|
||||||
# the rest requests are still being prefilled.
|
# the rest requests are still being prefilled.
|
||||||
sampled_token_ids=[[0], [], []],
|
sampled_token_ids=[[0], [], []],
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
@ -209,7 +208,6 @@ def test_no_mm_input_chunking():
|
|||||||
req_ids=[request.request_id for request in requests],
|
req_ids=[request.request_id for request in requests],
|
||||||
req_id_to_index=req_to_index,
|
req_id_to_index=req_to_index,
|
||||||
sampled_token_ids=[[] for _ in range(len(requests))],
|
sampled_token_ids=[[] for _ in range(len(requests))],
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
@ -273,7 +271,6 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
|||||||
req_ids=[request.request_id for request in requests],
|
req_ids=[request.request_id for request in requests],
|
||||||
req_id_to_index=req_to_index,
|
req_id_to_index=req_to_index,
|
||||||
sampled_token_ids=[[] for _ in range(len(requests))],
|
sampled_token_ids=[[] for _ in range(len(requests))],
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
@ -298,7 +295,6 @@ def test_schedule_concurrent_partial_requests(enable_prefix_caching: bool):
|
|||||||
req_ids=[request.request_id for request in requests],
|
req_ids=[request.request_id for request in requests],
|
||||||
req_id_to_index=req_to_index,
|
req_id_to_index=req_to_index,
|
||||||
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
|
sampled_token_ids=[[0], [0]] + [[] for _ in range(len(requests) - 2)],
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
@ -355,7 +351,6 @@ def test_stop_via_update_from_output():
|
|||||||
sampled_token_ids=[[EOS_TOKEN_ID],
|
sampled_token_ids=[[EOS_TOKEN_ID],
|
||||||
[10,
|
[10,
|
||||||
11]], # First request hits EOS, second continues
|
11]], # First request hits EOS, second continues
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[])
|
pooler_output=[])
|
||||||
@ -409,7 +404,6 @@ def test_stop_via_update_from_output():
|
|||||||
},
|
},
|
||||||
sampled_token_ids=[[10, 42, 12],
|
sampled_token_ids=[[10, 42, 12],
|
||||||
[13, 14]], # First request hits stop token
|
[13, 14]], # First request hits stop token
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[])
|
pooler_output=[])
|
||||||
@ -462,7 +456,6 @@ def test_stop_via_update_from_output():
|
|||||||
},
|
},
|
||||||
sampled_token_ids=[[10, 11, 12],
|
sampled_token_ids=[[10, 11, 12],
|
||||||
[13]], # First request exceeds max_tokens
|
[13]], # First request exceeds max_tokens
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[])
|
pooler_output=[])
|
||||||
@ -505,7 +498,6 @@ def test_stop_via_update_from_output():
|
|||||||
req_ids=[requests[0].request_id],
|
req_ids=[requests[0].request_id],
|
||||||
req_id_to_index={requests[0].request_id: 0},
|
req_id_to_index={requests[0].request_id: 0},
|
||||||
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[])
|
pooler_output=[])
|
||||||
@ -554,7 +546,6 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
|
|||||||
req_ids=[requests[0].request_id],
|
req_ids=[requests[0].request_id],
|
||||||
req_id_to_index={requests[0].request_id: 0},
|
req_id_to_index={requests[0].request_id: 0},
|
||||||
sampled_token_ids=[[0]],
|
sampled_token_ids=[[0]],
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
@ -572,7 +563,6 @@ def test_schedule_concurrent_batches(enable_prefix_caching: Optional[bool],
|
|||||||
req_ids=[requests[1].request_id],
|
req_ids=[requests[1].request_id],
|
||||||
req_id_to_index={requests[1].request_id: 0},
|
req_id_to_index={requests[1].request_id: 0},
|
||||||
sampled_token_ids=[[0]],
|
sampled_token_ids=[[0]],
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
@ -608,7 +598,6 @@ def test_preempt_during_execution():
|
|||||||
req_ids=[requests[0].request_id],
|
req_ids=[requests[0].request_id],
|
||||||
req_id_to_index={requests[0].request_id: 0},
|
req_id_to_index={requests[0].request_id: 0},
|
||||||
sampled_token_ids=[[0]],
|
sampled_token_ids=[[0]],
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
@ -626,7 +615,6 @@ def test_preempt_during_execution():
|
|||||||
req_ids=[requests[1].request_id],
|
req_ids=[requests[1].request_id],
|
||||||
req_id_to_index={requests[1].request_id: 0},
|
req_id_to_index={requests[1].request_id: 0},
|
||||||
sampled_token_ids=[[42]],
|
sampled_token_ids=[[42]],
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
@ -682,13 +670,14 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
|||||||
req_ids=req_ids,
|
req_ids=req_ids,
|
||||||
req_id_to_index=req_to_index,
|
req_id_to_index=req_to_index,
|
||||||
sampled_token_ids=[[0] for _ in range(len(requests))],
|
sampled_token_ids=[[0] for _ in range(len(requests))],
|
||||||
spec_token_ids=spec_tokens,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
)
|
)
|
||||||
engine_core_outputs = scheduler.update_from_output(output,
|
engine_core_outputs = scheduler.update_from_output(output,
|
||||||
model_runner_output)
|
model_runner_output)
|
||||||
|
draft_token_ids = DraftTokenIds(req_ids, spec_tokens)
|
||||||
|
scheduler.update_draft_token_ids(draft_token_ids)
|
||||||
|
|
||||||
for i in range(len(requests)):
|
for i in range(len(requests)):
|
||||||
running_req = scheduler.running[i]
|
running_req = scheduler.running[i]
|
||||||
@ -722,7 +711,6 @@ def test_schedule_spec_decoding_stats(spec_tokens, output_tokens, expected):
|
|||||||
req_ids=req_ids,
|
req_ids=req_ids,
|
||||||
req_id_to_index=req_to_index,
|
req_id_to_index=req_to_index,
|
||||||
sampled_token_ids=output_tokens,
|
sampled_token_ids=output_tokens,
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
@ -851,7 +839,6 @@ def test_kv_connector_basic():
|
|||||||
req_ids=req_ids,
|
req_ids=req_ids,
|
||||||
req_id_to_index=req_to_index,
|
req_id_to_index=req_to_index,
|
||||||
sampled_token_ids=[[1000]] * len(req_ids),
|
sampled_token_ids=[[1000]] * len(req_ids),
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
@ -898,7 +885,6 @@ def test_kv_connector_basic():
|
|||||||
req_ids=req_ids,
|
req_ids=req_ids,
|
||||||
req_id_to_index=req_to_index,
|
req_id_to_index=req_to_index,
|
||||||
sampled_token_ids=[[1000]] * len(req_ids),
|
sampled_token_ids=[[1000]] * len(req_ids),
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
@ -966,7 +952,6 @@ def test_kv_connector_unable_to_allocate():
|
|||||||
req_ids=req_ids,
|
req_ids=req_ids,
|
||||||
req_id_to_index=req_to_index,
|
req_id_to_index=req_to_index,
|
||||||
sampled_token_ids=[[1000]] * len(req_ids),
|
sampled_token_ids=[[1000]] * len(req_ids),
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
@ -1048,7 +1033,6 @@ def test_kv_connector_handles_preemption():
|
|||||||
req_ids=req_ids,
|
req_ids=req_ids,
|
||||||
req_id_to_index=req_to_index,
|
req_id_to_index=req_to_index,
|
||||||
sampled_token_ids=[[1000]] * len(req_ids),
|
sampled_token_ids=[[1000]] * len(req_ids),
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
@ -1142,7 +1126,6 @@ def make_output(scheduler: Scheduler):
|
|||||||
for i, req in enumerate(scheduler.running)
|
for i, req in enumerate(scheduler.running)
|
||||||
},
|
},
|
||||||
sampled_token_ids=[[1000]] * len(scheduler.running),
|
sampled_token_ids=[[1000]] * len(scheduler.running),
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
@ -1468,7 +1451,6 @@ def test_priority_scheduling_preemption():
|
|||||||
for i, req in enumerate(low_priority_requests)
|
for i, req in enumerate(low_priority_requests)
|
||||||
},
|
},
|
||||||
sampled_token_ids=[[100] for _ in low_priority_requests],
|
sampled_token_ids=[[100] for _ in low_priority_requests],
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
@ -1541,7 +1523,6 @@ def test_priority_scheduling_no_preemption_when_space_available():
|
|||||||
for i, req in enumerate(low_priority_requests)
|
for i, req in enumerate(low_priority_requests)
|
||||||
},
|
},
|
||||||
sampled_token_ids=[[100] for _ in low_priority_requests],
|
sampled_token_ids=[[100] for _ in low_priority_requests],
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
@ -1783,7 +1764,6 @@ def test_priority_scheduling_heap_property():
|
|||||||
req_ids=[req.req_id],
|
req_ids=[req.req_id],
|
||||||
req_id_to_index={req.req_id: 0},
|
req_id_to_index={req.req_id: 0},
|
||||||
sampled_token_ids=[[100]],
|
sampled_token_ids=[[100]],
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
|
|||||||
@ -200,7 +200,6 @@ def create_model_runner_output(
|
|||||||
req_ids=req_ids,
|
req_ids=req_ids,
|
||||||
req_id_to_index=req_id_to_index,
|
req_id_to_index=req_id_to_index,
|
||||||
sampled_token_ids=sampled_token_ids,
|
sampled_token_ids=sampled_token_ids,
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=None,
|
pooler_output=None,
|
||||||
|
|||||||
@ -9,7 +9,7 @@ if TYPE_CHECKING:
|
|||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.engine import EngineCoreOutputs
|
from vllm.v1.engine import EngineCoreOutputs
|
||||||
from vllm.v1.metrics.stats import SchedulerStats
|
from vllm.v1.metrics.stats import SchedulerStats
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
|
|
||||||
|
|
||||||
@ -61,6 +61,14 @@ class SchedulerInterface(ABC):
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def update_draft_token_ids(
|
||||||
|
self,
|
||||||
|
draft_token_ids: "DraftTokenIds",
|
||||||
|
) -> None:
|
||||||
|
"""Update the draft token ids for the scheduled requests."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_request(self, request: "Request") -> None:
|
def add_request(self, request: "Request") -> None:
|
||||||
"""Add a new request to the scheduler's internal queue.
|
"""Add a new request to the scheduler's internal queue.
|
||||||
|
|||||||
@ -30,7 +30,7 @@ from vllm.v1.engine import (EngineCoreEventType, EngineCoreOutput,
|
|||||||
EngineCoreOutputs)
|
EngineCoreOutputs)
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||||
from vllm.v1.metrics.stats import SchedulerStats
|
from vllm.v1.metrics.stats import SchedulerStats
|
||||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
from vllm.v1.outputs import DraftTokenIds, KVConnectorOutput, ModelRunnerOutput
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||||
from vllm.v1.structured_output import StructuredOutputManager
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
@ -141,7 +141,6 @@ class Scheduler(SchedulerInterface):
|
|||||||
cache_size=encoder_cache_size)
|
cache_size=encoder_cache_size)
|
||||||
|
|
||||||
speculative_config = vllm_config.speculative_config
|
speculative_config = vllm_config.speculative_config
|
||||||
|
|
||||||
self.use_eagle = False
|
self.use_eagle = False
|
||||||
self.num_spec_tokens = self.num_lookahead_tokens = 0
|
self.num_spec_tokens = self.num_lookahead_tokens = 0
|
||||||
if speculative_config:
|
if speculative_config:
|
||||||
@ -760,7 +759,6 @@ class Scheduler(SchedulerInterface):
|
|||||||
model_runner_output: ModelRunnerOutput,
|
model_runner_output: ModelRunnerOutput,
|
||||||
) -> dict[int, EngineCoreOutputs]:
|
) -> dict[int, EngineCoreOutputs]:
|
||||||
sampled_token_ids = model_runner_output.sampled_token_ids
|
sampled_token_ids = model_runner_output.sampled_token_ids
|
||||||
spec_token_ids = model_runner_output.spec_token_ids
|
|
||||||
logprobs = model_runner_output.logprobs
|
logprobs = model_runner_output.logprobs
|
||||||
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
prompt_logprobs_dict = model_runner_output.prompt_logprobs_dict
|
||||||
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
num_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||||
@ -845,20 +843,9 @@ class Scheduler(SchedulerInterface):
|
|||||||
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
|
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
|
||||||
req_id, new_token_ids)
|
req_id, new_token_ids)
|
||||||
|
|
||||||
# spec_token_ids comes from the model runner output
|
|
||||||
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
|
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
|
||||||
request.num_nans_in_logits = num_nans_in_logits[req_id]
|
request.num_nans_in_logits = num_nans_in_logits[req_id]
|
||||||
|
|
||||||
# Add newly generated spec token ids to the request.
|
|
||||||
if spec_token_ids is not None:
|
|
||||||
if self.structured_output_manager.should_advance(request):
|
|
||||||
metadata = request.structured_output_request
|
|
||||||
# Needs to happen after new_token_ids are accepted.
|
|
||||||
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
|
|
||||||
spec_token_ids[req_index])
|
|
||||||
else:
|
|
||||||
request.spec_token_ids = spec_token_ids[req_index]
|
|
||||||
|
|
||||||
# Get prompt logprobs for this request.
|
# Get prompt logprobs for this request.
|
||||||
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
prompt_logprobs_tensors = prompt_logprobs_dict.get(req_id)
|
||||||
if new_token_ids or pooler_output is not None \
|
if new_token_ids or pooler_output is not None \
|
||||||
@ -963,6 +950,30 @@ class Scheduler(SchedulerInterface):
|
|||||||
self.encoder_cache_manager.free_encoder_input(
|
self.encoder_cache_manager.free_encoder_input(
|
||||||
request, input_id)
|
request, input_id)
|
||||||
|
|
||||||
|
def update_draft_token_ids(
|
||||||
|
self,
|
||||||
|
draft_token_ids: DraftTokenIds,
|
||||||
|
) -> None:
|
||||||
|
for req_id, spec_token_ids in zip(
|
||||||
|
draft_token_ids.req_ids,
|
||||||
|
draft_token_ids.draft_token_ids,
|
||||||
|
):
|
||||||
|
request = self.requests.get(req_id)
|
||||||
|
if request is None or request.is_finished():
|
||||||
|
# The request may have been finished. Skip.
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Add newly generated spec token ids to the request.
|
||||||
|
if not spec_token_ids:
|
||||||
|
# NOTE(woosuk): request.spec_token_ids should be updated.
|
||||||
|
request.spec_token_ids.clear()
|
||||||
|
elif self.structured_output_manager.should_advance(request):
|
||||||
|
metadata = request.structured_output_request
|
||||||
|
request.spec_token_ids = metadata.grammar.validate_tokens( # type: ignore[union-attr]
|
||||||
|
spec_token_ids)
|
||||||
|
else:
|
||||||
|
request.spec_token_ids = spec_token_ids
|
||||||
|
|
||||||
def get_request_counts(self) -> tuple[int, int]:
|
def get_request_counts(self) -> tuple[int, int]:
|
||||||
"""Returns (num_running_reqs, num_waiting_reqs)."""
|
"""Returns (num_running_reqs, num_waiting_reqs)."""
|
||||||
return len(self.running), len(self.waiting)
|
return len(self.running), len(self.waiting)
|
||||||
|
|||||||
@ -126,6 +126,7 @@ class EngineCore:
|
|||||||
> 1,
|
> 1,
|
||||||
log_stats=self.log_stats,
|
log_stats=self.log_stats,
|
||||||
)
|
)
|
||||||
|
self.use_spec_decode = vllm_config.speculative_config is not None
|
||||||
|
|
||||||
self.mm_input_cache_server = MultiModalInputCacheServer(
|
self.mm_input_cache_server = MultiModalInputCacheServer(
|
||||||
vllm_config.model_config, MULTIMODAL_REGISTRY)
|
vllm_config.model_config, MULTIMODAL_REGISTRY)
|
||||||
@ -294,6 +295,13 @@ class EngineCore:
|
|||||||
return (engine_core_outputs,
|
return (engine_core_outputs,
|
||||||
scheduler_output.total_num_scheduled_tokens > 0)
|
scheduler_output.total_num_scheduled_tokens > 0)
|
||||||
|
|
||||||
|
def post_step(self, model_executed: bool) -> None:
|
||||||
|
if self.use_spec_decode and model_executed:
|
||||||
|
# Take the draft token ids.
|
||||||
|
draft_token_ids = self.model_executor.take_draft_token_ids()
|
||||||
|
if draft_token_ids is not None:
|
||||||
|
self.scheduler.update_draft_token_ids(draft_token_ids)
|
||||||
|
|
||||||
def step_with_batch_queue(
|
def step_with_batch_queue(
|
||||||
self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
|
self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
|
||||||
"""Schedule and execute batches with the batch queue.
|
"""Schedule and execute batches with the batch queue.
|
||||||
@ -746,6 +754,8 @@ class EngineCoreProc(EngineCore):
|
|||||||
# Put EngineCoreOutputs into the output queue.
|
# Put EngineCoreOutputs into the output queue.
|
||||||
for output in (outputs.items() if outputs else ()):
|
for output in (outputs.items() if outputs else ()):
|
||||||
self.output_queue.put_nowait(output)
|
self.output_queue.put_nowait(output)
|
||||||
|
# Post-step hook.
|
||||||
|
self.post_step(model_executed)
|
||||||
|
|
||||||
return model_executed
|
return model_executed
|
||||||
|
|
||||||
|
|||||||
@ -2,7 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
from typing import Callable, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -14,7 +14,7 @@ from vllm.executor.uniproc_executor import ( # noqa
|
|||||||
from vllm.executor.uniproc_executor import ( # noqa
|
from vllm.executor.uniproc_executor import ( # noqa
|
||||||
UniProcExecutor as UniProcExecutorV0)
|
UniProcExecutor as UniProcExecutorV0)
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||||
|
|
||||||
FailureCallback = Callable[[], None]
|
FailureCallback = Callable[[], None]
|
||||||
|
|
||||||
@ -88,6 +88,10 @@ class Executor(ExecutorBase):
|
|||||||
args=(scheduler_output, ))
|
args=(scheduler_output, ))
|
||||||
return output[0]
|
return output[0]
|
||||||
|
|
||||||
|
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||||
|
output = self.collective_rpc("take_draft_token_ids")
|
||||||
|
return output[0]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def max_concurrent_batches(self) -> int:
|
def max_concurrent_batches(self) -> int:
|
||||||
return 1
|
return 1
|
||||||
|
|||||||
@ -33,7 +33,7 @@ from vllm.utils import (decorate_logs, get_distributed_init_method,
|
|||||||
get_loopback_ip, get_mp_context, get_open_port,
|
get_loopback_ip, get_mp_context, get_open_port,
|
||||||
set_process_title)
|
set_process_title)
|
||||||
from vllm.v1.executor.abstract import Executor, FailureCallback
|
from vllm.v1.executor.abstract import Executor, FailureCallback
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||||
from vllm.worker.worker_base import WorkerWrapperBase
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -191,6 +191,12 @@ class MultiprocExecutor(Executor):
|
|||||||
outputs, self.output_rank)
|
outputs, self.output_rank)
|
||||||
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
|
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
|
||||||
|
|
||||||
|
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||||
|
# OPTIMIZATION: Get output only from a single worker (output_rank)
|
||||||
|
outputs = self.collective_rpc("take_draft_token_ids",
|
||||||
|
unique_reply_rank=self.output_rank)
|
||||||
|
return outputs[0]
|
||||||
|
|
||||||
def collective_rpc(self,
|
def collective_rpc(self,
|
||||||
method: Union[str, Callable],
|
method: Union[str, Callable],
|
||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
|
|||||||
@ -94,9 +94,6 @@ class ModelRunnerOutput:
|
|||||||
# each request due to speculative/jump decoding.
|
# each request due to speculative/jump decoding.
|
||||||
sampled_token_ids: list[list[int]]
|
sampled_token_ids: list[list[int]]
|
||||||
|
|
||||||
# num_reqs x num_spec_tokens
|
|
||||||
spec_token_ids: Optional[list[list[int]]]
|
|
||||||
|
|
||||||
# [num_reqs, max_num_logprobs + 1]
|
# [num_reqs, max_num_logprobs + 1]
|
||||||
# [num_reqs, max_num_logprobs + 1]
|
# [num_reqs, max_num_logprobs + 1]
|
||||||
# [num_reqs]
|
# [num_reqs]
|
||||||
@ -117,10 +114,18 @@ class ModelRunnerOutput:
|
|||||||
num_nans_in_logits: Optional[dict[str, int]] = None
|
num_nans_in_logits: Optional[dict[str, int]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DraftTokenIds:
|
||||||
|
|
||||||
|
# [num_reqs]
|
||||||
|
req_ids: list[str]
|
||||||
|
# num_reqs x num_draft_tokens
|
||||||
|
draft_token_ids: list[list[int]]
|
||||||
|
|
||||||
|
|
||||||
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
|
EMPTY_MODEL_RUNNER_OUTPUT = ModelRunnerOutput(req_ids=[],
|
||||||
req_id_to_index={},
|
req_id_to_index={},
|
||||||
sampled_token_ids=[],
|
sampled_token_ids=[],
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
|
|||||||
@ -38,12 +38,14 @@ class MedusaProposer:
|
|||||||
self,
|
self,
|
||||||
target_hidden_states: torch.Tensor,
|
target_hidden_states: torch.Tensor,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
) -> torch.Tensor:
|
) -> list[list[int]]:
|
||||||
# Generate blocks and compute logits
|
# Generate blocks and compute logits
|
||||||
blocks = self.model(target_hidden_states)
|
blocks = self.model(target_hidden_states)
|
||||||
logits = self.model.compute_logits(blocks, None)
|
logits = self.model.compute_logits(blocks, None)
|
||||||
|
|
||||||
# Get draft tokens and transpose the result
|
# Get draft tokens and transpose the result
|
||||||
|
# TODO(woosuk): OPTIMIZATION: Return GPU tensor without GPU-CPU
|
||||||
|
# synchronization.
|
||||||
draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits]
|
draft_tokens = [logit.argmax(dim=-1).tolist() for logit in logits]
|
||||||
return [list(row) for row in zip(*draft_tokens)]
|
return [list(row) for row in zip(*draft_tokens)]
|
||||||
|
|
||||||
|
|||||||
@ -65,8 +65,8 @@ from vllm.v1.kv_cache_interface import (AttentionSpec,
|
|||||||
FullAttentionSpec, KVCacheConfig,
|
FullAttentionSpec, KVCacheConfig,
|
||||||
KVCacheSpec, MambaSpec,
|
KVCacheSpec, MambaSpec,
|
||||||
SlidingWindowSpec)
|
SlidingWindowSpec)
|
||||||
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, LogprobsTensors,
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
|
||||||
ModelRunnerOutput)
|
LogprobsTensors, ModelRunnerOutput)
|
||||||
from vllm.v1.pool.metadata import PoolingMetadata
|
from vllm.v1.pool.metadata import PoolingMetadata
|
||||||
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
|
from vllm.v1.sample.logits_processor import LogitsProcessors, build_logitsprocs
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
@ -348,6 +348,10 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
|
|
||||||
self.reorder_batch_threshold: Optional[int] = None
|
self.reorder_batch_threshold: Optional[int] = None
|
||||||
|
|
||||||
|
# Cached outputs.
|
||||||
|
self._draft_token_ids: Optional[Union[list[list[int]],
|
||||||
|
torch.Tensor]] = None
|
||||||
|
|
||||||
def _init_model_kwargs(self, num_tokens: int):
|
def _init_model_kwargs(self, num_tokens: int):
|
||||||
model_kwargs = dict[str, Any]()
|
model_kwargs = dict[str, Any]()
|
||||||
num_reqs = self.input_batch.num_reqs
|
num_reqs = self.input_batch.num_reqs
|
||||||
@ -1493,7 +1497,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
req_ids=self.input_batch.req_ids,
|
req_ids=self.input_batch.req_ids,
|
||||||
req_id_to_index=self.input_batch.req_id_to_index,
|
req_id_to_index=self.input_batch.req_id_to_index,
|
||||||
sampled_token_ids=[],
|
sampled_token_ids=[],
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=None,
|
logprobs=None,
|
||||||
prompt_logprobs_dict={},
|
prompt_logprobs_dict={},
|
||||||
pooler_output=pooler_output,
|
pooler_output=pooler_output,
|
||||||
@ -1764,12 +1767,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
req_state = self.requests[req_id]
|
req_state = self.requests[req_id]
|
||||||
req_state.output_token_ids.extend(sampled_ids)
|
req_state.output_token_ids.extend(sampled_ids)
|
||||||
|
|
||||||
if not self.speculative_config:
|
if self.speculative_config:
|
||||||
# Speculative decoding is not enabled.
|
|
||||||
spec_token_ids = None
|
|
||||||
else:
|
|
||||||
assert spec_decode_common_attn_metadata is not None
|
assert spec_decode_common_attn_metadata is not None
|
||||||
spec_token_ids = self.propose_draft_token_ids(
|
self._draft_token_ids = self.propose_draft_token_ids(
|
||||||
scheduler_output,
|
scheduler_output,
|
||||||
valid_sampled_token_ids,
|
valid_sampled_token_ids,
|
||||||
sampling_metadata,
|
sampling_metadata,
|
||||||
@ -1786,7 +1786,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
req_ids=self.input_batch.req_ids,
|
req_ids=self.input_batch.req_ids,
|
||||||
req_id_to_index=self.input_batch.req_id_to_index,
|
req_id_to_index=self.input_batch.req_id_to_index,
|
||||||
sampled_token_ids=valid_sampled_token_ids,
|
sampled_token_ids=valid_sampled_token_ids,
|
||||||
spec_token_ids=spec_token_ids,
|
|
||||||
logprobs=logprobs_lists,
|
logprobs=logprobs_lists,
|
||||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
@ -1794,6 +1793,17 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
num_nans_in_logits=num_nans_in_logits,
|
num_nans_in_logits=num_nans_in_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||||
|
if self._draft_token_ids is None:
|
||||||
|
return None
|
||||||
|
req_ids = self.input_batch.req_ids
|
||||||
|
if isinstance(self._draft_token_ids, torch.Tensor):
|
||||||
|
draft_token_ids = self._draft_token_ids.tolist()
|
||||||
|
else:
|
||||||
|
draft_token_ids = self._draft_token_ids
|
||||||
|
self._draft_token_ids = None
|
||||||
|
return DraftTokenIds(req_ids, draft_token_ids)
|
||||||
|
|
||||||
def propose_draft_token_ids(
|
def propose_draft_token_ids(
|
||||||
self,
|
self,
|
||||||
scheduler_output: "SchedulerOutput",
|
scheduler_output: "SchedulerOutput",
|
||||||
@ -1804,11 +1814,11 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
aux_hidden_states: Optional[torch.Tensor],
|
aux_hidden_states: Optional[torch.Tensor],
|
||||||
spec_decode_metadata: Optional[SpecDecodeMetadata],
|
spec_decode_metadata: Optional[SpecDecodeMetadata],
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
) -> list[list[int]]:
|
) -> Union[list[list[int]], torch.Tensor]:
|
||||||
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
|
||||||
if self.speculative_config.method == "ngram":
|
if self.speculative_config.method == "ngram":
|
||||||
assert isinstance(self.drafter, NgramProposer)
|
assert isinstance(self.drafter, NgramProposer)
|
||||||
spec_token_ids = self.propose_ngram_draft_token_ids(
|
draft_token_ids = self.propose_ngram_draft_token_ids(
|
||||||
sampled_token_ids)
|
sampled_token_ids)
|
||||||
elif self.speculative_config.method == "medusa":
|
elif self.speculative_config.method == "medusa":
|
||||||
assert isinstance(self.drafter, MedusaProposer)
|
assert isinstance(self.drafter, MedusaProposer)
|
||||||
@ -1826,7 +1836,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
indices = torch.tensor(indices, device=self.device)
|
indices = torch.tensor(indices, device=self.device)
|
||||||
hidden_states = sample_hidden_states[indices]
|
hidden_states = sample_hidden_states[indices]
|
||||||
|
|
||||||
spec_token_ids = self.drafter.propose(
|
draft_token_ids = self.drafter.propose(
|
||||||
target_hidden_states=hidden_states,
|
target_hidden_states=hidden_states,
|
||||||
sampling_metadata=sampling_metadata,
|
sampling_metadata=sampling_metadata,
|
||||||
)
|
)
|
||||||
@ -1897,8 +1907,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
common_attn_metadata=common_attn_metadata,
|
common_attn_metadata=common_attn_metadata,
|
||||||
mm_embeds=mm_embeds,
|
mm_embeds=mm_embeds,
|
||||||
)
|
)
|
||||||
spec_token_ids = draft_token_ids.tolist()
|
return draft_token_ids
|
||||||
return spec_token_ids
|
|
||||||
|
|
||||||
def propose_ngram_draft_token_ids(
|
def propose_ngram_draft_token_ids(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@ -28,7 +28,8 @@ from vllm.tasks import SupportedTask
|
|||||||
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
|
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling
|
||||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, ModelRunnerOutput
|
from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, DraftTokenIds,
|
||||||
|
ModelRunnerOutput)
|
||||||
from vllm.v1.utils import report_usage_stats
|
from vllm.v1.utils import report_usage_stats
|
||||||
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
|
||||||
from vllm.v1.worker.worker_base import WorkerBase
|
from vllm.v1.worker.worker_base import WorkerBase
|
||||||
@ -386,6 +387,9 @@ class Worker(WorkerBase):
|
|||||||
assert isinstance(output, ModelRunnerOutput)
|
assert isinstance(output, ModelRunnerOutput)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||||
|
return self.model_runner.take_draft_token_ids()
|
||||||
|
|
||||||
def profile(self, is_start: bool = True):
|
def profile(self, is_start: bool = True):
|
||||||
if self.profiler is None:
|
if self.profiler is None:
|
||||||
raise RuntimeError("Profiler is not enabled.")
|
raise RuntimeError("Profiler is not enabled.")
|
||||||
|
|||||||
@ -1145,7 +1145,6 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
|||||||
req_ids=req_ids,
|
req_ids=req_ids,
|
||||||
req_id_to_index=self.input_batch.req_id_to_index,
|
req_id_to_index=self.input_batch.req_id_to_index,
|
||||||
sampled_token_ids=valid_sampled_token_ids,
|
sampled_token_ids=valid_sampled_token_ids,
|
||||||
spec_token_ids=None,
|
|
||||||
logprobs=logprobs_lists,
|
logprobs=logprobs_lists,
|
||||||
prompt_logprobs_dict=prompt_logprobs_dict,
|
prompt_logprobs_dict=prompt_logprobs_dict,
|
||||||
pooler_output=[],
|
pooler_output=[],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user