mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-28 15:37:12 +08:00
[V1] Prevent xgrammar from breaking TPU support (#14575)
Signed-off-by: Russell Bryant <rbryant@redhat.com>
This commit is contained in:
parent
432d6dad15
commit
04421dff8a
@ -4,6 +4,7 @@ import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Optional, Union
|
||||
|
||||
import vllm.platforms
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.inputs import (INPUT_REGISTRY, InputRegistry, ProcessorInputs,
|
||||
PromptType, SingletonInputsAdapter)
|
||||
@ -133,6 +134,9 @@ class Processor:
|
||||
if self.vllm_config.speculative_config:
|
||||
raise ValueError("Structured output is not supported with "
|
||||
"speculative decoding.")
|
||||
if vllm.platforms.current_platform.is_tpu():
|
||||
raise ValueError("Structured output is not supported on TPU.")
|
||||
|
||||
validate_structured_output_request(params)
|
||||
|
||||
def process_inputs(
|
||||
|
||||
@ -17,6 +17,7 @@ from vllm.v1.structured_output.grammar import (Grammar, StructuredOutputKey,
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
import xgrammar as xgr
|
||||
|
||||
from vllm.v1.request import Request
|
||||
@ -53,8 +54,7 @@ class StructuredOutputManager:
|
||||
# compilation, so we set it to half the number of CPUs.
|
||||
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self._grammar_bitmask = xgr.allocate_token_bitmask(
|
||||
self.vllm_config.scheduler_config.max_num_seqs, self.vocab_size)
|
||||
self._grammar_bitmask: Optional[torch.Tensor] = None
|
||||
|
||||
def __getitem__(self, key: StructuredOutputKey) -> Optional[Grammar]:
|
||||
# We need to pop and re-insert the grammar here for LRU cache
|
||||
@ -134,6 +134,11 @@ class StructuredOutputManager:
|
||||
if not structured_output_request_ids:
|
||||
return None
|
||||
|
||||
if self._grammar_bitmask is None:
|
||||
self._grammar_bitmask = xgr.allocate_token_bitmask(
|
||||
self.vllm_config.scheduler_config.max_num_seqs,
|
||||
self.vocab_size)
|
||||
|
||||
# Fill the bitmask using the index of each request equal to its
|
||||
# position in the batch. Resize the bitmask down to the size of
|
||||
# the batch.
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user