[V1] Delay all xgrammar usage until needed (#14616)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant 2025-03-11 16:21:33 -04:00 committed by GitHub
parent 53056731fd
commit 61a01b27a7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -14,7 +14,6 @@ from vllm.v1.structured_output.grammar import Grammar, StructuredOutputOptions
if TYPE_CHECKING: if TYPE_CHECKING:
import numpy as np import numpy as np
import numpy.typing as npt import numpy.typing as npt
import torch
import xgrammar as xgr import xgrammar as xgr
from vllm.v1.request import Request from vllm.v1.request import Request
@ -27,14 +26,18 @@ logger = init_logger(__name__)
class StructuredOutputManager: class StructuredOutputManager:
def __init__(self, vllm_config: VllmConfig): def __init__(self, vllm_config: VllmConfig):
tokenizer_group = init_tokenizer_from_configs(
model_config=vllm_config.model_config,
scheduler_config=vllm_config.scheduler_config,
parallel_config=vllm_config.parallel_config,
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
tokenizer_group.ping()
self.vocab_size = vllm_config.model_config.get_vocab_size() self.vocab_size = vllm_config.model_config.get_vocab_size()
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.init_complete = False
def _delayed_init(self):
"""Initialization delayed until we know it is needed."""
tokenizer_group = init_tokenizer_from_configs(
model_config=self.vllm_config.model_config,
scheduler_config=self.vllm_config.scheduler_config,
parallel_config=self.vllm_config.parallel_config,
lora_config=self.vllm_config.lora_config) # type: ignore[arg-type]
tokenizer_group.ping()
tokenizer = tokenizer_group.get_lora_tokenizer(None) tokenizer = tokenizer_group.get_lora_tokenizer(None)
tokenizer_info = xgr.TokenizerInfo.from_huggingface( tokenizer_info = xgr.TokenizerInfo.from_huggingface(
@ -47,12 +50,21 @@ class StructuredOutputManager:
# compilation, so we set it to half the number of CPUs. # compilation, so we set it to half the number of CPUs.
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers) self.executor = ThreadPoolExecutor(max_workers=max_workers)
self._grammar_bitmask: Optional[torch.Tensor] = None self._grammar_bitmask = xgr.allocate_token_bitmask(
self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size)
self.init_complete = True
def grammar_init(self, request: Request) -> None: def grammar_init(self, request: Request) -> None:
if request.structured_output_request is None: if request.structured_output_request is None:
return return
# The first time this is called, we need to finish initialization
# of xgrammar. We defer it to avoid the import of xgrammar and
# initialization cost if it is not going to be used.
if not self.init_complete:
self._delayed_init()
grammar: Future[Grammar] = self.executor.submit( grammar: Future[Grammar] = self.executor.submit(
self._async_create_grammar, request) self._async_create_grammar, request)
request.structured_output_request.grammar = grammar # type: ignore[assignment] request.structured_output_request.grammar = grammar # type: ignore[assignment]
@ -100,11 +112,6 @@ class StructuredOutputManager:
if not structured_output_request_ids: if not structured_output_request_ids:
return None return None
if self._grammar_bitmask is None:
self._grammar_bitmask = xgr.allocate_token_bitmask(
self.vllm_config.scheduler_config.max_num_seqs,
self.vocab_size)
# Fill the bitmask using the index of each request equal to its # Fill the bitmask using the index of each request equal to its
# position in the batch. Resize the bitmask down to the size of # position in the batch. Resize the bitmask down to the size of
# the batch. # the batch.