[Structured Outputs][V1] Skipping with models doesn't contain tokenizers (#20365)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Aaron Pham 2025-07-04 03:05:49 -04:00 committed by GitHub
parent a7bab0c9e5
commit 4a98edff1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 128 additions and 31 deletions

View File

@ -9,7 +9,7 @@ import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
@ -17,6 +17,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.structured_output.request import StructuredOutputRequest
EOS_TOKEN_ID = 50256
@ -33,6 +34,7 @@ def create_scheduler(
block_size: int = 16,
max_model_len: Optional[int] = None,
num_speculative_tokens: Optional[int] = None,
skip_tokenizer_init: bool = False,
) -> Scheduler:
'''Create scheduler under test.
@ -65,6 +67,7 @@ def create_scheduler(
trust_remote_code=True,
dtype="float16",
seed=42,
skip_tokenizer_init=skip_tokenizer_init,
)
# Cache config, optionally force APC
kwargs_cache = ({} if enable_prefix_caching is None else {
@ -186,7 +189,7 @@ def test_get_num_unfinished_requests():
])
def test_schedule(enable_prefix_caching: Optional[bool],
prompt_logprobs: Optional[int]):
'''Test scheduling.
'''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
'''
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
@ -1408,7 +1411,7 @@ def create_requests_with_priority(
def test_priority_scheduling_basic_ordering():
"""Test that requests are scheduled in priority order
"""Test that requests are scheduled in priority order
(lower value = higher priority)."""
scheduler = create_scheduler_with_priority()
@ -1437,7 +1440,7 @@ def test_priority_scheduling_basic_ordering():
def test_priority_scheduling_arrival_time_tiebreaker():
"""Test that arrival time is used
"""Test that arrival time is used
as tiebreaker when priorities are equal."""
scheduler = create_scheduler_with_priority()
@ -1495,7 +1498,7 @@ def test_priority_scheduling_mixed_priority_and_arrival():
def test_priority_scheduling_preemption():
"""Test that priority scheduling preempts
"""Test that priority scheduling preempts
lower priority requests when memory is constrained."""
# Create scheduler with very limited memory to force preemption
scheduler = create_scheduler_with_priority(
@ -1576,7 +1579,7 @@ def test_priority_scheduling_preemption():
def test_priority_scheduling_no_preemption_when_space_available():
"""Test that preemption doesn't happen
"""Test that preemption doesn't happen
when there's space for new requests."""
scheduler = create_scheduler_with_priority(
max_num_seqs=3, # Allow 3 concurrent requests
@ -1626,7 +1629,7 @@ def test_priority_scheduling_no_preemption_when_space_available():
def test_priority_scheduling_preemption_victim_selection():
"""Test that the correct victim is selected for
"""Test that the correct victim is selected for
preemption based on priority and arrival time."""
# This test verifies the priority-based victim selection logic
# by checking the waiting queue order after adding requests with different
@ -1743,7 +1746,7 @@ def test_priority_scheduling_waiting_queue_order():
def test_priority_scheduling_fcfs_fallback():
"""Test that FCFS behavior is maintained when all
"""Test that FCFS behavior is maintained when all
requests have same priority."""
scheduler = create_scheduler_with_priority()
@ -1811,7 +1814,7 @@ def test_priority_scheduling_with_limited_slots():
def test_priority_scheduling_heap_property():
"""Test that the waiting queue maintains heap
"""Test that the waiting queue maintains heap
property for priority scheduling."""
scheduler = create_scheduler_with_priority(
max_num_seqs=1, # Only one request can run at a time
@ -1857,3 +1860,39 @@ def test_priority_scheduling_heap_property():
# Verify requests were scheduled in priority order (lowest value first)
expected_priorities = sorted(priorities)
assert scheduled_priorities == expected_priorities
def test_schedule_skip_tokenizer_init():
scheduler = create_scheduler(skip_tokenizer_init=True)
requests = create_requests(num_requests=5)
for request in requests:
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.grammar_bitmask is None
def test_schedule_skip_tokenizer_init_structured_output_request():
scheduler = create_scheduler(skip_tokenizer_init=True)
guided_params = GuidedDecodingParams(regex="[0-9]+")
sampling_params = SamplingParams(
ignore_eos=False,
max_tokens=16,
guided_decoding=guided_params,
)
request = Request(
request_id="0",
prompt_token_ids=[0, 1],
multi_modal_inputs=None,
multi_modal_hashes=None,
multi_modal_placeholders=None,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
structured_output_request=StructuredOutputRequest(sampling_params),
)
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 0
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1

View File

@ -1,19 +1,30 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import random
from typing import Optional
from typing import TYPE_CHECKING, Optional
import pytest
from vllm import LLM, SamplingParams
from vllm import LLM
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector
if TYPE_CHECKING:
from tests.conftest import VllmRunner
MODEL = "facebook/opt-125m"
DTYPE = "half"
def _vllm_model(apc: bool, vllm_runner, monkeypatch):
def _vllm_model(
apc: bool,
vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch,
*,
skip_tokenizer_init: bool = False,
):
"""Set up VllmRunner instance."""
monkeypatch.setenv("VLLM_USE_V1", "1")
return vllm_runner(
@ -23,6 +34,7 @@ def _vllm_model(apc: bool, vllm_runner, monkeypatch):
enforce_eager=True,
enable_prefix_caching=apc,
gpu_memory_utilization=0.5,
skip_tokenizer_init=skip_tokenizer_init,
)
@ -45,9 +57,27 @@ def vllm_model_apc(vllm_runner, monkeypatch):
yield vllm_model
@pytest.fixture(
# Function scope decouples tests & allows
# env var adjustment via monkeypatch
scope="function",
# Prefix caching
params=[False, True])
def vllm_model_skip_tokenizer_init(vllm_runner, request, monkeypatch):
"""VllmRunner test fixture with APC."""
with _vllm_model(
request.param,
vllm_runner,
monkeypatch,
skip_tokenizer_init=True,
) as vllm_model:
yield vllm_model
def _get_test_sampling_params(
prompt_list: list[str],
seed: Optional[int] = 42,
structured_outputs: bool = False,
) -> tuple[list[SamplingParams], list[int]]:
"""Generate random sampling params for a batch."""
@ -62,14 +92,34 @@ def _get_test_sampling_params(
n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))]
# High temperature to maximize the chance of unique completions
return [
SamplingParams(temperature=0.95, top_p=0.95, n=n, seed=seed)
for n in n_list
SamplingParams(
temperature=0.95,
top_p=0.95,
n=n,
seed=seed,
guided_decoding=GuidedDecodingParams(
regex="[0-9]+") if structured_outputs else None,
) for n in n_list
], n_list
def test_compatibility_with_skip_tokenizer_init(
vllm_model_skip_tokenizer_init: VllmRunner,
example_prompts: list[str],
):
# Case 1: Structured output request should raise an error.
sampling_params_list, _ = _get_test_sampling_params(
example_prompts,
structured_outputs=True,
)
model: LLM = vllm_model_skip_tokenizer_init.model
with pytest.raises(ValueError):
_ = model.generate(example_prompts, sampling_params_list)
def test_parallel_sampling(vllm_model, example_prompts) -> None:
"""Test passes if parallel sampling `n>1` yields `n` unique completions.
Args:
vllm_model: VllmRunner instance under test.
example_prompt: test fixture providing prompts for testing.

View File

@ -152,6 +152,11 @@ class Processor:
if not params.guided_decoding or not self.decoding_config:
return
if self.model_config.skip_tokenizer_init and params.guided_decoding:
raise ValueError(
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501
)
engine_level_backend = self.decoding_config.backend
if params.guided_decoding.backend:
# Request-level backend selection is not supported in V1.

View File

@ -40,22 +40,25 @@ class StructuredOutputManager:
self._grammar_bitmask: Optional[torch.Tensor] = None
self._full_mask = torch.tensor(-1, dtype=torch.int32)
# The default max_workers if not specified is the number of CPUs * 5,
# which is way too high since these tasks are CPU-bound, not I/O bound.
# We also know we would never dominate CPU usage with just grammar
# compilation, so we set it to half the number of CPUs.
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.tokenizer = init_tokenizer_from_configs(
model_config=self.vllm_config.model_config,
scheduler_config=self.vllm_config.scheduler_config,
lora_config=self.vllm_config.lora_config,
).get_lora_tokenizer(None)
reasoning_backend = vllm_config.decoding_config.reasoning_backend
if reasoning_backend:
reasoner_cls = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
if not self.vllm_config.model_config.skip_tokenizer_init:
# The default max_workers if not specified is the number of
# CPUs * 5, which is way too high since these tasks are CPU-bound,
# not I/O bound. We also know we would never dominate CPU usage
# with just grammar compilation, so we set it to half the number
# of CPUs.
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
self.executor = ThreadPoolExecutor(max_workers=max_workers)
self.tokenizer = init_tokenizer_from_configs(
model_config=self.vllm_config.model_config,
scheduler_config=self.vllm_config.scheduler_config,
lora_config=self.vllm_config.lora_config,
).get_lora_tokenizer(None)
reasoning_backend = \
self.vllm_config.decoding_config.reasoning_backend
if reasoning_backend:
reasoner_cls = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
def grammar_init(self, request: Request) -> None:
if request.structured_output_request is None: