[Core] Streamline some structured output related code (#26737)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill 2025-10-14 16:27:44 -07:00 committed by GitHub
parent a86b4c58e8
commit 4aed506b65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 121 additions and 138 deletions

View File

@ -30,7 +30,6 @@ from vllm.v1.kv_cache_interface import (
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.structured_output.request import StructuredOutputRequest
from .utils import EOS_TOKEN_ID, create_requests, create_scheduler
@ -335,10 +334,10 @@ def test_stop_via_update_from_output():
requests[0].request_id: [],
requests[1].request_id: [10],
},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)
@ -383,10 +382,10 @@ def test_stop_via_update_from_output():
requests[0].request_id: [10, 42],
requests[1].request_id: [13],
},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)
@ -429,10 +428,10 @@ def test_stop_via_update_from_output():
requests[0].request_id: [10, 11],
requests[1].request_id: [],
},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)
@ -470,10 +469,10 @@ def test_stop_via_update_from_output():
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={requests[0].request_id: [EOS_TOKEN_ID, 10]},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)
@ -1941,7 +1940,6 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
structured_output_request=StructuredOutputRequest(sampling_params),
)
scheduler.add_request(request)
output = scheduler.schedule()

View File

@ -26,7 +26,7 @@ def _make_empty_scheduler_output():
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
kv_connector_metadata=SharedStorageConnectorMetadata(),
)

View File

@ -89,10 +89,10 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)
@ -168,10 +168,10 @@ def test_update_states_request_finished(model_runner):
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids={req_id},
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)
@ -198,10 +198,10 @@ def test_update_states_request_resumed(model_runner):
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)
@ -225,10 +225,10 @@ def test_update_states_request_resumed(model_runner):
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)
@ -256,10 +256,10 @@ def test_update_states_no_changes(model_runner):
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)
@ -291,10 +291,10 @@ def test_update_states_request_unscheduled(model_runner):
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)

View File

@ -146,10 +146,10 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
total_num_scheduled_tokens=total_num_scheduled_tokens,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)
@ -212,10 +212,10 @@ def test_update_states_request_finished(model_runner, dist_init):
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids={req_id},
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)
@ -244,10 +244,10 @@ def test_update_states_request_resumed(model_runner, dist_init):
total_num_scheduled_tokens=0,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)
@ -273,10 +273,10 @@ def test_update_states_request_resumed(model_runner, dist_init):
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)
@ -366,10 +366,10 @@ def test_update_states_no_changes(model_runner, dist_init):
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)
@ -403,10 +403,10 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
total_num_scheduled_tokens=1,
scheduled_spec_decode_tokens={},
scheduled_encoder_inputs={},
num_common_prefix_blocks=0,
num_common_prefix_blocks=[],
finished_req_ids=set(),
free_encoder_mm_hashes=[],
structured_output_request_ids={},
structured_output_request_ids=[],
grammar_bitmask=None,
)

View File

@ -165,9 +165,8 @@ class SchedulerOutput:
# freed from the encoder cache.
free_encoder_mm_hashes: list[str]
# Dict of request ids to their index within the batch
# for filling the next token bitmask
structured_output_request_ids: dict[str, int]
# ids of structured outputs requests included in the bitmask, in order.
structured_output_request_ids: list[str]
# the bitmask for the whole batch
grammar_bitmask: "npt.NDArray[np.int32] | None"

View File

@ -5,7 +5,7 @@ import itertools
import time
from collections import defaultdict
from collections.abc import Iterable
from typing import Any
from typing import TYPE_CHECKING, Any
from vllm.config import VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
@ -34,6 +34,10 @@ from vllm.v1.request import Request, RequestStatus
from vllm.v1.spec_decode.metrics import SpecDecodingStats
from vllm.v1.structured_output import StructuredOutputManager
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
logger = init_logger(__name__)
@ -608,11 +612,8 @@ class Scheduler(SchedulerInterface):
scheduled_spec_decode_tokens,
req_to_new_blocks,
)
scheduled_requests = (
scheduled_new_reqs + scheduled_running_reqs + scheduled_resumed_reqs
)
structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask(
scheduled_requests, scheduled_spec_decode_tokens
num_scheduled_tokens.keys(), scheduled_spec_decode_tokens
)
scheduler_output = SchedulerOutput(
scheduled_new_reqs=new_reqs_data,
@ -876,32 +877,28 @@ class Scheduler(SchedulerInterface):
def get_grammar_bitmask(
self,
requests: list[Request],
scheduled_request_ids: Iterable[str],
scheduled_spec_decode_tokens: dict[str, list[int]],
):
# NOTE: structured_output_request_ids maps
# a request's (request that uses structured output)
# request_id to its index in the batch.
# This will help us determine to slice the grammar bitmask
# and only applies valid mask for requests that
# uses structured decoding.
structured_output_request_ids: dict[str, int] = {}
for i, req in enumerate(requests):
if req.use_structured_output:
# PERF: in case of chunked prefill,
# request might not include any new tokens.
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids[req.request_id] = i
) -> tuple[list[str], "npt.NDArray[np.int32] | None"]:
# Collect list of scheduled request ids that use structured output.
# The corresponding rows of the bitmask will be in this order.
# PERF: in case of chunked prefill,
# request might not include any new tokens.
# Therefore, we might introduce some additional
# cycle to fill in the bitmask, which could be a big no-op.
structured_output_request_ids = [
req_id
for req_id in scheduled_request_ids
if (req := self.requests.get(req_id)) and req.use_structured_output
]
if not structured_output_request_ids:
bitmask = None
else:
bitmask = self.structured_output_manager.grammar_bitmask(
self.requests,
structured_output_request_ids,
scheduled_spec_decode_tokens,
)
return structured_output_request_ids, None
bitmask = self.structured_output_manager.grammar_bitmask(
self.requests,
structured_output_request_ids,
scheduled_spec_decode_tokens,
)
return structured_output_request_ids, bitmask
def update_from_output(
@ -1013,12 +1010,10 @@ class Scheduler(SchedulerInterface):
new_logprobs = logprobs.slice(req_index, req_index + 1)
if new_token_ids and self.structured_output_manager.should_advance(request):
# NOTE: structured_output_request
# should not be None if use_structured_output, we have
# checked above, so safe to ignore type warning
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
req_id, new_token_ids
)
struct_output_request = request.structured_output_request
assert struct_output_request is not None
assert struct_output_request.grammar is not None
struct_output_request.grammar.accept_tokens(req_id, new_token_ids)
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
request.num_nans_in_logits = num_nans_in_logits[req_id]

View File

@ -40,7 +40,6 @@ class Request:
prompt_embeds: torch.Tensor | None = None,
mm_features: list[MultiModalFeatureSpec] | None = None,
lora_request: Optional["LoRARequest"] = None,
structured_output_request: Optional["StructuredOutputRequest"] = None,
cache_salt: str | None = None,
priority: int = 0,
trace_headers: Mapping[str, str] | None = None,
@ -54,11 +53,12 @@ class Request:
# Because of LoRA, the eos token id can be different for each request.
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.structured_output_request = structured_output_request
self.structured_output_request = StructuredOutputRequest.from_sampling_params(
sampling_params
)
self.arrival_time = arrival_time if arrival_time is not None else time.time()
self.status = RequestStatus.WAITING
self.use_structured_output = False
self.events: list[EngineCoreEvent] = []
self.stop_reason: int | str | None = None
@ -72,9 +72,8 @@ class Request:
# Generative models.
assert sampling_params.max_tokens is not None
self.max_tokens = sampling_params.max_tokens
if sampling_params.structured_outputs is not None:
if self.structured_output_request is not None:
self.status = RequestStatus.WAITING_FOR_FSM
self.use_structured_output = True
if sampling_params.extra_args is not None:
self.kv_transfer_params = sampling_params.extra_args.get(
@ -145,11 +144,6 @@ class Request:
eos_token_id=request.eos_token_id,
arrival_time=request.arrival_time,
lora_request=request.lora_request,
structured_output_request=StructuredOutputRequest(
sampling_params=request.sampling_params
)
if request.sampling_params
else None,
cache_salt=request.cache_salt,
priority=request.priority,
trace_headers=request.trace_headers,
@ -170,6 +164,10 @@ class Request:
if self.get_hash_new_full_blocks is not None:
self.block_hashes.extend(self.get_hash_new_full_blocks())
@property
def use_structured_output(self) -> bool:
return self.structured_output_request is not None
@property
def is_output_corrupted(self) -> bool:
return self.num_nans_in_logits > 0

View File

@ -167,7 +167,7 @@ class StructuredOutputManager:
def grammar_bitmask(
self,
requests: dict[str, Request],
structured_output_request_ids: dict[str, int],
structured_output_request_ids: list[str],
scheduled_spec_decode_tokens: dict[str, list[int]],
) -> "npt.NDArray[np.int32] | None":
# Prepare the structured output bitmask for this batch.
@ -196,17 +196,16 @@ class StructuredOutputManager:
# masks for each request, one for each possible bonus token position.
# These are stored inline in the tensor and unpacked by the gpu runner.
cumulative_index = 0
ordered_seq = sorted(structured_output_request_ids.items(), key=lambda x: x[1])
# Optimized parallel filling of bitmasks for
# non-spec, large-batch-size cases
if (
len(ordered_seq) > self.fill_bitmask_parallel_threshold
len(structured_output_request_ids) > self.fill_bitmask_parallel_threshold
and max_num_spec_tokens == 0
):
promises = []
batch = []
for req_id, _ in ordered_seq:
for req_id in structured_output_request_ids:
request = requests[req_id]
structured_output_request = request.structured_output_request
if TYPE_CHECKING:
@ -230,7 +229,7 @@ class StructuredOutputManager:
promise.result()
else:
# Fallback to serial filling of bitmasks for small-batch-size cases
for req_id, _ in ordered_seq:
for req_id in structured_output_request_ids:
request = requests[req_id]
structured_output_request = request.structured_output_request
@ -295,22 +294,21 @@ class StructuredOutputManager:
assert request.structured_output_request.grammar is not None
# by default, we should always advance
# for cases that don't use thinking mode.
if self.reasoner is not None:
structured_req = request.structured_output_request
if structured_req.reasoning_ended:
return True
# Check if reasoning ends in *this* step
if self.reasoner.is_reasoning_end(request.all_token_ids):
# Reasoning just ended, so we shouldn't advance til
# next pass
structured_req.reasoning_ended = True
return False
else:
if self.reasoner is None:
return True
structured_req = request.structured_output_request
if structured_req.reasoning_ended:
return True
# Check if reasoning ends in *this* step
if self.reasoner.is_reasoning_end(request.all_token_ids):
# Reasoning just ended, so we shouldn't advance til
# next pass
structured_req.reasoning_ended = True
return False
def clear_backend(self) -> None:
if self.backend is not None:
self.backend.destroy()

View File

@ -252,7 +252,7 @@ def serialize_guidance_grammar(
def validate_guidance_grammar(
sampling_params: SamplingParams, tokenizer: llguidance.LLTokenizer | None = None
) -> None:
tp, grm = get_structured_output_key(sampling_params)
tp, grm = get_structured_output_key(sampling_params.structured_outputs)
guidance_grm = serialize_guidance_grammar(tp, grm)
err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer)
if err:

View File

@ -7,7 +7,7 @@ from concurrent.futures import Future
from concurrent.futures._base import TimeoutError
from typing import cast
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
from vllm.v1.structured_output.backend_types import (
StructuredOutputGrammar,
StructuredOutputKey,
@ -17,10 +17,19 @@ from vllm.v1.structured_output.backend_types import (
@dataclasses.dataclass
class StructuredOutputRequest:
sampling_params: SamplingParams
params: StructuredOutputsParams
_grammar: Future[StructuredOutputGrammar] | StructuredOutputGrammar | None = None
reasoning_ended: bool | None = None
@staticmethod
def from_sampling_params(
sampling_params: SamplingParams | None,
) -> "StructuredOutputRequest | None":
if sampling_params is None:
return None
params = sampling_params.structured_outputs
return StructuredOutputRequest(params=params) if params else None
def _check_grammar_completion(self) -> bool:
# NOTE: We have to lazy import to gate circular imports
from vllm.v1.request import RequestStatus
@ -53,31 +62,28 @@ class StructuredOutputRequest:
@functools.cached_property
def structured_output_key(self) -> StructuredOutputKey:
return get_structured_output_key(self.sampling_params)
return get_structured_output_key(self.params)
def get_structured_output_key(sampling_params: SamplingParams) -> StructuredOutputKey:
params = sampling_params.structured_outputs
assert params is not None, "params can't be None."
def get_structured_output_key(params: StructuredOutputsParams) -> StructuredOutputKey:
if params.json is not None:
if not isinstance(params.json, str):
json_str = json.dumps(params.json)
else:
json_str = params.json
return (StructuredOutputOptions.JSON, json_str)
elif params.json_object:
return (StructuredOutputOptions.JSON_OBJECT, "")
elif params.regex is not None:
return (StructuredOutputOptions.REGEX, params.regex)
elif params.choice is not None:
return StructuredOutputOptions.JSON, json_str
if params.json_object:
return StructuredOutputOptions.JSON_OBJECT, ""
if params.regex is not None:
return StructuredOutputOptions.REGEX, params.regex
if params.choice is not None:
if not isinstance(params.choice, str):
json_str = json.dumps(params.choice)
else:
json_str = params.choice
return (StructuredOutputOptions.CHOICE, json_str)
elif params.grammar is not None:
return (StructuredOutputOptions.GRAMMAR, params.grammar)
elif params.structural_tag is not None:
return (StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag)
else:
raise ValueError("No valid structured output parameter found")
return StructuredOutputOptions.CHOICE, json_str
if params.grammar is not None:
return StructuredOutputOptions.GRAMMAR, params.grammar
if params.structural_tag is not None:
return StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag
raise ValueError("No valid structured output parameter found")

View File

@ -47,7 +47,6 @@ def apply_grammar_bitmask(
scheduler_output: SchedulerOutput,
input_batch: InputBatch,
logits: torch.Tensor,
device: torch.device,
) -> None:
"""
Apply grammar bitmask to output logits of the model with xgrammar function.
@ -56,7 +55,6 @@ def apply_grammar_bitmask(
scheduler_output (SchedulerOutput): The result of engine scheduling.
input_batch (InputBatch): The input of model runner.
logits (torch.Tensor): The output logits of model forward.
device (torch.device): The device that model runner running on.
"""
grammar_bitmask = scheduler_output.grammar_bitmask
if grammar_bitmask is None:
@ -91,10 +89,7 @@ def apply_grammar_bitmask(
dtype=grammar_bitmask.dtype,
)
cumulative_index = 0
seq = sorted(
scheduler_output.structured_output_request_ids.items(), key=lambda x: x[1]
)
for req_id, _ in seq:
for req_id in scheduler_output.structured_output_request_ids:
num_spec_tokens = len(
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
)
@ -117,7 +112,7 @@ def apply_grammar_bitmask(
xgr.apply_token_bitmask_inplace(
logits,
grammar_bitmask.to(device, non_blocking=True),
grammar_bitmask.to(logits.device, non_blocking=True),
indices=out_indices if not skip_out_indices else None,
)

View File

@ -2568,10 +2568,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
logits = model_output_broadcast_data["logits"]
# Apply structured output bitmasks if present
if scheduler_output.grammar_bitmask is not None:
apply_grammar_bitmask(
scheduler_output, self.input_batch, logits, self.device
)
if scheduler_output.structured_output_request_ids:
apply_grammar_bitmask(scheduler_output, self.input_batch, logits)
with record_function_or_nullcontext("Sample"):
sampler_output = self._sample(logits, spec_decode_metadata)

View File

@ -1963,12 +1963,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
self.grammar_bitmask_cpu.zero_()
self.require_structured_out_cpu.zero_()
sorted_struct_requests = sorted(
scheduler_output.structured_output_request_ids.items(),
key=lambda item: item[1],
)
cumulative_mask_idx = 0
for req_id, _ in sorted_struct_requests:
for req_id in scheduler_output.structured_output_request_ids:
if req_id not in self.input_batch.req_id_to_index:
continue
batch_index = self.input_batch.req_id_to_index[req_id]