mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-26 12:24:28 +08:00
[Core] Streamline some structured output related code (#26737)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
a86b4c58e8
commit
4aed506b65
@ -30,7 +30,6 @@ from vllm.v1.kv_cache_interface import (
|
||||
from vllm.v1.outputs import DraftTokenIds, ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
from vllm.v1.structured_output.request import StructuredOutputRequest
|
||||
|
||||
from .utils import EOS_TOKEN_ID, create_requests, create_scheduler
|
||||
|
||||
@ -335,10 +334,10 @@ def test_stop_via_update_from_output():
|
||||
requests[0].request_id: [],
|
||||
requests[1].request_id: [10],
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
@ -383,10 +382,10 @@ def test_stop_via_update_from_output():
|
||||
requests[0].request_id: [10, 42],
|
||||
requests[1].request_id: [13],
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
@ -429,10 +428,10 @@ def test_stop_via_update_from_output():
|
||||
requests[0].request_id: [10, 11],
|
||||
requests[1].request_id: [],
|
||||
},
|
||||
num_common_prefix_blocks=0,
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
@ -470,10 +469,10 @@ def test_stop_via_update_from_output():
|
||||
total_num_scheduled_tokens=3,
|
||||
scheduled_encoder_inputs={},
|
||||
scheduled_spec_decode_tokens={requests[0].request_id: [EOS_TOKEN_ID, 10]},
|
||||
num_common_prefix_blocks=0,
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
@ -1941,7 +1940,6 @@ def test_schedule_skip_tokenizer_init_structured_output_request():
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
structured_output_request=StructuredOutputRequest(sampling_params),
|
||||
)
|
||||
scheduler.add_request(request)
|
||||
output = scheduler.schedule()
|
||||
|
||||
@ -26,7 +26,7 @@ def _make_empty_scheduler_output():
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
kv_connector_metadata=SharedStorageConnectorMetadata(),
|
||||
)
|
||||
|
||||
@ -89,10 +89,10 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
@ -168,10 +168,10 @@ def test_update_states_request_finished(model_runner):
|
||||
total_num_scheduled_tokens=0,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids={req_id},
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
@ -198,10 +198,10 @@ def test_update_states_request_resumed(model_runner):
|
||||
total_num_scheduled_tokens=0,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
@ -225,10 +225,10 @@ def test_update_states_request_resumed(model_runner):
|
||||
total_num_scheduled_tokens=1,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
@ -256,10 +256,10 @@ def test_update_states_no_changes(model_runner):
|
||||
total_num_scheduled_tokens=1,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
@ -291,10 +291,10 @@ def test_update_states_request_unscheduled(model_runner):
|
||||
total_num_scheduled_tokens=1,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
|
||||
@ -146,10 +146,10 @@ def _schedule_new_request(*req_ids: str) -> SchedulerOutput:
|
||||
total_num_scheduled_tokens=total_num_scheduled_tokens,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
@ -212,10 +212,10 @@ def test_update_states_request_finished(model_runner, dist_init):
|
||||
total_num_scheduled_tokens=0,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids={req_id},
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
@ -244,10 +244,10 @@ def test_update_states_request_resumed(model_runner, dist_init):
|
||||
total_num_scheduled_tokens=0,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
@ -273,10 +273,10 @@ def test_update_states_request_resumed(model_runner, dist_init):
|
||||
total_num_scheduled_tokens=1,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
@ -366,10 +366,10 @@ def test_update_states_no_changes(model_runner, dist_init):
|
||||
total_num_scheduled_tokens=1,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
@ -403,10 +403,10 @@ def test_update_states_request_unscheduled(model_runner, dist_init):
|
||||
total_num_scheduled_tokens=1,
|
||||
scheduled_spec_decode_tokens={},
|
||||
scheduled_encoder_inputs={},
|
||||
num_common_prefix_blocks=0,
|
||||
num_common_prefix_blocks=[],
|
||||
finished_req_ids=set(),
|
||||
free_encoder_mm_hashes=[],
|
||||
structured_output_request_ids={},
|
||||
structured_output_request_ids=[],
|
||||
grammar_bitmask=None,
|
||||
)
|
||||
|
||||
|
||||
@ -165,9 +165,8 @@ class SchedulerOutput:
|
||||
# freed from the encoder cache.
|
||||
free_encoder_mm_hashes: list[str]
|
||||
|
||||
# Dict of request ids to their index within the batch
|
||||
# for filling the next token bitmask
|
||||
structured_output_request_ids: dict[str, int]
|
||||
# ids of structured outputs requests included in the bitmask, in order.
|
||||
structured_output_request_ids: list[str]
|
||||
# the bitmask for the whole batch
|
||||
grammar_bitmask: "npt.NDArray[np.int32] | None"
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ import itertools
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
|
||||
@ -34,6 +34,10 @@ from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.spec_decode.metrics import SpecDecodingStats
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -608,11 +612,8 @@ class Scheduler(SchedulerInterface):
|
||||
scheduled_spec_decode_tokens,
|
||||
req_to_new_blocks,
|
||||
)
|
||||
scheduled_requests = (
|
||||
scheduled_new_reqs + scheduled_running_reqs + scheduled_resumed_reqs
|
||||
)
|
||||
structured_output_request_ids, grammar_bitmask = self.get_grammar_bitmask(
|
||||
scheduled_requests, scheduled_spec_decode_tokens
|
||||
num_scheduled_tokens.keys(), scheduled_spec_decode_tokens
|
||||
)
|
||||
scheduler_output = SchedulerOutput(
|
||||
scheduled_new_reqs=new_reqs_data,
|
||||
@ -876,32 +877,28 @@ class Scheduler(SchedulerInterface):
|
||||
|
||||
def get_grammar_bitmask(
|
||||
self,
|
||||
requests: list[Request],
|
||||
scheduled_request_ids: Iterable[str],
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]],
|
||||
):
|
||||
# NOTE: structured_output_request_ids maps
|
||||
# a request's (request that uses structured output)
|
||||
# request_id to its index in the batch.
|
||||
# This will help us determine to slice the grammar bitmask
|
||||
# and only applies valid mask for requests that
|
||||
# uses structured decoding.
|
||||
structured_output_request_ids: dict[str, int] = {}
|
||||
for i, req in enumerate(requests):
|
||||
if req.use_structured_output:
|
||||
# PERF: in case of chunked prefill,
|
||||
# request might not include any new tokens.
|
||||
# Therefore, we might introduce some additional
|
||||
# cycle to fill in the bitmask, which could be a big no-op.
|
||||
structured_output_request_ids[req.request_id] = i
|
||||
|
||||
) -> tuple[list[str], "npt.NDArray[np.int32] | None"]:
|
||||
# Collect list of scheduled request ids that use structured output.
|
||||
# The corresponding rows of the bitmask will be in this order.
|
||||
# PERF: in case of chunked prefill,
|
||||
# request might not include any new tokens.
|
||||
# Therefore, we might introduce some additional
|
||||
# cycle to fill in the bitmask, which could be a big no-op.
|
||||
structured_output_request_ids = [
|
||||
req_id
|
||||
for req_id in scheduled_request_ids
|
||||
if (req := self.requests.get(req_id)) and req.use_structured_output
|
||||
]
|
||||
if not structured_output_request_ids:
|
||||
bitmask = None
|
||||
else:
|
||||
bitmask = self.structured_output_manager.grammar_bitmask(
|
||||
self.requests,
|
||||
structured_output_request_ids,
|
||||
scheduled_spec_decode_tokens,
|
||||
)
|
||||
return structured_output_request_ids, None
|
||||
|
||||
bitmask = self.structured_output_manager.grammar_bitmask(
|
||||
self.requests,
|
||||
structured_output_request_ids,
|
||||
scheduled_spec_decode_tokens,
|
||||
)
|
||||
return structured_output_request_ids, bitmask
|
||||
|
||||
def update_from_output(
|
||||
@ -1013,12 +1010,10 @@ class Scheduler(SchedulerInterface):
|
||||
new_logprobs = logprobs.slice(req_index, req_index + 1)
|
||||
|
||||
if new_token_ids and self.structured_output_manager.should_advance(request):
|
||||
# NOTE: structured_output_request
|
||||
# should not be None if use_structured_output, we have
|
||||
# checked above, so safe to ignore type warning
|
||||
request.structured_output_request.grammar.accept_tokens( # type: ignore[union-attr]
|
||||
req_id, new_token_ids
|
||||
)
|
||||
struct_output_request = request.structured_output_request
|
||||
assert struct_output_request is not None
|
||||
assert struct_output_request.grammar is not None
|
||||
struct_output_request.grammar.accept_tokens(req_id, new_token_ids)
|
||||
|
||||
if num_nans_in_logits is not None and req_id in num_nans_in_logits:
|
||||
request.num_nans_in_logits = num_nans_in_logits[req_id]
|
||||
|
||||
@ -40,7 +40,6 @@ class Request:
|
||||
prompt_embeds: torch.Tensor | None = None,
|
||||
mm_features: list[MultiModalFeatureSpec] | None = None,
|
||||
lora_request: Optional["LoRARequest"] = None,
|
||||
structured_output_request: Optional["StructuredOutputRequest"] = None,
|
||||
cache_salt: str | None = None,
|
||||
priority: int = 0,
|
||||
trace_headers: Mapping[str, str] | None = None,
|
||||
@ -54,11 +53,12 @@ class Request:
|
||||
# Because of LoRA, the eos token id can be different for each request.
|
||||
self.eos_token_id = eos_token_id
|
||||
self.lora_request = lora_request
|
||||
self.structured_output_request = structured_output_request
|
||||
self.structured_output_request = StructuredOutputRequest.from_sampling_params(
|
||||
sampling_params
|
||||
)
|
||||
self.arrival_time = arrival_time if arrival_time is not None else time.time()
|
||||
|
||||
self.status = RequestStatus.WAITING
|
||||
self.use_structured_output = False
|
||||
self.events: list[EngineCoreEvent] = []
|
||||
self.stop_reason: int | str | None = None
|
||||
|
||||
@ -72,9 +72,8 @@ class Request:
|
||||
# Generative models.
|
||||
assert sampling_params.max_tokens is not None
|
||||
self.max_tokens = sampling_params.max_tokens
|
||||
if sampling_params.structured_outputs is not None:
|
||||
if self.structured_output_request is not None:
|
||||
self.status = RequestStatus.WAITING_FOR_FSM
|
||||
self.use_structured_output = True
|
||||
|
||||
if sampling_params.extra_args is not None:
|
||||
self.kv_transfer_params = sampling_params.extra_args.get(
|
||||
@ -145,11 +144,6 @@ class Request:
|
||||
eos_token_id=request.eos_token_id,
|
||||
arrival_time=request.arrival_time,
|
||||
lora_request=request.lora_request,
|
||||
structured_output_request=StructuredOutputRequest(
|
||||
sampling_params=request.sampling_params
|
||||
)
|
||||
if request.sampling_params
|
||||
else None,
|
||||
cache_salt=request.cache_salt,
|
||||
priority=request.priority,
|
||||
trace_headers=request.trace_headers,
|
||||
@ -170,6 +164,10 @@ class Request:
|
||||
if self.get_hash_new_full_blocks is not None:
|
||||
self.block_hashes.extend(self.get_hash_new_full_blocks())
|
||||
|
||||
@property
|
||||
def use_structured_output(self) -> bool:
|
||||
return self.structured_output_request is not None
|
||||
|
||||
@property
|
||||
def is_output_corrupted(self) -> bool:
|
||||
return self.num_nans_in_logits > 0
|
||||
|
||||
@ -167,7 +167,7 @@ class StructuredOutputManager:
|
||||
def grammar_bitmask(
|
||||
self,
|
||||
requests: dict[str, Request],
|
||||
structured_output_request_ids: dict[str, int],
|
||||
structured_output_request_ids: list[str],
|
||||
scheduled_spec_decode_tokens: dict[str, list[int]],
|
||||
) -> "npt.NDArray[np.int32] | None":
|
||||
# Prepare the structured output bitmask for this batch.
|
||||
@ -196,17 +196,16 @@ class StructuredOutputManager:
|
||||
# masks for each request, one for each possible bonus token position.
|
||||
# These are stored inline in the tensor and unpacked by the gpu runner.
|
||||
cumulative_index = 0
|
||||
ordered_seq = sorted(structured_output_request_ids.items(), key=lambda x: x[1])
|
||||
|
||||
# Optimized parallel filling of bitmasks for
|
||||
# non-spec, large-batch-size cases
|
||||
if (
|
||||
len(ordered_seq) > self.fill_bitmask_parallel_threshold
|
||||
len(structured_output_request_ids) > self.fill_bitmask_parallel_threshold
|
||||
and max_num_spec_tokens == 0
|
||||
):
|
||||
promises = []
|
||||
batch = []
|
||||
for req_id, _ in ordered_seq:
|
||||
for req_id in structured_output_request_ids:
|
||||
request = requests[req_id]
|
||||
structured_output_request = request.structured_output_request
|
||||
if TYPE_CHECKING:
|
||||
@ -230,7 +229,7 @@ class StructuredOutputManager:
|
||||
promise.result()
|
||||
else:
|
||||
# Fallback to serial filling of bitmasks for small-batch-size cases
|
||||
for req_id, _ in ordered_seq:
|
||||
for req_id in structured_output_request_ids:
|
||||
request = requests[req_id]
|
||||
structured_output_request = request.structured_output_request
|
||||
|
||||
@ -295,22 +294,21 @@ class StructuredOutputManager:
|
||||
assert request.structured_output_request.grammar is not None
|
||||
# by default, we should always advance
|
||||
# for cases that don't use thinking mode.
|
||||
if self.reasoner is not None:
|
||||
structured_req = request.structured_output_request
|
||||
|
||||
if structured_req.reasoning_ended:
|
||||
return True
|
||||
|
||||
# Check if reasoning ends in *this* step
|
||||
if self.reasoner.is_reasoning_end(request.all_token_ids):
|
||||
# Reasoning just ended, so we shouldn't advance til
|
||||
# next pass
|
||||
structured_req.reasoning_ended = True
|
||||
|
||||
return False
|
||||
else:
|
||||
if self.reasoner is None:
|
||||
return True
|
||||
|
||||
structured_req = request.structured_output_request
|
||||
if structured_req.reasoning_ended:
|
||||
return True
|
||||
|
||||
# Check if reasoning ends in *this* step
|
||||
if self.reasoner.is_reasoning_end(request.all_token_ids):
|
||||
# Reasoning just ended, so we shouldn't advance til
|
||||
# next pass
|
||||
structured_req.reasoning_ended = True
|
||||
|
||||
return False
|
||||
|
||||
def clear_backend(self) -> None:
|
||||
if self.backend is not None:
|
||||
self.backend.destroy()
|
||||
|
||||
@ -252,7 +252,7 @@ def serialize_guidance_grammar(
|
||||
def validate_guidance_grammar(
|
||||
sampling_params: SamplingParams, tokenizer: llguidance.LLTokenizer | None = None
|
||||
) -> None:
|
||||
tp, grm = get_structured_output_key(sampling_params)
|
||||
tp, grm = get_structured_output_key(sampling_params.structured_outputs)
|
||||
guidance_grm = serialize_guidance_grammar(tp, grm)
|
||||
err = llguidance.LLMatcher.validate_grammar(guidance_grm, tokenizer)
|
||||
if err:
|
||||
|
||||
@ -7,7 +7,7 @@ from concurrent.futures import Future
|
||||
from concurrent.futures._base import TimeoutError
|
||||
from typing import cast
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
||||
from vllm.v1.structured_output.backend_types import (
|
||||
StructuredOutputGrammar,
|
||||
StructuredOutputKey,
|
||||
@ -17,10 +17,19 @@ from vllm.v1.structured_output.backend_types import (
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StructuredOutputRequest:
|
||||
sampling_params: SamplingParams
|
||||
params: StructuredOutputsParams
|
||||
_grammar: Future[StructuredOutputGrammar] | StructuredOutputGrammar | None = None
|
||||
reasoning_ended: bool | None = None
|
||||
|
||||
@staticmethod
|
||||
def from_sampling_params(
|
||||
sampling_params: SamplingParams | None,
|
||||
) -> "StructuredOutputRequest | None":
|
||||
if sampling_params is None:
|
||||
return None
|
||||
params = sampling_params.structured_outputs
|
||||
return StructuredOutputRequest(params=params) if params else None
|
||||
|
||||
def _check_grammar_completion(self) -> bool:
|
||||
# NOTE: We have to lazy import to gate circular imports
|
||||
from vllm.v1.request import RequestStatus
|
||||
@ -53,31 +62,28 @@ class StructuredOutputRequest:
|
||||
|
||||
@functools.cached_property
|
||||
def structured_output_key(self) -> StructuredOutputKey:
|
||||
return get_structured_output_key(self.sampling_params)
|
||||
return get_structured_output_key(self.params)
|
||||
|
||||
|
||||
def get_structured_output_key(sampling_params: SamplingParams) -> StructuredOutputKey:
|
||||
params = sampling_params.structured_outputs
|
||||
assert params is not None, "params can't be None."
|
||||
def get_structured_output_key(params: StructuredOutputsParams) -> StructuredOutputKey:
|
||||
if params.json is not None:
|
||||
if not isinstance(params.json, str):
|
||||
json_str = json.dumps(params.json)
|
||||
else:
|
||||
json_str = params.json
|
||||
return (StructuredOutputOptions.JSON, json_str)
|
||||
elif params.json_object:
|
||||
return (StructuredOutputOptions.JSON_OBJECT, "")
|
||||
elif params.regex is not None:
|
||||
return (StructuredOutputOptions.REGEX, params.regex)
|
||||
elif params.choice is not None:
|
||||
return StructuredOutputOptions.JSON, json_str
|
||||
if params.json_object:
|
||||
return StructuredOutputOptions.JSON_OBJECT, ""
|
||||
if params.regex is not None:
|
||||
return StructuredOutputOptions.REGEX, params.regex
|
||||
if params.choice is not None:
|
||||
if not isinstance(params.choice, str):
|
||||
json_str = json.dumps(params.choice)
|
||||
else:
|
||||
json_str = params.choice
|
||||
return (StructuredOutputOptions.CHOICE, json_str)
|
||||
elif params.grammar is not None:
|
||||
return (StructuredOutputOptions.GRAMMAR, params.grammar)
|
||||
elif params.structural_tag is not None:
|
||||
return (StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag)
|
||||
else:
|
||||
raise ValueError("No valid structured output parameter found")
|
||||
return StructuredOutputOptions.CHOICE, json_str
|
||||
if params.grammar is not None:
|
||||
return StructuredOutputOptions.GRAMMAR, params.grammar
|
||||
if params.structural_tag is not None:
|
||||
return StructuredOutputOptions.STRUCTURAL_TAG, params.structural_tag
|
||||
raise ValueError("No valid structured output parameter found")
|
||||
|
||||
@ -47,7 +47,6 @@ def apply_grammar_bitmask(
|
||||
scheduler_output: SchedulerOutput,
|
||||
input_batch: InputBatch,
|
||||
logits: torch.Tensor,
|
||||
device: torch.device,
|
||||
) -> None:
|
||||
"""
|
||||
Apply grammar bitmask to output logits of the model with xgrammar function.
|
||||
@ -56,7 +55,6 @@ def apply_grammar_bitmask(
|
||||
scheduler_output (SchedulerOutput): The result of engine scheduling.
|
||||
input_batch (InputBatch): The input of model runner.
|
||||
logits (torch.Tensor): The output logits of model forward.
|
||||
device (torch.device): The device that model runner running on.
|
||||
"""
|
||||
grammar_bitmask = scheduler_output.grammar_bitmask
|
||||
if grammar_bitmask is None:
|
||||
@ -91,10 +89,7 @@ def apply_grammar_bitmask(
|
||||
dtype=grammar_bitmask.dtype,
|
||||
)
|
||||
cumulative_index = 0
|
||||
seq = sorted(
|
||||
scheduler_output.structured_output_request_ids.items(), key=lambda x: x[1]
|
||||
)
|
||||
for req_id, _ in seq:
|
||||
for req_id in scheduler_output.structured_output_request_ids:
|
||||
num_spec_tokens = len(
|
||||
scheduler_output.scheduled_spec_decode_tokens.get(req_id, [])
|
||||
)
|
||||
@ -117,7 +112,7 @@ def apply_grammar_bitmask(
|
||||
|
||||
xgr.apply_token_bitmask_inplace(
|
||||
logits,
|
||||
grammar_bitmask.to(device, non_blocking=True),
|
||||
grammar_bitmask.to(logits.device, non_blocking=True),
|
||||
indices=out_indices if not skip_out_indices else None,
|
||||
)
|
||||
|
||||
|
||||
@ -2568,10 +2568,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
logits = model_output_broadcast_data["logits"]
|
||||
|
||||
# Apply structured output bitmasks if present
|
||||
if scheduler_output.grammar_bitmask is not None:
|
||||
apply_grammar_bitmask(
|
||||
scheduler_output, self.input_batch, logits, self.device
|
||||
)
|
||||
if scheduler_output.structured_output_request_ids:
|
||||
apply_grammar_bitmask(scheduler_output, self.input_batch, logits)
|
||||
|
||||
with record_function_or_nullcontext("Sample"):
|
||||
sampler_output = self._sample(logits, spec_decode_metadata)
|
||||
|
||||
@ -1963,12 +1963,8 @@ class TPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.grammar_bitmask_cpu.zero_()
|
||||
self.require_structured_out_cpu.zero_()
|
||||
|
||||
sorted_struct_requests = sorted(
|
||||
scheduler_output.structured_output_request_ids.items(),
|
||||
key=lambda item: item[1],
|
||||
)
|
||||
cumulative_mask_idx = 0
|
||||
for req_id, _ in sorted_struct_requests:
|
||||
for req_id in scheduler_output.structured_output_request_ids:
|
||||
if req_id not in self.input_batch.req_id_to_index:
|
||||
continue
|
||||
batch_index = self.input_batch.req_id_to_index[req_id]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user