diff --git a/vllm/v1/engine/processor.py b/vllm/v1/engine/processor.py index 5c940cce5d5b1..56846030ac49f 100644 --- a/vllm/v1/engine/processor.py +++ b/vllm/v1/engine/processor.py @@ -4,6 +4,7 @@ import time from collections.abc import Mapping from typing import Optional, Union +import vllm.platforms from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs, PromptType, SingletonInputsAdapter) @@ -133,6 +134,9 @@ class Processor: if self.vllm_config.speculative_config: raise ValueError("Structured output is not supported with " "speculative decoding.") + if vllm.platforms.current_platform.is_tpu(): + raise ValueError("Structured output is not supported on TPU.") + validate_structured_output_request(params) def process_inputs( diff --git a/vllm/v1/structured_output/__init__.py b/vllm/v1/structured_output/__init__.py index 0c2e0ac2aa73c..1f6e35643927f 100644 --- a/vllm/v1/structured_output/__init__.py +++ b/vllm/v1/structured_output/__init__.py @@ -17,6 +17,7 @@ from vllm.v1.structured_output.grammar import (Grammar, StructuredOutputKey, if TYPE_CHECKING: import numpy as np import numpy.typing as npt + import torch import xgrammar as xgr from vllm.v1.request import Request @@ -53,8 +54,7 @@ class StructuredOutputManager: # compilation, so we set it to half the number of CPUs. max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) self.executor = ThreadPoolExecutor(max_workers=max_workers) - self._grammar_bitmask = xgr.allocate_token_bitmask( - self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size) + self._grammar_bitmask: Optional[torch.Tensor] = None def __getitem__(self, key: StructuredOutputKey) -> Optional[Grammar]: # We need to pop and re-insert the grammar here for LRU cache @@ -134,6 +134,11 @@ class StructuredOutputManager: if not structured_output_request_ids: return None + if self._grammar_bitmask is None: + self._grammar_bitmask = xgr.allocate_token_bitmask( + self.vllm_config.scheduler_config.max_num_seqs, + self.vocab_size) + # Fill the bitmask using the index of each request equal to its # position in the batch. Resize the bitmask down to the size of # the batch.