diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 43cb95fb47f95..1ba557977707f 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -137,7 +137,7 @@ class EngineCore: req = Request.from_engine_core_request(request) if req.use_structured_output: # Start grammar compilation asynchronously - self.structured_output_manager.populate_cache(req) + self.structured_output_manager.grammar_init(req) self.scheduler.add_request(req) diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 3f828e0854765..fd1e6feed6a0b 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -1,18 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations -import copy import multiprocessing -from collections import OrderedDict -from concurrent.futures import ThreadPoolExecutor +from concurrent.futures import Future, ThreadPoolExecutor from typing import TYPE_CHECKING, Optional from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs from vllm.utils import LazyLoader -from vllm.v1.structured_output.grammar import (Grammar, StructuredOutputKey, - StructuredOutputOptions) +from vllm.v1.structured_output.grammar import Grammar, StructuredOutputOptions if TYPE_CHECKING: import numpy as np @@ -29,7 +26,7 @@ logger = init_logger(__name__) class StructuredOutputManager: - def __init__(self, vllm_config: VllmConfig, max_cache_size: int = 500): + def __init__(self, vllm_config: VllmConfig): tokenizer_group = init_tokenizer_from_configs( model_config=vllm_config.model_config, scheduler_config=vllm_config.scheduler_config, @@ -44,10 +41,6 @@ class StructuredOutputManager: tokenizer, vocab_size=self.vocab_size) self.compiler = xgr.GrammarCompiler(tokenizer_info, max_threads=8) - self.max_cache_size = max_cache_size - self.request_key_to_grammar: OrderedDict[StructuredOutputKey, - Grammar] = OrderedDict() - # The default max_workers if not specified is the number of CPUs * 5, # which is way too high since these tasks are CPU-bound, not I/O bound. # We also know we would never dominate CPU usage with just grammar @@ -56,51 +49,22 @@ class StructuredOutputManager: self.executor = ThreadPoolExecutor(max_workers=max_workers) self._grammar_bitmask: Optional[torch.Tensor] = None - def __getitem__(self, key: StructuredOutputKey) -> Optional[Grammar]: - # We need to pop and re-insert the grammar here for LRU cache - # of request_key_to_grammar - if key in self.request_key_to_grammar: - # Move accessed item to the end (most recently used) - value = self.request_key_to_grammar.pop(key) - if value is not None: - self.request_key_to_grammar[key] = value - return value - return None - - def populate_cache(self, request: Request) -> None: + def grammar_init(self, request: Request) -> None: if request.structured_output_request is None: return - grammar = self.request_key_to_grammar.get( - request.structured_output_request.structured_output_key) - if grammar: - request.structured_output_request.grammar = copy.copy(grammar) - return - request.structured_output_request.grammar = self.cache(request) + grammar: Future[Grammar] = self.executor.submit( + self._async_create_grammar, request) + request.structured_output_request.grammar = grammar # type: ignore[assignment] - def cache(self, request: Request): - return self.executor.submit(self._executor_loop, request) - - def _executor_loop(self, request: Request) -> Grammar: - # NOTE: The structured_output_request should never be - # None in this case, but mypy can't infer this - # correctly, so we need to ignore the error here. + def _async_create_grammar(self, request: Request) -> Grammar: key = request.structured_output_request.structured_output_key # type: ignore[union-attr] - grammar = self.request_key_to_grammar.get(key) - if grammar is not None: - return copy.copy(grammar) - grammar = self.initialize_grammar(key) - # If cache is full, remove the least recently used item - if len(self.request_key_to_grammar) >= self.max_cache_size: - self.request_key_to_grammar.popitem(last=False) - self.request_key_to_grammar[key] = grammar - return copy.copy(grammar) - def initialize_grammar(self, key: StructuredOutputKey) -> Grammar: # Note that the request was validated in the engine core client, # so at this point we know it is a supported type of request. # - # TODO: we still need to handle xgrammar compilation failures + # TODO: we still need to handle xgrammar compilation failures, + # though it should be unlikely as we test that up front as well. request_type, grammar_spec = key if request_type == StructuredOutputOptions.JSON: