mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-07 13:23:07 +08:00
[V1] Refactor Structured Output for multiple backends (#14694)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
parent
46c759c165
commit
3a1e648158
@ -119,16 +119,21 @@ class Processor:
|
||||
def _validate_structured_output(self, params: SamplingParams) -> None:
|
||||
if not params.guided_decoding or not self.decoding_config:
|
||||
return
|
||||
if self.decoding_config.guided_decoding_backend != "xgrammar":
|
||||
raise ValueError(
|
||||
"Only xgrammar structured output is supported in V1.")
|
||||
if (params.guided_decoding.backend
|
||||
and params.guided_decoding.backend != 'xgrammar'):
|
||||
raise ValueError(
|
||||
"Only xgrammar structured output is supported in V1.")
|
||||
if self.vllm_config.speculative_config:
|
||||
raise ValueError("Structured output is not supported with "
|
||||
"speculative decoding.")
|
||||
|
||||
supported_backends = ["xgrammar"]
|
||||
engine_level_backend = self.decoding_config.guided_decoding_backend
|
||||
if engine_level_backend not in supported_backends:
|
||||
raise ValueError(f"Only {supported_backends} structured output is "
|
||||
"supported in V1.")
|
||||
if params.guided_decoding.backend:
|
||||
if params.guided_decoding.backend != engine_level_backend:
|
||||
raise ValueError("Request-level structured output backend "
|
||||
"must match engine-level backend. "
|
||||
f"{params.guided_decoding.backend}"
|
||||
f" != {engine_level_backend}")
|
||||
else:
|
||||
params.guided_decoding.backend = engine_level_backend
|
||||
|
||||
if vllm.platforms.current_platform.is_tpu():
|
||||
raise ValueError("Structured output is not supported on TPU.")
|
||||
|
||||
|
||||
@ -7,75 +7,27 @@ from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils import LazyLoader
|
||||
from vllm.v1.structured_output.grammar import Grammar, StructuredOutputOptions
|
||||
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||
StructuredOutputGrammar)
|
||||
from vllm.v1.structured_output.backend_xgrammar import XgrammarBackend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import xgrammar as xgr
|
||||
import torch
|
||||
|
||||
from vllm.v1.request import Request
|
||||
else:
|
||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class StructuredOutputManager:
|
||||
"""Engine-level manager for structured output requests."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.backend: Optional[StructuredOutputBackend] = None
|
||||
self.vllm_config = vllm_config
|
||||
self.init_complete = False
|
||||
|
||||
def _delayed_init(self):
|
||||
"""Initialization delayed until we know it is needed."""
|
||||
tokenizer_group = init_tokenizer_from_configs(
|
||||
model_config=self.vllm_config.model_config,
|
||||
scheduler_config=self.vllm_config.scheduler_config,
|
||||
parallel_config=self.vllm_config.parallel_config,
|
||||
lora_config=self.vllm_config.lora_config) # type: ignore[arg-type]
|
||||
tokenizer_group.ping()
|
||||
|
||||
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
||||
self.vocab_size = self.vllm_config.model_config.get_vocab_size()
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# NOTE: ideally, xgrammar should handle this accordingly.
|
||||
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
|
||||
try:
|
||||
encoded_vocab = [
|
||||
token for token, _ in sorted(
|
||||
tokenizer.get_vocab().items(),
|
||||
key=lambda x: x[1],
|
||||
)
|
||||
]
|
||||
stop_token_ids = None
|
||||
if hasattr(
|
||||
tokenizer,
|
||||
"eos_token_id",
|
||||
) and tokenizer.eos_token_id is not None:
|
||||
stop_token_ids = [tokenizer.eos_token_id]
|
||||
except AttributeError as e:
|
||||
raise ValueError(
|
||||
f"Cannot get the vocabulary of the tokenizer "
|
||||
f"{type(tokenizer)}. The tokenizer should have a "
|
||||
"get_vocab method.") from e
|
||||
tokenizer_info = xgr.TokenizerInfo(
|
||||
encoded_vocab=encoded_vocab,
|
||||
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
|
||||
vocab_type=xgr.VocabType.BYTE_FALLBACK,
|
||||
vocab_size=self.vocab_size,
|
||||
stop_token_ids=stop_token_ids,
|
||||
add_prefix_space=True,
|
||||
)
|
||||
else:
|
||||
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
|
||||
tokenizer,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
|
||||
self._grammar_bitmask: Optional[torch.Tensor] = None
|
||||
|
||||
# The default max_workers if not specified is the number of CPUs * 5,
|
||||
# which is way too high since these tasks are CPU-bound, not I/O bound.
|
||||
@ -83,28 +35,30 @@ class StructuredOutputManager:
|
||||
# compilation, so we set it to half the number of CPUs.
|
||||
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self._grammar_bitmask = xgr.allocate_token_bitmask(
|
||||
self.vllm_config.scheduler_config.max_num_seqs,
|
||||
self.vocab_size,
|
||||
)
|
||||
|
||||
self.init_complete = True
|
||||
|
||||
def grammar_init(self, request: Request) -> None:
|
||||
if request.structured_output_request is None:
|
||||
return
|
||||
|
||||
# The first time this is called, we need to finish initialization
|
||||
# of xgrammar. We defer it to avoid the import of xgrammar and
|
||||
# initialization cost if it is not going to be used.
|
||||
if not self.init_complete:
|
||||
self._delayed_init()
|
||||
# Initialize the backend the first time it is needed.
|
||||
#
|
||||
# NOTE: We only support a single backend. We do NOT support different
|
||||
# backends on a per-request basis in V1 (for now, anyway...).
|
||||
if self.backend is None:
|
||||
backend_name = request.sampling_params.guided_decoding.backend_name
|
||||
if backend_name == "xgrammar":
|
||||
self.backend = XgrammarBackend(self.vllm_config)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported structured output backend: {backend_name}")
|
||||
|
||||
grammar: Future[Grammar] = self.executor.submit(
|
||||
self._async_create_grammar, request)
|
||||
grammar: Future[StructuredOutputGrammar] = self.executor.submit(
|
||||
self._async_create_grammar, request, self.backend)
|
||||
request.structured_output_request.grammar = grammar # type: ignore[assignment]
|
||||
|
||||
def _async_create_grammar(self, request: Request) -> Grammar:
|
||||
def _async_create_grammar(
|
||||
self, request: Request,
|
||||
backend: StructuredOutputBackend) -> StructuredOutputGrammar:
|
||||
key = request.structured_output_request.structured_output_key # type: ignore[union-attr]
|
||||
|
||||
# Note that the request was validated in the engine core client,
|
||||
@ -114,28 +68,8 @@ class StructuredOutputManager:
|
||||
# though it should be unlikely as we test that up front as well.
|
||||
request_type, grammar_spec = key
|
||||
|
||||
if request_type == StructuredOutputOptions.JSON:
|
||||
# TODO -- allow any_whitespace to be configurable
|
||||
# pending merge of https://github.com/vllm-project/vllm/pull/12744
|
||||
ctx = self.compiler.compile_json_schema(grammar_spec,
|
||||
any_whitespace=False)
|
||||
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
||||
ctx = self.compiler.compile_builtin_json_grammar()
|
||||
elif request_type == StructuredOutputOptions.GRAMMAR:
|
||||
ctx = self.compiler.compile_grammar(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.REGEX:
|
||||
ctx = self.compiler.compile_regex(grammar_spec)
|
||||
else:
|
||||
logger.error("Validation should have already occurred. "
|
||||
"Please file an issue.")
|
||||
raise ValueError(
|
||||
f"grammar is not of valid supported types. ({request_type!s})")
|
||||
|
||||
return Grammar(
|
||||
matcher=xgr.GrammarMatcher(ctx),
|
||||
vocab_size=self.vocab_size,
|
||||
ctx=ctx,
|
||||
)
|
||||
assert self.backend is not None
|
||||
return self.backend.compile_grammar(request_type, grammar_spec)
|
||||
|
||||
def grammar_bitmask(
|
||||
self,
|
||||
@ -147,6 +81,11 @@ class StructuredOutputManager:
|
||||
if not structured_output_request_ids:
|
||||
return None
|
||||
|
||||
if self._grammar_bitmask is None:
|
||||
assert self.backend is not None
|
||||
self._grammar_bitmask = self.backend.allocate_token_bitmask(
|
||||
self.vllm_config.scheduler_config.max_num_seqs)
|
||||
|
||||
# Fill the bitmask using the index of each request equal to its
|
||||
# position in the batch. Resize the bitmask down to the size of
|
||||
# the batch.
|
||||
@ -154,7 +93,7 @@ class StructuredOutputManager:
|
||||
for req_id, batch_index in structured_output_request_ids.items():
|
||||
request = requests[req_id].structured_output_request
|
||||
assert request is not None and request.grammar is not None
|
||||
if not request.grammar.matcher.is_terminated():
|
||||
if not request.grammar.is_terminated():
|
||||
request.grammar.fill_bitmask(bitmask_tensor, batch_index)
|
||||
if batch_len < self._grammar_bitmask.shape[0]:
|
||||
bitmask_tensor = self._grammar_bitmask[:batch_len]
|
||||
|
||||
89
vllm/v1/structured_output/backend_types.py
Normal file
89
vllm/v1/structured_output/backend_types.py
Normal file
@ -0,0 +1,89 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class StructuredOutputOptions(enum.Enum):
|
||||
JSON = enum.auto()
|
||||
JSON_OBJECT = enum.auto()
|
||||
REGEX = enum.auto()
|
||||
GRAMMAR = enum.auto()
|
||||
CHOICE = enum.auto()
|
||||
|
||||
|
||||
StructuredOutputKey = tuple[StructuredOutputOptions, str]
|
||||
|
||||
|
||||
class StructuredOutputGrammar(ABC):
|
||||
"""Request-level backend for structured output requests."""
|
||||
|
||||
@abstractmethod
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
"""
|
||||
Determines whether the provided tokens are accepted for the
|
||||
given request.
|
||||
|
||||
Args:
|
||||
request_id (str): The unique identifier for the request.
|
||||
tokens (list[int]): A list of token IDs to evaluate.
|
||||
|
||||
Returns:
|
||||
bool: True if the tokens are accepted, False otherwise.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, batch_index: int) -> None:
|
||||
"""
|
||||
Fills the bitmask for a specific batch index.
|
||||
|
||||
Args:
|
||||
bitmask (torch.Tensor): The bitmask to fill
|
||||
batch_index (int): The index in the bitmask to fill
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def is_terminated(self) -> bool:
|
||||
"""
|
||||
Checks whether the structured output process has terminated.
|
||||
|
||||
Returns:
|
||||
bool: True if the process is terminated, False otherwise.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def reset(self):
|
||||
"""
|
||||
Resets the state of the structured output grammar.
|
||||
"""
|
||||
|
||||
|
||||
class StructuredOutputBackend(ABC):
|
||||
"""Engine-level backend for structured output requests."""
|
||||
|
||||
@abstractmethod
|
||||
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||
grammar_spec: str) -> StructuredOutputGrammar:
|
||||
"""
|
||||
Compiles a grammar specification into a structured output grammar.
|
||||
|
||||
Args:
|
||||
request_type (StructuredOutputOptions): The type of structured
|
||||
output request.
|
||||
grammar_spec (str): The grammar specification to compile.
|
||||
|
||||
Returns:
|
||||
StructuredOutputGrammar: The compiled structured output grammar.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def allocate_token_bitmask(self, max_num_seqs: int):
|
||||
"""
|
||||
Allocates a token bitmask for the specified maximum number of sequences.
|
||||
|
||||
Args:
|
||||
max_num_seqs (int): The maximum number of sequences for which
|
||||
to allocate the bitmask.
|
||||
"""
|
||||
143
vllm/v1/structured_output/backend_xgrammar.py
Normal file
143
vllm/v1/structured_output/backend_xgrammar.py
Normal file
@ -0,0 +1,143 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
|
||||
from vllm.utils import LazyLoader
|
||||
from vllm.v1.structured_output.backend_types import (StructuredOutputBackend,
|
||||
StructuredOutputGrammar,
|
||||
StructuredOutputOptions)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr
|
||||
else:
|
||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class XgrammarBackend(StructuredOutputBackend):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
self.vllm_config = vllm_config
|
||||
tokenizer_group = init_tokenizer_from_configs(
|
||||
model_config=vllm_config.model_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
parallel_config=vllm_config.parallel_config,
|
||||
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
|
||||
tokenizer_group.ping()
|
||||
|
||||
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
||||
self.vocab_size = vllm_config.model_config.get_vocab_size()
|
||||
if isinstance(tokenizer, MistralTokenizer):
|
||||
# NOTE: ideally, xgrammar should handle this accordingly.
|
||||
# refer to https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98
|
||||
try:
|
||||
encoded_vocab = [
|
||||
token for token, _ in sorted(
|
||||
tokenizer.get_vocab().items(),
|
||||
key=lambda x: x[1],
|
||||
)
|
||||
]
|
||||
stop_token_ids = None
|
||||
if hasattr(
|
||||
tokenizer,
|
||||
"eos_token_id",
|
||||
) and tokenizer.eos_token_id is not None:
|
||||
stop_token_ids = [tokenizer.eos_token_id]
|
||||
except AttributeError as e:
|
||||
raise ValueError(
|
||||
f"Cannot get the vocabulary of the tokenizer "
|
||||
f"{type(tokenizer)}. The tokenizer should have a "
|
||||
"get_vocab method.") from e
|
||||
tokenizer_info = xgr.TokenizerInfo( # type: ignore
|
||||
encoded_vocab=encoded_vocab,
|
||||
# NOTE: https://github.com/mlc-ai/xgrammar/blob/5e141f6ff1ca02bc31f9e512e68b61f2a8ae88e5/tests/python/test_tokenizer_info.py#L43 # noqa: E501
|
||||
vocab_type=xgr.VocabType.BYTE_FALLBACK,
|
||||
vocab_size=self.vocab_size,
|
||||
stop_token_ids=stop_token_ids,
|
||||
add_prefix_space=True,
|
||||
)
|
||||
else:
|
||||
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
|
||||
tokenizer,
|
||||
vocab_size=self.vocab_size,
|
||||
)
|
||||
self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8)
|
||||
|
||||
def compile_grammar(self, request_type: StructuredOutputOptions,
|
||||
grammar_spec: str) -> StructuredOutputGrammar:
|
||||
if request_type == StructuredOutputOptions.JSON:
|
||||
ctx = self.compiler.compile_json_schema(grammar_spec,
|
||||
any_whitespace=False)
|
||||
elif request_type == StructuredOutputOptions.JSON_OBJECT:
|
||||
ctx = self.compiler.compile_builtin_json_grammar()
|
||||
elif request_type == StructuredOutputOptions.GRAMMAR:
|
||||
ctx = self.compiler.compile_grammar(grammar_spec)
|
||||
elif request_type == StructuredOutputOptions.REGEX:
|
||||
ctx = self.compiler.compile_regex(grammar_spec)
|
||||
else:
|
||||
logger.error(
|
||||
"Validation should have already occurred. Please file an issue."
|
||||
)
|
||||
raise ValueError(
|
||||
f"grammar is not of valid supported types. ({request_type!s})")
|
||||
|
||||
return XgrammarGrammar(
|
||||
matcher=xgr.GrammarMatcher(ctx),
|
||||
vocab_size=self.vocab_size,
|
||||
ctx=ctx,
|
||||
)
|
||||
|
||||
def allocate_token_bitmask(self, max_num_seqs: int):
|
||||
return xgr.allocate_token_bitmask(max_num_seqs, self.vocab_size)
|
||||
|
||||
|
||||
@dataclass
|
||||
class XgrammarGrammar(StructuredOutputGrammar):
|
||||
# NOTE: This would be a generic-enough class for
|
||||
# supporting different backends, in the future.
|
||||
# For now, just xgrammar.
|
||||
#
|
||||
# TODO: support max_rollback_tokens
|
||||
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
|
||||
# for jump-forward decoding
|
||||
|
||||
vocab_size: int
|
||||
matcher: xgr.GrammarMatcher = field(hash=False)
|
||||
ctx: xgr.CompiledGrammar = field(hash=False)
|
||||
num_processed_tokens: int = field(default_factory=lambda: 0,
|
||||
repr=False,
|
||||
hash=False,
|
||||
init=False)
|
||||
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
"""Accepts a list of tokens and advances the FSM.
|
||||
|
||||
Returns True if the FSM was advanced successfully.
|
||||
Returns False if the FSM failed to advance.
|
||||
"""
|
||||
for token in tokens:
|
||||
if not self.matcher.accept_token(token):
|
||||
logger.error(
|
||||
"Failed to advance FSM for request %s "
|
||||
"for tokens %s. Please file an issue.", request_id, token)
|
||||
return False
|
||||
self.num_processed_tokens += 1
|
||||
return True
|
||||
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> None:
|
||||
self.matcher.fill_next_token_bitmask(bitmask, idx)
|
||||
|
||||
def is_terminated(self) -> bool:
|
||||
return self.matcher.is_terminated()
|
||||
|
||||
def reset(self):
|
||||
self.num_processed_tokens = 0
|
||||
self.matcher.reset()
|
||||
@ -1,77 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import LazyLoader
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr
|
||||
else:
|
||||
xgr = LazyLoader("xgr", globals(), "xgrammar")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class StructuredOutputOptions(enum.Enum):
|
||||
JSON = enum.auto()
|
||||
JSON_OBJECT = enum.auto()
|
||||
REGEX = enum.auto()
|
||||
GRAMMAR = enum.auto()
|
||||
CHOICE = enum.auto()
|
||||
|
||||
|
||||
StructuredOutputKey = tuple[StructuredOutputOptions, str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Grammar:
|
||||
# NOTE: This would be a generic-enough class for
|
||||
# supporting different backends, in the future.
|
||||
# For now, just xgrammar.
|
||||
#
|
||||
# TODO: support max_rollback_tokens
|
||||
# https://xgrammar.mlc.ai/docs/api/python/index.html#xgrammar.GrammarMatcher.find_jump_forward_string
|
||||
# for jump-forward decoding
|
||||
|
||||
vocab_size: int
|
||||
matcher: xgr.GrammarMatcher = field(hash=False)
|
||||
ctx: xgr.CompiledGrammar = field(hash=False)
|
||||
num_processed_tokens: int = field(default_factory=lambda: 0,
|
||||
repr=False,
|
||||
hash=False,
|
||||
init=False)
|
||||
|
||||
def accept_tokens(self, request_id: str, tokens: list[int]) -> bool:
|
||||
"""Accepts a list of tokens and advances the FSM.
|
||||
|
||||
Returns True if the FSM was advanced successfully.
|
||||
Returns False if the FSM failed to advance.
|
||||
"""
|
||||
for token in tokens:
|
||||
if not self.matcher.accept_token(token):
|
||||
logger.error(
|
||||
"Failed to advance FSM for request %s "
|
||||
"for tokens %s. Please file an issue.", request_id, token)
|
||||
return False
|
||||
self.num_processed_tokens += 1
|
||||
return True
|
||||
|
||||
def fill_bitmask(self, bitmask: torch.Tensor, idx: int) -> bool:
|
||||
return self.matcher.fill_next_token_bitmask(bitmask, idx)
|
||||
|
||||
def reset(self):
|
||||
self.num_processed_tokens = 0
|
||||
self.matcher.reset()
|
||||
|
||||
def __copy__(self):
|
||||
return Grammar(
|
||||
matcher=xgr.GrammarMatcher(self.ctx),
|
||||
vocab_size=self.vocab_size,
|
||||
ctx=self.ctx,
|
||||
)
|
||||
@ -9,15 +9,17 @@ from concurrent.futures._base import TimeoutError
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.structured_output.grammar import (Grammar, StructuredOutputKey,
|
||||
StructuredOutputOptions)
|
||||
from vllm.v1.structured_output.backend_types import (StructuredOutputGrammar,
|
||||
StructuredOutputKey,
|
||||
StructuredOutputOptions)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class StructuredOutputRequest:
|
||||
|
||||
sampling_params: SamplingParams
|
||||
_grammar: Optional[Union[Future[Grammar], Grammar]] = None
|
||||
_grammar: Optional[Union[Future[StructuredOutputGrammar],
|
||||
StructuredOutputGrammar]] = None
|
||||
|
||||
def _check_grammar_completion(self) -> bool:
|
||||
# NOTE: We have to lazy import to gate circular imports
|
||||
@ -37,12 +39,16 @@ class StructuredOutputRequest:
|
||||
return self._check_grammar_completion()
|
||||
|
||||
@property
|
||||
def grammar(self) -> Optional[Grammar]:
|
||||
def grammar(self) -> Optional[StructuredOutputGrammar]:
|
||||
completed = self._check_grammar_completion()
|
||||
return cast(Optional[Grammar], self._grammar) if completed else None
|
||||
return cast(Optional[StructuredOutputGrammar],
|
||||
self._grammar) if completed else None
|
||||
|
||||
@grammar.setter
|
||||
def grammar(self, grammar: Union[Grammar, Future[Grammar]]) -> None:
|
||||
def grammar(
|
||||
self, grammar: Union[StructuredOutputGrammar,
|
||||
Future[StructuredOutputGrammar]]
|
||||
) -> None:
|
||||
self._grammar = grammar
|
||||
|
||||
@functools.cached_property
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user