[Structured Outputs][V1] Skipping with models doesn't contain tokenizers (#20365)

Signed-off-by: Aaron Pham <contact@aarnphm.xyz>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Aaron Pham 2025-07-04 03:05:49 -04:00 committed by GitHub
parent a7bab0c9e5
commit 4a98edff1f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 128 additions and 31 deletions

View File

@ -9,7 +9,7 @@ import torch
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig, from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
SchedulerConfig, SpeculativeConfig, VllmConfig) SchedulerConfig, SpeculativeConfig, VllmConfig)
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler from vllm.v1.core.sched.scheduler import Scheduler
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
@ -17,6 +17,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus from vllm.v1.request import Request, RequestStatus
from vllm.v1.structured_output import StructuredOutputManager from vllm.v1.structured_output import StructuredOutputManager
from vllm.v1.structured_output.request import StructuredOutputRequest
EOS_TOKEN_ID = 50256 EOS_TOKEN_ID = 50256
@ -33,6 +34,7 @@ def create_scheduler(
block_size: int = 16, block_size: int = 16,
max_model_len: Optional[int] = None, max_model_len: Optional[int] = None,
num_speculative_tokens: Optional[int] = None, num_speculative_tokens: Optional[int] = None,
skip_tokenizer_init: bool = False,
) -> Scheduler: ) -> Scheduler:
'''Create scheduler under test. '''Create scheduler under test.
@ -65,6 +67,7 @@ def create_scheduler(
trust_remote_code=True, trust_remote_code=True,
dtype="float16", dtype="float16",
seed=42, seed=42,
skip_tokenizer_init=skip_tokenizer_init,
) )
# Cache config, optionally force APC # Cache config, optionally force APC
kwargs_cache = ({} if enable_prefix_caching is None else { kwargs_cache = ({} if enable_prefix_caching is None else {
@ -186,7 +189,7 @@ def test_get_num_unfinished_requests():
]) ])
def test_schedule(enable_prefix_caching: Optional[bool], def test_schedule(enable_prefix_caching: Optional[bool],
prompt_logprobs: Optional[int]): prompt_logprobs: Optional[int]):
'''Test scheduling. '''Test scheduling.
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
''' '''
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching) scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
@ -1408,7 +1411,7 @@ def create_requests_with_priority(
def test_priority_scheduling_basic_ordering(): def test_priority_scheduling_basic_ordering():
"""Test that requests are scheduled in priority order """Test that requests are scheduled in priority order
(lower value = higher priority).""" (lower value = higher priority)."""
scheduler = create_scheduler_with_priority() scheduler = create_scheduler_with_priority()
@ -1437,7 +1440,7 @@ def test_priority_scheduling_basic_ordering():
def test_priority_scheduling_arrival_time_tiebreaker(): def test_priority_scheduling_arrival_time_tiebreaker():
"""Test that arrival time is used """Test that arrival time is used
as tiebreaker when priorities are equal.""" as tiebreaker when priorities are equal."""
scheduler = create_scheduler_with_priority() scheduler = create_scheduler_with_priority()
@ -1495,7 +1498,7 @@ def test_priority_scheduling_mixed_priority_and_arrival():
def test_priority_scheduling_preemption(): def test_priority_scheduling_preemption():
"""Test that priority scheduling preempts """Test that priority scheduling preempts
lower priority requests when memory is constrained.""" lower priority requests when memory is constrained."""
# Create scheduler with very limited memory to force preemption # Create scheduler with very limited memory to force preemption
scheduler = create_scheduler_with_priority( scheduler = create_scheduler_with_priority(
@ -1576,7 +1579,7 @@ def test_priority_scheduling_preemption():
def test_priority_scheduling_no_preemption_when_space_available(): def test_priority_scheduling_no_preemption_when_space_available():
"""Test that preemption doesn't happen """Test that preemption doesn't happen
when there's space for new requests.""" when there's space for new requests."""
scheduler = create_scheduler_with_priority( scheduler = create_scheduler_with_priority(
max_num_seqs=3, # Allow 3 concurrent requests max_num_seqs=3, # Allow 3 concurrent requests
@ -1626,7 +1629,7 @@ def test_priority_scheduling_no_preemption_when_space_available():
def test_priority_scheduling_preemption_victim_selection(): def test_priority_scheduling_preemption_victim_selection():
"""Test that the correct victim is selected for """Test that the correct victim is selected for
preemption based on priority and arrival time.""" preemption based on priority and arrival time."""
# This test verifies the priority-based victim selection logic # This test verifies the priority-based victim selection logic
# by checking the waiting queue order after adding requests with different # by checking the waiting queue order after adding requests with different
@ -1743,7 +1746,7 @@ def test_priority_scheduling_waiting_queue_order():
def test_priority_scheduling_fcfs_fallback(): def test_priority_scheduling_fcfs_fallback():
"""Test that FCFS behavior is maintained when all """Test that FCFS behavior is maintained when all
requests have same priority.""" requests have same priority."""
scheduler = create_scheduler_with_priority() scheduler = create_scheduler_with_priority()
@ -1811,7 +1814,7 @@ def test_priority_scheduling_with_limited_slots():
def test_priority_scheduling_heap_property(): def test_priority_scheduling_heap_property():
"""Test that the waiting queue maintains heap """Test that the waiting queue maintains heap
property for priority scheduling.""" property for priority scheduling."""
scheduler = create_scheduler_with_priority( scheduler = create_scheduler_with_priority(
max_num_seqs=1, # Only one request can run at a time max_num_seqs=1, # Only one request can run at a time
@ -1857,3 +1860,39 @@ def test_priority_scheduling_heap_property():
# Verify requests were scheduled in priority order (lowest value first) # Verify requests were scheduled in priority order (lowest value first)
expected_priorities = sorted(priorities) expected_priorities = sorted(priorities)
assert scheduled_priorities == expected_priorities assert scheduled_priorities == expected_priorities
def test_schedule_skip_tokenizer_init():
scheduler = create_scheduler(skip_tokenizer_init=True)
requests = create_requests(num_requests=5)
for request in requests:
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == len(requests)
assert output.grammar_bitmask is None
def test_schedule_skip_tokenizer_init_structured_output_request():
scheduler = create_scheduler(skip_tokenizer_init=True)
guided_params = GuidedDecodingParams(regex="[0-9]+")
sampling_params = SamplingParams(
ignore_eos=False,
max_tokens=16,
guided_decoding=guided_params,
)
request = Request(
request_id="0",
prompt_token_ids=[0, 1],
multi_modal_inputs=None,
multi_modal_hashes=None,
multi_modal_placeholders=None,
sampling_params=sampling_params,
pooling_params=None,
eos_token_id=EOS_TOKEN_ID,
structured_output_request=StructuredOutputRequest(sampling_params),
)
scheduler.add_request(request)
output = scheduler.schedule()
assert len(output.scheduled_new_reqs) == 0
assert len(scheduler.running) == 0
assert len(scheduler.waiting) == 1

View File

@ -1,19 +1,30 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import random import random
from typing import Optional from typing import TYPE_CHECKING, Optional
import pytest import pytest
from vllm import LLM, SamplingParams from vllm import LLM
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector
if TYPE_CHECKING:
from tests.conftest import VllmRunner
MODEL = "facebook/opt-125m" MODEL = "facebook/opt-125m"
DTYPE = "half" DTYPE = "half"
def _vllm_model(apc: bool, vllm_runner, monkeypatch): def _vllm_model(
apc: bool,
vllm_runner: type[VllmRunner],
monkeypatch: pytest.MonkeyPatch,
*,
skip_tokenizer_init: bool = False,
):
"""Set up VllmRunner instance.""" """Set up VllmRunner instance."""
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
return vllm_runner( return vllm_runner(
@ -23,6 +34,7 @@ def _vllm_model(apc: bool, vllm_runner, monkeypatch):
enforce_eager=True, enforce_eager=True,
enable_prefix_caching=apc, enable_prefix_caching=apc,
gpu_memory_utilization=0.5, gpu_memory_utilization=0.5,
skip_tokenizer_init=skip_tokenizer_init,
) )
@ -45,9 +57,27 @@ def vllm_model_apc(vllm_runner, monkeypatch):
yield vllm_model yield vllm_model
@pytest.fixture(
# Function scope decouples tests & allows
# env var adjustment via monkeypatch
scope="function",
# Prefix caching
params=[False, True])
def vllm_model_skip_tokenizer_init(vllm_runner, request, monkeypatch):
"""VllmRunner test fixture with APC."""
with _vllm_model(
request.param,
vllm_runner,
monkeypatch,
skip_tokenizer_init=True,
) as vllm_model:
yield vllm_model
def _get_test_sampling_params( def _get_test_sampling_params(
prompt_list: list[str], prompt_list: list[str],
seed: Optional[int] = 42, seed: Optional[int] = 42,
structured_outputs: bool = False,
) -> tuple[list[SamplingParams], list[int]]: ) -> tuple[list[SamplingParams], list[int]]:
"""Generate random sampling params for a batch.""" """Generate random sampling params for a batch."""
@ -62,14 +92,34 @@ def _get_test_sampling_params(
n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))] n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))]
# High temperature to maximize the chance of unique completions # High temperature to maximize the chance of unique completions
return [ return [
SamplingParams(temperature=0.95, top_p=0.95, n=n, seed=seed) SamplingParams(
for n in n_list temperature=0.95,
top_p=0.95,
n=n,
seed=seed,
guided_decoding=GuidedDecodingParams(
regex="[0-9]+") if structured_outputs else None,
) for n in n_list
], n_list ], n_list
def test_compatibility_with_skip_tokenizer_init(
vllm_model_skip_tokenizer_init: VllmRunner,
example_prompts: list[str],
):
# Case 1: Structured output request should raise an error.
sampling_params_list, _ = _get_test_sampling_params(
example_prompts,
structured_outputs=True,
)
model: LLM = vllm_model_skip_tokenizer_init.model
with pytest.raises(ValueError):
_ = model.generate(example_prompts, sampling_params_list)
def test_parallel_sampling(vllm_model, example_prompts) -> None: def test_parallel_sampling(vllm_model, example_prompts) -> None:
"""Test passes if parallel sampling `n>1` yields `n` unique completions. """Test passes if parallel sampling `n>1` yields `n` unique completions.
Args: Args:
vllm_model: VllmRunner instance under test. vllm_model: VllmRunner instance under test.
example_prompt: test fixture providing prompts for testing. example_prompt: test fixture providing prompts for testing.

View File

@ -152,6 +152,11 @@ class Processor:
if not params.guided_decoding or not self.decoding_config: if not params.guided_decoding or not self.decoding_config:
return return
if self.model_config.skip_tokenizer_init and params.guided_decoding:
raise ValueError(
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501
)
engine_level_backend = self.decoding_config.backend engine_level_backend = self.decoding_config.backend
if params.guided_decoding.backend: if params.guided_decoding.backend:
# Request-level backend selection is not supported in V1. # Request-level backend selection is not supported in V1.

View File

@ -40,22 +40,25 @@ class StructuredOutputManager:
self._grammar_bitmask: Optional[torch.Tensor] = None self._grammar_bitmask: Optional[torch.Tensor] = None
self._full_mask = torch.tensor(-1, dtype=torch.int32) self._full_mask = torch.tensor(-1, dtype=torch.int32)
# The default max_workers if not specified is the number of CPUs * 5, if not self.vllm_config.model_config.skip_tokenizer_init:
# which is way too high since these tasks are CPU-bound, not I/O bound. # The default max_workers if not specified is the number of
# We also know we would never dominate CPU usage with just grammar # CPUs * 5, which is way too high since these tasks are CPU-bound,
# compilation, so we set it to half the number of CPUs. # not I/O bound. We also know we would never dominate CPU usage
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2) # with just grammar compilation, so we set it to half the number
self.executor = ThreadPoolExecutor(max_workers=max_workers) # of CPUs.
self.tokenizer = init_tokenizer_from_configs( max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
model_config=self.vllm_config.model_config, self.executor = ThreadPoolExecutor(max_workers=max_workers)
scheduler_config=self.vllm_config.scheduler_config, self.tokenizer = init_tokenizer_from_configs(
lora_config=self.vllm_config.lora_config, model_config=self.vllm_config.model_config,
).get_lora_tokenizer(None) scheduler_config=self.vllm_config.scheduler_config,
reasoning_backend = vllm_config.decoding_config.reasoning_backend lora_config=self.vllm_config.lora_config,
if reasoning_backend: ).get_lora_tokenizer(None)
reasoner_cls = ReasoningParserManager.get_reasoning_parser( reasoning_backend = \
reasoning_backend) self.vllm_config.decoding_config.reasoning_backend
self.reasoner = reasoner_cls(tokenizer=self.tokenizer) if reasoning_backend:
reasoner_cls = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
def grammar_init(self, request: Request) -> None: def grammar_init(self, request: Request) -> None:
if request.structured_output_request is None: if request.structured_output_request is None: