[V1] Delay all xgrammar usage until needed (#14616)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant 2025-03-11 16:21:33 -04:00 committed by GitHub
parent 53056731fd
commit 61a01b27a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,7 +14,6 @@ from vllm.v1.structured_output.grammar import Grammar, StructuredOutputOptions
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
import torch
import xgrammar as xgr
from vllm.v1.request import Request
@ -27,14 +26,18 @@ logger = init_logger(__name__)
class StructuredOutputManager:
def __init__(self, vllm_config: VllmConfig):
tokenizer_group = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
tokenizer_group.ping()
self.vocab_size = vllm_config.model_config.get_vocab_size()
self.vllm_config = vllm_config
self.init_complete = False
def _delayed_init(self):
"""Initialization delayed until we know it is needed."""
tokenizer_group = init_tokenizer_from_configs(
model_config=self.vllm_config.model_config,
scheduler_config=self.vllm_config.scheduler_config,
parallel_config=self.vllm_config.parallel_config,
lora_config=self.vllm_config.lora_config) # type: ignore[arg-type]
tokenizer_group.ping()
tokenizer = tokenizer_group.get_lora_tokenizer(None)
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
@ -47,12 +50,21 @@ class StructuredOutputManager:
# compilation, so we set it to half the number of CPUs.
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self._grammar_bitmask: Optional[torch.Tensor] = None
self._grammar_bitmask = xgr.allocate_token_bitmask(
self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size)
self.init_complete = True
def grammar_init(self, request: Request) -> None:
if request.structured_output_request is None:
return
# The first time this is called, we need to finish initialization
# of xgrammar. We defer it to avoid the import of xgrammar and
# initialization cost if it is not going to be used.
if not self.init_complete:
self._delayed_init()
grammar: Future[Grammar] = self.executor.submit(
self._async_create_grammar, request)
request.structured_output_request.grammar = grammar # type: ignore[assignment]
@ -100,11 +112,6 @@ class StructuredOutputManager:
if not structured_output_request_ids:
return None
if self._grammar_bitmask is None:
self._grammar_bitmask = xgr.allocate_token_bitmask(
self.vllm_config.scheduler_config.max_num_seqs,
self.vocab_size)
# Fill the bitmask using the index of each request equal to its
# position in the batch. Resize the bitmask down to the size of
# the batch.