[Core] Async scheduling + structured outputs compatibility (#26866)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-10-31 17:35:04 -07:00 committed by GitHub
parent df334868ca
commit 0cdbe7b744
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 419 additions and 191 deletions

View File

@ -6,6 +6,9 @@ from copy import deepcopy
from tblib import pickling_support
# Import fixture
from tests.v1.entrypoints.conftest import sample_json_schema # noqa
# ruff: noqa
# Install support for pickling exceptions so that we can nicely propagate

View File

@ -337,8 +337,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_output = ModelRunnerOutput(
@ -385,8 +383,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_output = ModelRunnerOutput(
@ -431,8 +427,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_output = ModelRunnerOutput(
@ -472,8 +466,6 @@ def test_stop_via_update_from_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_output = ModelRunnerOutput(
@ -1988,7 +1980,6 @@ def test_schedule_skip_tokenizer_init():
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.grammar_bitmask is None
def test_schedule_skip_tokenizer_init_structured_output_request():

View File

@ -7,6 +7,7 @@ import torch._dynamo.config as dynamo_config
from vllm import SamplingParams
from vllm.logprobs import Logprob
from vllm.sampling_params import StructuredOutputsParams
from ...conftest import VllmRunner
from ...models.utils import check_outputs_equal
@ -15,9 +16,12 @@ MODEL = "Qwen/Qwen3-0.6B"
@dynamo_config.patch(cache_size_limit=16)
def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
def test_preempt_and_async_scheduling_e2e(
sample_json_schema, monkeypatch: pytest.MonkeyPatch
):
"""Test consistency of combos of async scheduling, preemption,
uni/multiproc executor, and various sampling parameters."""
uni/multiproc executor, and various sampling parameters
including structured outputs."""
first_prompt = (
"The following numbers of the sequence "
@ -35,6 +39,12 @@ def test_preempt_and_async_scheduling_e2e(monkeypatch: pytest.MonkeyPatch):
dict(bad_words=["the", " the"]),
dict(logprobs=2),
dict(logprobs=2, presence_penalty=-1.0),
dict(structured_outputs=StructuredOutputsParams(json=sample_json_schema)),
dict(
structured_outputs=StructuredOutputsParams(json=sample_json_schema),
logprobs=2,
presence_penalty=-1.0,
),
]
default_params = dict(

View File

@ -248,7 +248,7 @@ def test_engine_core_concurrent_batches():
self,
scheduler_output,
non_block=False,
) -> Future[ModelRunnerOutput]:
) -> Future[ModelRunnerOutput | None]:
"""Make execute_model non-blocking."""
# DummyExecutor used only for testing async case.
@ -263,6 +263,23 @@ def test_engine_core_concurrent_batches():
# Use the thread pool instead of creating a new thread
return self.thread_pool.submit(_execute)
def sample_tokens(
self, grammar_output, non_block=False
) -> Future[ModelRunnerOutput]:
"""Make sample_tokens non-blocking."""
# DummyExecutor used only for testing async case.
assert non_block
def _execute():
output = self.collective_rpc("sample_tokens", args=(grammar_output,))
# Make a copy because output[0] may be reused
# by the next batch.
return copy.deepcopy(output[0])
# Use the thread pool instead of creating a new thread
return self.thread_pool.submit(_execute)
@property
def max_concurrent_batches(self) -> int:
return 2

View File

@ -31,7 +31,9 @@ class CustomMultiprocExecutor(MultiprocExecutor):
# Drop marker to show that this was run
with open(".marker", "w"):
...
return super().collective_rpc(method, timeout, args, kwargs)
return super().collective_rpc(
method, timeout, args, kwargs, non_block, unique_reply_rank
)
CustomMultiprocExecutorAsync = CustomMultiprocExecutor

View File

@ -26,8 +26,6 @@ def _make_empty_scheduler_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
kv_connector_metadata=SharedStorageConnectorMetadata(),
)

View File

@ -981,9 +981,7 @@ def test_scheduler_kv_connector_stats_aggregation():
scheduled_encoder_inputs={},
num_common_prefix_blocks=[0],
finished_req_ids=set(),
free_encoder_mm_hashes=set(),
structured_output_request_ids={},
grammar_bitmask=None,
free_encoder_mm_hashes=[],
)
engine_core_outputs = scheduler.update_from_output(scheduler_output, model_output)

View File

@ -92,8 +92,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
@ -171,8 +169,6 @@ def test_update_states_request_finished(model_runner):
num_common_prefix_blocks=[],
finished_req_ids={req_id},
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
@ -201,8 +197,6 @@ def test_update_states_request_resumed(model_runner):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
@ -230,8 +224,6 @@ def test_update_states_request_resumed(model_runner):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
@ -261,8 +253,6 @@ def test_update_states_no_changes(model_runner):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
@ -296,8 +286,6 @@ def test_update_states_request_unscheduled(model_runner):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)

View File

@ -152,8 +152,6 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
@ -269,8 +267,6 @@ def test_update_states_request_finished(model_runner, dist_init):
num_common_prefix_blocks=[],
finished_req_ids={req_id},
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
metadata_before = model_runner.input_batch.sampling_metadata
@ -301,8 +297,6 @@ def test_update_states_request_resumed(model_runner, dist_init):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
model_runner._update_states(scheduler_output)
@ -330,8 +324,6 @@ def test_update_states_request_resumed(model_runner, dist_init):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
metadata_before = model_runner.input_batch.sampling_metadata
@ -423,8 +415,6 @@ def test_update_states_no_changes(model_runner, dist_init):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
metadata_before = model_runner.input_batch.sampling_metadata
@ -460,8 +450,6 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids=[],
grammar_bitmask=None,
)
metadata_before = model_runner._update_states(scheduler_output)

View File

@ -6,7 +6,7 @@ KV cache helper for store.
from collections.abc import Sequence
from concurrent.futures import CancelledError, Future
from typing import TYPE_CHECKING, Literal, cast
from typing import TYPE_CHECKING, Literal
import torch
@ -138,8 +138,11 @@ class KVOutputAggregator:
return cls(connector.get_finished_count() or world_size)
def aggregate(
self, outputs: list[ModelRunnerOutput], output_rank: int = 0
) -> ModelRunnerOutput:
self, outputs: list[ModelRunnerOutput | None], output_rank: int = 0
) -> ModelRunnerOutput | None:
if not outputs[output_rank]:
return None
# Aggregate kv_connector_output from all workers
def update_finished_set(
@ -161,6 +164,7 @@ class KVOutputAggregator:
aggregated_kv_connector_stats = None
invalid_block_ids = set[int]()
for model_runner_output in outputs:
assert model_runner_output is not None
kv_output = model_runner_output.kv_connector_output
if not kv_output:
continue
@ -204,6 +208,7 @@ class KVOutputAggregator:
# select output of the worker specified by output_rank
output = outputs[output_rank]
assert output is not None
output.kv_connector_output = KVConnectorOutput(
finished_sending=finished_sending or None,
finished_recving=finished_recving or None,
@ -215,13 +220,16 @@ class KVOutputAggregator:
return output
def async_aggregate(
self, output_futures: Sequence[Future[ModelRunnerOutput]], output_rank: int = 0
) -> Future[ModelRunnerOutput]:
self,
output_futures: Sequence[Future[ModelRunnerOutput | None]],
output_rank: int = 0,
) -> Future[ModelRunnerOutput | None]:
"""Takes a list of futures and returns a single future which resolves
to the respective list of outputs."""
result_future: Future[ModelRunnerOutput] = Future()
result_future: Future[ModelRunnerOutput | None] = Future()
outputs: list[ModelRunnerOutput | None] = [None] * len(output_futures)
remaining = len(output_futures)
def make_callback(idx):
def callback(fut):
@ -236,12 +244,10 @@ class KVOutputAggregator:
result_future.set_exception(e)
# this check assumes io_thread_pool uses a single thread
if all(outputs):
result_future.set_result(
self.aggregate(
cast(list[ModelRunnerOutput], outputs), output_rank
)
)
nonlocal remaining
remaining -= 1
if not remaining:
result_future.set_result(self.aggregate(outputs, output_rank))
return callback

View File

@ -15,8 +15,12 @@ class AsyncScheduler(Scheduler):
scheduler_output: SchedulerOutput,
) -> None:
super()._update_after_schedule(scheduler_output)
pending_structured_output_tokens = False
for req_id in scheduler_output.num_scheduled_tokens:
request = self.requests[req_id]
pending_structured_output_tokens |= (
request.use_structured_output and request.num_output_placeholders > 0
)
if (
request.num_computed_tokens
== request.num_tokens + request.num_output_placeholders
@ -25,6 +29,10 @@ class AsyncScheduler(Scheduler):
# TODO(woosuk): Support speculative decoding.
request.num_output_placeholders += 1
scheduler_output.pending_structured_output_tokens = (
pending_structured_output_tokens
)
def _update_request_with_output(
self,
request: Request,

View File

@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorBase_V1
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import EngineCoreOutputs
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
@ -40,6 +40,12 @@ class SchedulerInterface(ABC):
"""
raise NotImplementedError
@abstractmethod
def get_grammar_bitmask(
self, scheduler_output: "SchedulerOutput"
) -> "GrammarOutput | None":
raise NotImplementedError
@abstractmethod
def update_from_output(
self,

View File

@ -181,12 +181,17 @@ class SchedulerOutput:
# freed from the encoder cache.
free_encoder_mm_hashes: list[str]
# ids of structured outputs requests included in the bitmask, in the
# same order as the corresponding stacked rows of the bitmask.
# There may be more than one row per request in the case of speculative decoding.
structured_output_request_ids: list[str]
# the bitmask for the whole batch
grammar_bitmask: "npt.NDArray[np.int32] | None"
# Whether the scheduled requests have all the output tokens they
# need to perform grammar bitmask computation.
pending_structured_output_tokens: bool = False
# KV Cache Connector metadata.
kv_connector_metadata: KVConnectorMetadata | None = None
@dataclass
class GrammarOutput:
# ids of structured output requests.
structured_output_request_ids: list[str]
# Bitmask ordered as structured_output_request_ids.
grammar_bitmask: "npt.NDArray[np.int32]"

View File

@ -5,7 +5,7 @@ import itertools
import time
from collections import defaultdict
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any
from typing import Any
from vllm.config import VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
@ -24,7 +24,12 @@ from vllm.v1.core.encoder_cache_manager import (
)
from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput
from vllm.v1.core.sched.output import (
CachedRequestData,
GrammarOutput,
NewRequestData,
SchedulerOutput,
)
from vllm.v1.core.sched.request_queue import SchedulingPolicy, create_request_queue
from vllm.v1.core.sched.utils import check_stop, remove_all
from vllm.v1.engine import EngineCoreEventType, EngineCoreOutput, EngineCoreOutputs
@ -35,10 +40,6 @@ from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
logger = init_logger(__name__)
@ -619,9 +620,6 @@ class Scheduler(SchedulerInterface):
scheduled_spec_decode_tokens,
req_to_new_blocks,
)
structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask(
num_scheduled_tokens.keys(), scheduled_spec_decode_tokens
)
# Record the request ids that were scheduled in this step.
self.prev_step_scheduled_req_ids.clear()
@ -641,8 +639,6 @@ class Scheduler(SchedulerInterface):
# the previous and the current steps.
finished_req_ids=self.finished_req_ids,
free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(),
structured_output_request_ids=structured_output_request_ids,
grammar_bitmask=grammar_bitmask,
)
# NOTE(Kuntai): this function is designed for multiple purposes:
@ -872,9 +868,8 @@ class Scheduler(SchedulerInterface):
def get_grammar_bitmask(
self,
scheduled_request_ids: Iterable[str],
scheduled_spec_decode_tokens: dict[str, list[int]],
) -> tuple[list[str], "npt.NDArray[np.int32] | None"]:
scheduler_output: SchedulerOutput,
) -> GrammarOutput | None:
# Collect list of scheduled request ids that use structured output.
# The corresponding rows of the bitmask will be in this order.
# PERF: in case of chunked prefill,
@ -883,18 +878,18 @@ class Scheduler(SchedulerInterface):
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids = [
req_id
for req_id in scheduled_request_ids
for req_id in scheduler_output.num_scheduled_tokens
if (req := self.requests.get(req_id)) and req.use_structured_output
]
if not structured_output_request_ids:
return structured_output_request_ids, None
return None
bitmask = self.structured_output_manager.grammar_bitmask(
self.requests,
structured_output_request_ids,
scheduled_spec_decode_tokens,
scheduler_output.scheduled_spec_decode_tokens,
)
return structured_output_request_ids, bitmask
return GrammarOutput(structured_output_request_ids, bitmask)
def update_from_output(
self,

View File

@ -12,7 +12,7 @@ from concurrent.futures import Future
from contextlib import ExitStack, contextmanager
from inspect import isclass, signature
from logging import DEBUG
from typing import Any, TypeVar
from typing import Any, TypeVar, cast
import msgspec
import zmq
@ -334,9 +334,12 @@ class EngineCore:
if not self.scheduler.has_requests():
return {}, False
scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output, non_block=True)
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
with self.log_error_detail(scheduler_output):
model_output = self.model_executor.execute_model(scheduler_output)
model_output = future.result()
if model_output is None:
model_output = self.model_executor.sample_tokens(grammar_output)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
@ -376,20 +379,47 @@ class EngineCore:
assert len(batch_queue) < self.batch_queue_size
model_executed = False
deferred_scheduler_output = None
if self.scheduler.has_requests():
scheduler_output = self.scheduler.schedule()
future = self.model_executor.execute_model(scheduler_output, non_block=True)
batch_queue.appendleft((future, scheduler_output))
exec_future = self.model_executor.execute_model(
scheduler_output, non_block=True
)
model_executed = scheduler_output.total_num_scheduled_tokens > 0
if (
model_executed
and len(batch_queue) < self.batch_queue_size
and not batch_queue[-1][0].done()
):
# Don't block on next worker response unless the queue is full
# or there are no more requests to schedule.
return None, True
if scheduler_output.pending_structured_output_tokens:
# We need to defer sampling until we have processed the model output
# from the prior step.
deferred_scheduler_output = scheduler_output
# Block-wait for execute to return (continues running async on the GPU).
with self.log_error_detail(scheduler_output):
exec_result = exec_future.result()
assert exec_result is None
else:
# We aren't waiting for any tokens, get any grammar output immediately.
grammar_output = self.scheduler.get_grammar_bitmask(scheduler_output)
# Block-wait for execute to return (continues running async on the GPU).
with self.log_error_detail(scheduler_output):
exec_result = exec_future.result()
if exec_result is None:
# Call sample tokens.
future = self.model_executor.sample_tokens(
grammar_output, non_block=True
)
else:
# No sampling required (e.g. all requests finished).
future = cast(Future[ModelRunnerOutput], exec_future)
# Add this step's future to the queue.
batch_queue.appendleft((future, scheduler_output))
if (
model_executed
and len(batch_queue) < self.batch_queue_size
and not batch_queue[-1][0].done()
):
# Don't block on next worker response unless the queue is full
# or there are no more requests to schedule.
return None, True
elif not batch_queue:
# Queue is empty. We should not reach here since this method should
@ -405,6 +435,19 @@ class EngineCore:
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output
)
# NOTE(nick): We can either handle the deferred tasks here or save
# in a field and do it immediately once step_with_batch_queue is
# re-called. The latter slightly favors TTFT over TPOT/throughput.
if deferred_scheduler_output:
# We now have the tokens needed to compute the bitmask for the
# deferred request. Get the bitmask and call sample tokens.
grammar_output = self.scheduler.get_grammar_bitmask(
deferred_scheduler_output
)
future = self.model_executor.sample_tokens(grammar_output, non_block=True)
batch_queue.appendleft((future, deferred_scheduler_output))
return engine_core_outputs, model_executed
def shutdown(self):

View File

@ -16,7 +16,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.tasks import SupportedTask
from vllm.utils.import_utils import resolve_obj_by_qualname
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
@ -187,28 +187,44 @@ class Executor(ABC):
@overload
def execute_model(
self,
scheduler_output: SchedulerOutput,
non_block: Literal[False] = False,
) -> ModelRunnerOutput:
self, scheduler_output: SchedulerOutput, non_block: Literal[False] = False
) -> ModelRunnerOutput | None:
pass
@overload
def execute_model(
self,
scheduler_output: SchedulerOutput,
non_block: Literal[True] = True,
) -> Future[ModelRunnerOutput]:
self, scheduler_output: SchedulerOutput, non_block: Literal[True] = True
) -> Future[ModelRunnerOutput | None]:
pass
def execute_model(
self, scheduler_output: SchedulerOutput, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
output = self.collective_rpc( # type: ignore[call-overload]
"execute_model", args=(scheduler_output,), non_block=non_block
)
return output[0]
@overload
def sample_tokens(
self, grammar_output: GrammarOutput | None, non_block: Literal[False] = False
) -> ModelRunnerOutput:
pass
@overload
def sample_tokens(
self, grammar_output: GrammarOutput | None, non_block: Literal[True] = True
) -> Future[ModelRunnerOutput]:
pass
def sample_tokens(
self, grammar_output: GrammarOutput | None, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
output = self.collective_rpc( # type: ignore[call-overload]
"sample_tokens", args=(grammar_output,), non_block=non_block
)
return output[0]
def execute_dummy_batch(self) -> None:
self.collective_rpc("execute_dummy_batch")

View File

@ -46,7 +46,7 @@ from vllm.utils.system_utils import (
get_mp_context,
set_process_title,
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.executor.abstract import Executor, FailureCallback
from vllm.v1.outputs import AsyncModelRunnerOutput, DraftTokenIds, ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerWrapperBase
@ -132,15 +132,12 @@ class MultiprocExecutor(Executor):
uw.death_writer.close()
self._ensure_worker_termination([uw.proc for uw in unready_workers])
# For pipeline parallel, we use a thread pool for asynchronous
# execute_model.
if self.max_concurrent_batches > 1:
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue
# _async_aggregate_workers_output also assumes a single IO thread
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io"
)
# Note: must use only 1 IO thread to keep dequeue sequence
# from the response queue.
# _async_aggregate_workers_output also assumes a single IO thread.
self.io_thread_pool = ThreadPoolExecutor(
max_workers=1, thread_name_prefix="mp_exec_io"
)
self.output_rank = self._get_output_rank()
self.has_connector = self.vllm_config.kv_transfer_config is not None
@ -180,15 +177,27 @@ class MultiprocExecutor(Executor):
self.failure_callback = callback
def execute_model( # type: ignore[override]
self,
scheduler_output: SchedulerOutput,
non_block: bool = False,
self, scheduler_output: SchedulerOutput, non_block: bool = False
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
return self._execute_with_aggregation(
"execute_model", scheduler_output, non_block=non_block
)
def sample_tokens( # type: ignore[override]
self, grammar_output: GrammarOutput | None, non_block: bool = False
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
return self._execute_with_aggregation( # type: ignore[return-value]
"sample_tokens", grammar_output, non_block=non_block
)
def _execute_with_aggregation(
self, method: str, *args, non_block: bool = False
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
if not self.has_connector:
# get output only from a single worker (output_rank)
(output,) = self.collective_rpc(
"execute_model",
args=(scheduler_output,),
method,
args=args,
unique_reply_rank=self.output_rank,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
@ -197,8 +206,8 @@ class MultiprocExecutor(Executor):
# get output from all workers
outputs = self.collective_rpc(
"execute_model",
args=(scheduler_output,),
method,
args=args,
non_block=non_block,
timeout=envs.VLLM_EXECUTE_MODEL_TIMEOUT_SECONDS,
)

View File

@ -19,7 +19,7 @@ from vllm.utils.network_utils import (
get_ip,
get_open_port,
)
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.executor.abstract import Executor
from vllm.v1.executor.ray_utils import (
@ -41,6 +41,9 @@ if TYPE_CHECKING:
logger = init_logger(__name__)
COMPLETED_NONE_FUTURE: Future[ModelRunnerOutput | None] = Future()
COMPLETED_NONE_FUTURE.set_result(None)
@dataclass
class RayWorkerMetaData:
@ -96,6 +99,8 @@ class RayDistributedExecutor(Executor):
# KV connector setup
self.has_connector = self.vllm_config.kv_transfer_config is not None
self.scheduler_output: SchedulerOutput | None = None
@property
def max_concurrent_batches(self) -> int:
"""Ray distributed executor supports pipeline parallelism,
@ -381,22 +386,46 @@ class RayDistributedExecutor(Executor):
self.shutdown()
def execute_model( # type: ignore[override]
self, scheduler_output: SchedulerOutput, non_block: bool = False
self,
scheduler_output: SchedulerOutput,
non_block: bool = False,
) -> ModelRunnerOutput | None | Future[ModelRunnerOutput | None]:
if self.scheduler_output is not None:
raise RuntimeError(
"State error: sample_tokens() must be called "
"after execute_model() returns None."
)
self.scheduler_output = scheduler_output
return COMPLETED_NONE_FUTURE if non_block else None
def sample_tokens( # type: ignore[override]
self,
grammar_output: "GrammarOutput | None",
non_block: bool = False,
) -> ModelRunnerOutput | Future[ModelRunnerOutput]:
"""Execute the model on the Ray workers.
The scheduler output to use should have been provided in
a prior call to execute_model().
Args:
scheduler_output: The scheduler output to execute.
grammar_output: The structured outputs grammar bitmask, if applicable.
non_block: If True, the method will return a Future.
Returns:
The model runner output.
"""
scheduler_output = self.scheduler_output
if scheduler_output is None:
return None # noqa
self.scheduler_output = None
# Build the compiled DAG for the first time.
if self.forward_dag is None: # type: ignore
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
refs = self.forward_dag.execute(scheduler_output) # type: ignore
refs = self.forward_dag.execute((scheduler_output, grammar_output)) # type: ignore
if not self.has_connector:
# Get output only from a single worker (output_rank)

View File

@ -19,7 +19,7 @@ from vllm.v1.outputs import AsyncModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerWrapperBase
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
logger = init_logger(__name__)
@ -82,36 +82,41 @@ try:
def execute_model_ray(
self,
scheduler_output: Union[
"SchedulerOutput", tuple["SchedulerOutput", "IntermediateTensors"]
],
execute_model_input: tuple["SchedulerOutput", "GrammarOutput"]
| tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
) -> Union[
"ModelRunnerOutput", tuple["SchedulerOutput", "IntermediateTensors"]
"ModelRunnerOutput",
tuple["SchedulerOutput", "GrammarOutput", "IntermediateTensors"],
]:
# This method is used by Ray Compiled Graph to execute the model,
# and it needs a special logic of self.setup_device_if_necessary()
self.setup_device_if_necessary()
assert self.worker is not None, "Worker is not initialized"
if isinstance(scheduler_output, tuple):
scheduler_output, intermediate_tensors = scheduler_output
if len(execute_model_input) == 3:
scheduler_output, grammar_output, intermediate_tensors = (
execute_model_input
)
else:
scheduler_output, intermediate_tensors = scheduler_output, None
scheduler_output, grammar_output = execute_model_input
intermediate_tensors = None
assert self.worker.model_runner is not None
output = self.worker.model_runner.execute_model(
scheduler_output, intermediate_tensors
)
if isinstance(output, IntermediateTensors):
output = scheduler_output, output
output = scheduler_output, grammar_output, output
elif not get_pp_group().is_last_rank:
# Case where there are no scheduled requests
# but may still be finished requests.
assert not output or not output.req_ids
output = scheduler_output, None
# Ensure outputs crossing Ray compiled DAG are serializable.
# AsyncModelRunnerOutput holds CUDA events and cannot be
# pickled.
if isinstance(output, AsyncModelRunnerOutput):
output = output.get_output()
output = scheduler_output, grammar_output, None
elif output is None:
output = self.worker.model_runner.sample_tokens(grammar_output)
# Ensure outputs crossing Ray compiled DAG are serializable.
# AsyncModelRunnerOutput holds CUDA events and cannot be
# pickled.
if isinstance(output, AsyncModelRunnerOutput):
output = output.get_output()
return output
def override_env_vars(self, vars: dict[str, str]):

View File

@ -16,6 +16,7 @@ from diskcache import Cache
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.utils.import_utils import LazyLoader
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
if TYPE_CHECKING:
import outlines_core as oc
@ -24,7 +25,6 @@ if TYPE_CHECKING:
import xgrammar as xgr
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
@ -47,6 +47,7 @@ CACHE = None
def apply_grammar_bitmask(
scheduler_output: SchedulerOutput,
grammar_output: GrammarOutput,
input_batch: InputBatch,
logits: torch.Tensor,
) -> None:
@ -58,9 +59,9 @@ def apply_grammar_bitmask(
input_batch (InputBatch): The input of model runner.
logits (torch.Tensor): The output logits of model forward.
"""
grammar_bitmask = scheduler_output.grammar_bitmask
if grammar_bitmask is None:
return
# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = grammar_output.grammar_bitmask
# We receive the structured output bitmask from the scheduler,
# compacted to contain bitmasks only for structured output requests.
@ -79,7 +80,7 @@ def apply_grammar_bitmask(
cumulative_offset += len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
)
if req_id in scheduler_output.structured_output_request_ids:
if req_id in grammar_output.structured_output_request_ids:
struct_out_req_batch_indices[req_id] = logit_index
out_indices = []
@ -91,7 +92,7 @@ def apply_grammar_bitmask(
dtype=grammar_bitmask.dtype,
)
cumulative_index = 0
for req_id in scheduler_output.structured_output_request_ids:
for req_id in grammar_output.structured_output_request_ids:
num_spec_tokens = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
)
@ -101,22 +102,28 @@ def apply_grammar_bitmask(
sorted_bitmask[logit_index + i] = grammar_bitmask[cumulative_index + i]
out_indices.append(logit_index + i)
cumulative_index += 1 + num_spec_tokens
grammar_bitmask = sorted_bitmask
# Copy async to device as tensor.
grammar_bitmask = torch.from_numpy(sorted_bitmask).to(
logits.device, non_blocking=True
)
# If the length of out indices and the logits have the same shape
# we don't need to pass indices to the kernel,
# since the bitmask is already aligned with the logits.
skip_out_indices = len(out_indices) == logits.shape[0]
# Serialization of np.ndarray is much more efficient than a tensor,
# so we receive it in that format.
grammar_bitmask = torch.from_numpy(grammar_bitmask).contiguous()
index_tensor = None
if not skip_out_indices:
# xgrammar expects a python list of indices but it will actually work with
# a tensor. If we copy the tensor ourselves here we can do it in a non_blocking
# manner and there should be no cpu sync within xgrammar.
index_tensor = torch.tensor(
out_indices, dtype=torch.int32, device="cpu", pin_memory=True
)
index_tensor = index_tensor.to(logits.device, non_blocking=True)
xgr.apply_token_bitmask_inplace(
logits,
grammar_bitmask.to(logits.device, non_blocking=True),
indices=out_indices if not skip_out_indices else None,
)
xgr.apply_token_bitmask_inplace(logits, grammar_bitmask, indices=index_tensor)
class OutlinesVocabulary:

View File

@ -109,6 +109,7 @@ from vllm.v1.outputs import (
EMPTY_MODEL_RUNNER_OUTPUT,
AsyncModelRunnerOutput,
DraftTokenIds,
KVConnectorOutput,
LogprobsLists,
LogprobsTensors,
ModelRunnerOutput,
@ -150,7 +151,7 @@ from .utils import (
if TYPE_CHECKING:
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
logger = init_logger(__name__)
@ -218,6 +219,20 @@ class AsyncGPUModelRunnerOutput(AsyncModelRunnerOutput):
return output
class ExecuteModelState(NamedTuple):
"""Ephemeral cached state transferred between execute_model() and
sample_tokens(), after execute_model() returns None."""
scheduler_output: "SchedulerOutput"
logits: torch.Tensor
spec_decode_metadata: SpecDecodeMetadata | None
spec_decode_common_attn_metadata: CommonAttentionMetadata | None
hidden_states: torch.Tensor
sample_hidden_states: torch.Tensor
aux_hidden_states: list[torch.Tensor] | None
kv_connector_output: KVConnectorOutput | None
class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
def __init__(
self,
@ -509,6 +524,9 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
pin_memory=self.pin_memory,
)
# Ephemeral state transferred between execute_model() and sample_tokens().
self.execute_model_state: ExecuteModelState | None = None
def reset_mm_cache(self) -> None:
if self.mm_budget:
self.mm_budget.reset_cache()
@ -2113,7 +2131,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
num_input_tokens: int, # Padded
intermediate_tensors: IntermediateTensors | None = None,
) -> tuple[
int,
torch.Tensor | None,
torch.Tensor | None,
torch.Tensor,
@ -2207,7 +2224,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
model_kwargs.update(encoder_inputs)
return (
num_scheduled_tokens,
input_ids,
inputs_embeds,
positions,
@ -2425,13 +2441,19 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: IntermediateTensors | None = None,
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
) -> ModelRunnerOutput | IntermediateTensors | None:
if self.execute_model_state is not None:
raise RuntimeError(
"State error: sample_tokens() must be called "
"after execute_model() returns None."
)
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
with record_function_or_nullcontext("Preprocess"):
with self.synchronize_input_prep():
# Update persistent batch states.
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
if not num_scheduled_tokens:
if not has_kv_transfer_group():
# Return empty ModelRunnerOutput if no work to do.
return EMPTY_MODEL_RUNNER_OUTPUT
@ -2471,7 +2493,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
(
num_scheduled_tokens,
input_ids,
inputs_embeds,
positions,
@ -2559,6 +2580,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
# Rare case.
assert not self.is_pooling_model
sample_hidden_states = hidden_states[logits_indices]
if not get_pp_group().is_last_rank:
all_gather_tensors = {
"residual": not is_residual_scattered_for_sp(
@ -2572,7 +2594,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
)
logits = None
else:
sample_hidden_states = hidden_states[logits_indices]
logits = self.model.compute_logits(sample_hidden_states)
model_output_broadcast_data = {}
@ -2585,9 +2606,45 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
assert model_output_broadcast_data is not None
logits = model_output_broadcast_data["logits"]
# Apply structured output bitmasks if present
if scheduler_output.structured_output_request_ids:
apply_grammar_bitmask(scheduler_output, self.input_batch, logits)
self.execute_model_state = ExecuteModelState(
scheduler_output,
logits,
spec_decode_metadata,
spec_decode_common_attn_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
kv_connector_output,
)
return None
@torch.inference_mode
def sample_tokens(
self, grammar_output: "GrammarOutput | None"
) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors:
if self.execute_model_state is None:
# Nothing to do (PP non-final rank case), output isn't used.
return None # noqa
# Unpack ephemeral state.
(
scheduler_output,
logits,
spec_decode_metadata,
spec_decode_common_attn_metadata,
hidden_states,
sample_hidden_states,
aux_hidden_states,
kv_connector_output,
) = self.execute_model_state
# Clear ephemeral state.
self.execute_model_state = None
# Apply structured output bitmasks if present.
if grammar_output is not None:
apply_grammar_bitmask(
scheduler_output, grammar_output, self.input_batch, logits
)
with record_function_or_nullcontext("Sample"):
sampler_output = self._sample(logits, spec_decode_metadata)
@ -2646,7 +2703,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
sampler_output,
logits,
hidden_states,
num_scheduled_tokens,
scheduler_output.total_num_scheduled_tokens,
spec_decode_metadata,
)

View File

@ -6,6 +6,7 @@ import copy
import gc
import os
from contextlib import AbstractContextManager, nullcontext
from types import NoneType
from typing import TYPE_CHECKING, Any
import torch
@ -37,6 +38,7 @@ from vllm.sequence import IntermediateTensors
from vllm.tasks import SupportedTask
from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
from vllm.v1.core.sched.output import GrammarOutput
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import (
@ -508,11 +510,16 @@ class Worker(WorkerBase):
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_runner.get_supported_tasks()
@torch.inference_mode()
def sample_tokens(
self, grammar_output: "GrammarOutput"
) -> ModelRunnerOutput | AsyncModelRunnerOutput:
return self.model_runner.sample_tokens(grammar_output)
@torch.inference_mode()
def execute_model(
self,
scheduler_output: "SchedulerOutput",
) -> ModelRunnerOutput | AsyncModelRunnerOutput | None:
self, scheduler_output: "SchedulerOutput"
) -> ModelRunnerOutput | None:
intermediate_tensors = None
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
@ -531,13 +538,13 @@ class Worker(WorkerBase):
)
output = self.model_runner.execute_model(scheduler_output, intermediate_tensors)
if isinstance(output, (ModelRunnerOutput, AsyncModelRunnerOutput)):
if isinstance(output, (ModelRunnerOutput, NoneType)):
return output
assert isinstance(output, IntermediateTensors)
parallel_config = self.vllm_config.parallel_config
assert (
parallel_config.distributed_executor_backend != ("external_launcher")
parallel_config.distributed_executor_backend != "external_launcher"
and not get_pp_group().is_last_rank
)

View File

@ -92,7 +92,7 @@ from .utils import (
)
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
logger = init_logger(__name__)
@ -372,6 +372,11 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
else:
self.sample_from_logits_func = self.sample_from_logits
# For passing scheduler_output between successive
# execute_model() and sample_tokens() calls.
self.scheduler_output: SchedulerOutput | None = None
self.mm_embed_inputs: tuple[list[torch.Tensor], torch.Tensor] | None = None
def reset_mm_cache(self) -> None:
if self.mm_budget:
self.mm_budget.reset_cache()
@ -1078,7 +1083,12 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self,
scheduler_output: "SchedulerOutput",
intermediate_tensors: IntermediateTensors | None = None,
) -> ModelRunnerOutput:
) -> ModelRunnerOutput | None:
if self.scheduler_output is not None:
raise RuntimeError(
"State error: sample_tokens() must be called "
"after execute_model() returns None."
)
# Update cached state
self._update_states(scheduler_output)
if not scheduler_output.total_num_scheduled_tokens:
@ -1088,14 +1098,30 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return self.kv_connector_no_forward(scheduler_output, self.vllm_config)
mm_embed_inputs = None
if self.supports_mm_inputs:
# Run the multimodal encoder if any.
self._execute_mm_encoder(scheduler_output)
mm_embed_inputs = self._gather_mm_embeddings(scheduler_output)
else:
mm_embed_inputs = None
torch_xla.sync(wait=False)
self.scheduler_output = scheduler_output
self.mm_embed_inputs = mm_embed_inputs
return None
@torch.no_grad()
def sample_tokens(
self, grammar_output: "GrammarOutput | None"
) -> ModelRunnerOutput:
if self.scheduler_output is None:
# Nothing to do (PP non-final rank case), output isn't used.
return None # noqa
scheduler_output = self.scheduler_output
mm_embed_inputs = self.mm_embed_inputs
self.scheduler_output = None
self.mm_embed_inputs = None
# Prepare inputs, the requests might be split into multiple
# executions, combine the result of each execution.
start_index = 0
@ -1131,9 +1157,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
tpu_sampling_metadata = TPUSupportedSamplingMetadata.from_input_batch(
self.input_batch, padded_num_reqs, self.device
)
if scheduler_output.grammar_bitmask is not None:
if grammar_output is not None:
require_struct_decoding, grammar_bitmask_padded, arange = (
self.prepare_structured_decoding_input(logits, scheduler_output)
self.prepare_structured_decoding_input(logits, grammar_output)
)
logits = self.structured_decode(
require_struct_decoding, grammar_bitmask_padded, logits, arange
@ -1954,10 +1980,9 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
return self.model.get_input_embeddings(*args, **kwargs)
def prepare_structured_decoding_input(
self, logits: torch.Tensor, scheduler_output: "SchedulerOutput"
self, logits: torch.Tensor, grammar_output: "GrammarOutput"
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
grammar_bitmask = scheduler_output.grammar_bitmask
assert grammar_bitmask is not None
grammar_bitmask = grammar_output.grammar_bitmask
num_reqs, _ = logits.shape
# Reset pre-allocated tensors
@ -1965,7 +1990,7 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.require_structured_out_cpu.zero_()
cumulative_mask_idx = 0
for req_id in scheduler_output.structured_output_request_ids:
for req_id in grammar_output.structured_output_request_ids:
if req_id not in self.input_batch.req_id_to_index:
continue
batch_index = self.input_batch.req_id_to_index[req_id]

View File

@ -17,7 +17,6 @@ from vllm.distributed import (
)
from vllm.distributed.kv_transfer import (
ensure_kv_transfer_initialized,
has_kv_transfer_group,
)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -27,7 +26,7 @@ from vllm.platforms.tpu import USE_TPU_INFERENCE
from vllm.tasks import SupportedTask
from vllm.utils.math_utils import cdiv
from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import AttentionSpec, KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import report_usage_stats
@ -255,13 +254,13 @@ class TPUWorker:
tpu_kv_cache_bytes = tpu_kv_cache_bytes * head_size // padded_head_size
return int(tpu_kv_cache_bytes)
def sample_tokens(self, grammar_output: "GrammarOutput") -> ModelRunnerOutput:
return self.model_runner.sample_tokens(grammar_output)
def execute_model(
self,
scheduler_output: "SchedulerOutput",
self, scheduler_output: "SchedulerOutput"
) -> ModelRunnerOutput | None:
output = self.model_runner.execute_model(scheduler_output)
# every worker's output is needed when kv_transfer_group is set up
return output if self.is_driver_worker or has_kv_transfer_group() else None
return self.model_runner.execute_model(scheduler_output)
def profile(self, is_start: bool = True):
if self.rank < 1:

View File

@ -20,10 +20,12 @@ from vllm.v1.kv_cache_interface import KVCacheSpec
from vllm.v1.serial_utils import run_method
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.outputs import AsyncModelRunnerOutput, ModelRunnerOutput
else:
SchedulerOutput = object
GrammarOutput = object
AsyncModelRunnerOutput = object
ModelRunnerOutput = object
logger = init_logger(__name__)
@ -122,7 +124,21 @@ class WorkerBase:
"""Load model onto target device."""
raise NotImplementedError
def execute_model(self, scheduler_output: SchedulerOutput) -> ModelRunnerOutput:
def execute_model(
self, scheduler_output: SchedulerOutput
) -> ModelRunnerOutput | None:
"""If this method returns None, sample_tokens should be called immediately after
to obtain the ModelRunnerOutput.
Note that this design may be changed in future if/when structured outputs
parallelism is re-architected.
"""
raise NotImplementedError
def sample_tokens(
self, grammar_output: GrammarOutput
) -> ModelRunnerOutput | AsyncModelRunnerOutput:
"""Should be called immediately after execute_model iff it returned None."""
raise NotImplementedError
def get_cache_block_size_bytes(self) -> int:
@ -344,7 +360,7 @@ class WorkerWrapperBase:
scheduler_output: SchedulerOutput,
*args,
**kwargs,
) -> ModelRunnerOutput:
) -> ModelRunnerOutput | None:
self._apply_mm_cache(scheduler_output)
assert self.worker is not None