[V1] Prevent xgrammar from breaking TPU support (#14575)

Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
Russell Bryant 2025-03-10 19:06:19 -04:00 committed by GitHub
parent 432d6dad15
commit 04421dff8a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 11 additions and 2 deletions

View File

@ -4,6 +4,7 @@ import time
from collections.abc import Mapping
from typing import Optional, Union
import vllm.platforms
from vllm.config import VllmConfig
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
PromptType, SingletonInputsAdapter)
@ -133,6 +134,9 @@ class Processor:
if self.vllm_config.speculative_config:
raise ValueError("Structured output is not supported with "
"speculative decoding.")
if vllm.platforms.current_platform.is_tpu():
raise ValueError("Structured output is not supported on TPU.")
validate_structured_output_request(params)
def process_inputs(

View File

@ -17,6 +17,7 @@ from vllm.v1.structured_output.grammar import (Grammar, StructuredOutputKey,
if TYPE_CHECKING:
import numpy as np
import numpy.typing as npt
import torch
import xgrammar as xgr
from vllm.v1.request import Request
@ -53,8 +54,7 @@ class StructuredOutputManager:
# compilation, so we set it to half the number of CPUs.
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self._grammar_bitmask = xgr.allocate_token_bitmask(
self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size)
self._grammar_bitmask: Optional[torch.Tensor] = None
def __getitem__(self, key: StructuredOutputKey) -> Optional[Grammar]:
# We need to pop and re-insert the grammar here for LRU cache
@ -134,6 +134,11 @@ class StructuredOutputManager:
if not structured_output_request_ids:
return None
if self._grammar_bitmask is None:
self._grammar_bitmask = xgr.allocate_token_bitmask(
self.vllm_config.scheduler_config.max_num_seqs,
self.vocab_size)
# Fill the bitmask using the index of each request equal to its
# position in the batch. Resize the bitmask down to the size of
# the batch.