mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-13 13:35:36 +08:00
[V1] Delay all xgrammar usage until needed (#14616)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
parent
53056731fd
commit
61a01b27a7
@ -14,7 +14,6 @@ from vllm.v1.structured_output.grammar import Grammar, StructuredOutputOptions
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
import torch
|
|
||||||
import xgrammar as xgr
|
import xgrammar as xgr
|
||||||
|
|
||||||
from vllm.v1.request import Request
|
from vllm.v1.request import Request
|
||||||
@ -27,14 +26,18 @@ logger = init_logger(__name__)
|
|||||||
class StructuredOutputManager:
|
class StructuredOutputManager:
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig):
|
def __init__(self, vllm_config: VllmConfig):
|
||||||
tokenizer_group = init_tokenizer_from_configs(
|
|
||||||
model_config=vllm_config.model_config,
|
|
||||||
scheduler_config=vllm_config.scheduler_config,
|
|
||||||
parallel_config=vllm_config.parallel_config,
|
|
||||||
lora_config=vllm_config.lora_config) # type: ignore[arg-type]
|
|
||||||
tokenizer_group.ping()
|
|
||||||
self.vocab_size = vllm_config.model_config.get_vocab_size()
|
self.vocab_size = vllm_config.model_config.get_vocab_size()
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
|
self.init_complete = False
|
||||||
|
|
||||||
|
def _delayed_init(self):
|
||||||
|
"""Initialization delayed until we know it is needed."""
|
||||||
|
tokenizer_group = init_tokenizer_from_configs(
|
||||||
|
model_config=self.vllm_config.model_config,
|
||||||
|
scheduler_config=self.vllm_config.scheduler_config,
|
||||||
|
parallel_config=self.vllm_config.parallel_config,
|
||||||
|
lora_config=self.vllm_config.lora_config) # type: ignore[arg-type]
|
||||||
|
tokenizer_group.ping()
|
||||||
|
|
||||||
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
tokenizer = tokenizer_group.get_lora_tokenizer(None)
|
||||||
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
|
tokenizer_info = xgr.TokenizerInfo.from_huggingface(
|
||||||
@ -47,12 +50,21 @@ class StructuredOutputManager:
|
|||||||
# compilation, so we set it to half the number of CPUs.
|
# compilation, so we set it to half the number of CPUs.
|
||||||
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
||||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||||
self._grammar_bitmask: Optional[torch.Tensor] = None
|
self._grammar_bitmask = xgr.allocate_token_bitmask(
|
||||||
|
self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size)
|
||||||
|
|
||||||
|
self.init_complete = True
|
||||||
|
|
||||||
def grammar_init(self, request: Request) -> None:
|
def grammar_init(self, request: Request) -> None:
|
||||||
if request.structured_output_request is None:
|
if request.structured_output_request is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# The first time this is called, we need to finish initialization
|
||||||
|
# of xgrammar. We defer it to avoid the import of xgrammar and
|
||||||
|
# initialization cost if it is not going to be used.
|
||||||
|
if not self.init_complete:
|
||||||
|
self._delayed_init()
|
||||||
|
|
||||||
grammar: Future[Grammar] = self.executor.submit(
|
grammar: Future[Grammar] = self.executor.submit(
|
||||||
self._async_create_grammar, request)
|
self._async_create_grammar, request)
|
||||||
request.structured_output_request.grammar = grammar # type: ignore[assignment]
|
request.structured_output_request.grammar = grammar # type: ignore[assignment]
|
||||||
@ -100,11 +112,6 @@ class StructuredOutputManager:
|
|||||||
if not structured_output_request_ids:
|
if not structured_output_request_ids:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if self._grammar_bitmask is None:
|
|
||||||
self._grammar_bitmask = xgr.allocate_token_bitmask(
|
|
||||||
self.vllm_config.scheduler_config.max_num_seqs,
|
|
||||||
self.vocab_size)
|
|
||||||
|
|
||||||
# Fill the bitmask using the index of each request equal to its
|
# Fill the bitmask using the index of each request equal to its
|
||||||
# position in the batch. Resize the bitmask down to the size of
|
# position in the batch. Resize the bitmask down to the size of
|
||||||
# the batch.
|
# the batch.
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user