mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-01 20:37:30 +08:00
[BugFix] Adding env variable to disable async grammar compilation (#29996)
Signed-off-by: Alec Solder <alecs@fb.com> Signed-off-by: Alec S <10566873+alecsolder@users.noreply.github.com> Co-authored-by: Alec Solder <alecs@fb.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
parent
62b3333448
commit
65ee97288a
@ -1,9 +1,14 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from concurrent.futures import Future
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm.config import StructuredOutputsConfig, VllmConfig
|
||||
from vllm.config.model import ModelConfig
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.speculative import SpeculativeConfig
|
||||
from vllm.sampling_params import SamplingParams, StructuredOutputsParams
|
||||
from vllm.v1.request import Request
|
||||
@ -116,3 +121,72 @@ def test_grammar_bitmask_with_specdec():
|
||||
) # EOS not the final token
|
||||
grammar_bitmask(request, prompt[i:]) # EOS not present
|
||||
grammar_bitmask(request, prompt[i:] + [tokenizer.eos_token_id])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("async_grammar", [True, False])
|
||||
def test_grammar_init_async_and_sync(async_grammar):
|
||||
"""Test grammar initialization works correctly in both async and sync modes.
|
||||
|
||||
This test validates that the distributed_executor_backend config option
|
||||
correctly controls whether grammar compilation happens asynchronously
|
||||
(via executor.submit) or synchronously. When set to "external_launcher",
|
||||
grammar compilation is synchronous to avoid deadlocks.
|
||||
"""
|
||||
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)
|
||||
prompt = tokenizer.encode('{"a": "b"}')
|
||||
|
||||
# Use "external_launcher" for sync mode, None for async mode
|
||||
executor_backend = None if async_grammar else "external_launcher"
|
||||
vllm_config = VllmConfig(
|
||||
model_config=ModelConfig(tokenizer=TOKENIZER),
|
||||
structured_outputs_config=StructuredOutputsConfig(backend="guidance"),
|
||||
parallel_config=ParallelConfig(distributed_executor_backend=executor_backend),
|
||||
)
|
||||
structured_output_manager = StructuredOutputManager(vllm_config)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
structured_outputs=StructuredOutputsParams(
|
||||
json='{"type": "object"}',
|
||||
),
|
||||
)
|
||||
sampling_params.structured_outputs._backend = "guidance"
|
||||
|
||||
request = Request(
|
||||
"test_request",
|
||||
prompt_token_ids=prompt,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
structured_output_manager.grammar_init(request)
|
||||
|
||||
# Check the internal _grammar type immediately after init
|
||||
# Before _check_grammar_completion is called, async mode should have a Future
|
||||
raw_grammar = request.structured_output_request._grammar
|
||||
if async_grammar:
|
||||
assert isinstance(raw_grammar, Future), (
|
||||
"Async mode should store a Future before completion"
|
||||
)
|
||||
else:
|
||||
assert not isinstance(raw_grammar, Future), (
|
||||
"Sync mode should store the grammar directly, not a Future"
|
||||
)
|
||||
|
||||
# Wait for grammar to be ready (handles both async and sync cases)
|
||||
start_time = time.time()
|
||||
while not request.structured_output_request._check_grammar_completion():
|
||||
if time.time() - start_time > 5: # 5-second timeout
|
||||
pytest.fail("Grammar compilation timed out")
|
||||
time.sleep(0.01)
|
||||
|
||||
# After completion, _grammar should no longer be a Future
|
||||
assert not isinstance(request.structured_output_request._grammar, Future)
|
||||
|
||||
# Verify grammar is properly initialized and functional
|
||||
grammar = request.structured_output_request.grammar
|
||||
assert grammar is not None
|
||||
assert not grammar.is_terminated()
|
||||
|
||||
# Verify the grammar can accept valid tokens
|
||||
assert grammar.accept_tokens(request.request_id, prompt)
|
||||
|
||||
@ -40,6 +40,16 @@ class StructuredOutputManager:
|
||||
self.reasoner: ReasoningParser | None = None
|
||||
self.vllm_config = vllm_config
|
||||
|
||||
# When in external_launcher mode, async grammar compilation causes deadlocks
|
||||
# due to external_launcher mode having a scheduler for each TP rank.
|
||||
# Async grammar compilation causes the WAITING_FOR_FSM → WAITING transition to
|
||||
# happen at different times on different TP ranks,
|
||||
# breaking the determinism assumption that external_launcher relies on.
|
||||
self._use_async_grammar_compilation = (
|
||||
vllm_config.parallel_config.distributed_executor_backend
|
||||
!= "external_launcher"
|
||||
)
|
||||
|
||||
self._grammar_bitmask: torch.Tensor | None = None
|
||||
self._full_mask = torch.tensor(-1, dtype=torch.int32)
|
||||
|
||||
@ -138,10 +148,13 @@ class StructuredOutputManager:
|
||||
else:
|
||||
raise ValueError(f"Unsupported structured output backend: {backend}")
|
||||
|
||||
grammar = self.executor.submit(self._async_create_grammar, request)
|
||||
if self._use_async_grammar_compilation:
|
||||
grammar = self.executor.submit(self._create_grammar, request)
|
||||
else:
|
||||
grammar = self._create_grammar(request) # type: ignore[assignment]
|
||||
request.structured_output_request.grammar = grammar # type: ignore[assignment]
|
||||
|
||||
def _async_create_grammar(
|
||||
def _create_grammar(
|
||||
self,
|
||||
request: Request,
|
||||
) -> StructuredOutputGrammar:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user