mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 10:46:05 +08:00
[Structured Outputs][V1] Skipping with models doesn't contain tokenizers (#20365)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
a7bab0c9e5
commit
4a98edff1f
@ -9,7 +9,7 @@ import torch
|
||||
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
@ -17,6 +17,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
from vllm.v1.structured_output import StructuredOutputManager
|
||||
from vllm.v1.structured_output.request import StructuredOutputRequest
|
||||
|
||||
EOS_TOKEN_ID = 50256
|
||||
|
||||
@ -33,6 +34,7 @@ def create_scheduler(
|
||||
block_size: int = 16,
|
||||
max_model_len: Optional[int] = None,
|
||||
num_speculative_tokens: Optional[int] = None,
|
||||
skip_tokenizer_init: bool = False,
|
||||
) -> Scheduler:
|
||||
'''Create scheduler under test.
|
||||
|
||||
@ -65,6 +67,7 @@ def create_scheduler(
|
||||
trust_remote_code=True,
|
||||
dtype="float16",
|
||||
seed=42,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
)
|
||||
# Cache config, optionally force APC
|
||||
kwargs_cache = ({} if enable_prefix_caching is None else {
|
||||
@ -186,7 +189,7 @@ def test_get_num_unfinished_requests():
|
||||
])
|
||||
def test_schedule(enable_prefix_caching: Optional[bool],
|
||||
prompt_logprobs: Optional[int]):
|
||||
'''Test scheduling.
|
||||
'''Test scheduling.
|
||||
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
|
||||
'''
|
||||
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
|
||||
@ -1408,7 +1411,7 @@ def create_requests_with_priority(
|
||||
|
||||
|
||||
def test_priority_scheduling_basic_ordering():
|
||||
"""Test that requests are scheduled in priority order
|
||||
"""Test that requests are scheduled in priority order
|
||||
(lower value = higher priority)."""
|
||||
scheduler = create_scheduler_with_priority()
|
||||
|
||||
@ -1437,7 +1440,7 @@ def test_priority_scheduling_basic_ordering():
|
||||
|
||||
|
||||
def test_priority_scheduling_arrival_time_tiebreaker():
|
||||
"""Test that arrival time is used
|
||||
"""Test that arrival time is used
|
||||
as tiebreaker when priorities are equal."""
|
||||
scheduler = create_scheduler_with_priority()
|
||||
|
||||
@ -1495,7 +1498,7 @@ def test_priority_scheduling_mixed_priority_and_arrival():
|
||||
|
||||
|
||||
def test_priority_scheduling_preemption():
|
||||
"""Test that priority scheduling preempts
|
||||
"""Test that priority scheduling preempts
|
||||
lower priority requests when memory is constrained."""
|
||||
# Create scheduler with very limited memory to force preemption
|
||||
scheduler = create_scheduler_with_priority(
|
||||
@ -1576,7 +1579,7 @@ def test_priority_scheduling_preemption():
|
||||
|
||||
|
||||
def test_priority_scheduling_no_preemption_when_space_available():
|
||||
"""Test that preemption doesn't happen
|
||||
"""Test that preemption doesn't happen
|
||||
when there's space for new requests."""
|
||||
scheduler = create_scheduler_with_priority(
|
||||
max_num_seqs=3, # Allow 3 concurrent requests
|
||||
@ -1626,7 +1629,7 @@ def test_priority_scheduling_no_preemption_when_space_available():
|
||||
|
||||
|
||||
def test_priority_scheduling_preemption_victim_selection():
|
||||
"""Test that the correct victim is selected for
|
||||
"""Test that the correct victim is selected for
|
||||
preemption based on priority and arrival time."""
|
||||
# This test verifies the priority-based victim selection logic
|
||||
# by checking the waiting queue order after adding requests with different
|
||||
@ -1743,7 +1746,7 @@ def test_priority_scheduling_waiting_queue_order():
|
||||
|
||||
|
||||
def test_priority_scheduling_fcfs_fallback():
|
||||
"""Test that FCFS behavior is maintained when all
|
||||
"""Test that FCFS behavior is maintained when all
|
||||
requests have same priority."""
|
||||
scheduler = create_scheduler_with_priority()
|
||||
|
||||
@ -1811,7 +1814,7 @@ def test_priority_scheduling_with_limited_slots():
|
||||
|
||||
|
||||
def test_priority_scheduling_heap_property():
|
||||
"""Test that the waiting queue maintains heap
|
||||
"""Test that the waiting queue maintains heap
|
||||
property for priority scheduling."""
|
||||
scheduler = create_scheduler_with_priority(
|
||||
max_num_seqs=1, # Only one request can run at a time
|
||||
@ -1857,3 +1860,39 @@ def test_priority_scheduling_heap_property():
|
||||
# Verify requests were scheduled in priority order (lowest value first)
|
||||
expected_priorities = sorted(priorities)
|
||||
assert scheduled_priorities == expected_priorities
|
||||
|
||||
|
||||
def test_schedule_skip_tokenizer_init():
|
||||
scheduler = create_scheduler(skip_tokenizer_init=True)
|
||||
requests = create_requests(num_requests=5)
|
||||
for request in requests:
|
||||
scheduler.add_request(request)
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == len(requests)
|
||||
assert output.grammar_bitmask is None
|
||||
|
||||
|
||||
def test_schedule_skip_tokenizer_init_structured_output_request():
|
||||
scheduler = create_scheduler(skip_tokenizer_init=True)
|
||||
guided_params = GuidedDecodingParams(regex="[0-9]+")
|
||||
sampling_params = SamplingParams(
|
||||
ignore_eos=False,
|
||||
max_tokens=16,
|
||||
guided_decoding=guided_params,
|
||||
)
|
||||
request = Request(
|
||||
request_id="0",
|
||||
prompt_token_ids=[0, 1],
|
||||
multi_modal_inputs=None,
|
||||
multi_modal_hashes=None,
|
||||
multi_modal_placeholders=None,
|
||||
sampling_params=sampling_params,
|
||||
pooling_params=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
structured_output_request=StructuredOutputRequest(sampling_params),
|
||||
)
|
||||
scheduler.add_request(request)
|
||||
output = scheduler.schedule()
|
||||
assert len(output.scheduled_new_reqs) == 0
|
||||
assert len(scheduler.running) == 0
|
||||
assert len(scheduler.waiting) == 1
|
||||
|
||||
@ -1,19 +1,30 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm import LLM
|
||||
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from tests.conftest import VllmRunner
|
||||
|
||||
MODEL = "facebook/opt-125m"
|
||||
DTYPE = "half"
|
||||
|
||||
|
||||
def _vllm_model(apc: bool, vllm_runner, monkeypatch):
|
||||
def _vllm_model(
|
||||
apc: bool,
|
||||
vllm_runner: type[VllmRunner],
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
*,
|
||||
skip_tokenizer_init: bool = False,
|
||||
):
|
||||
"""Set up VllmRunner instance."""
|
||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||
return vllm_runner(
|
||||
@ -23,6 +34,7 @@ def _vllm_model(apc: bool, vllm_runner, monkeypatch):
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=apc,
|
||||
gpu_memory_utilization=0.5,
|
||||
skip_tokenizer_init=skip_tokenizer_init,
|
||||
)
|
||||
|
||||
|
||||
@ -45,9 +57,27 @@ def vllm_model_apc(vllm_runner, monkeypatch):
|
||||
yield vllm_model
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
# Function scope decouples tests & allows
|
||||
# env var adjustment via monkeypatch
|
||||
scope="function",
|
||||
# Prefix caching
|
||||
params=[False, True])
|
||||
def vllm_model_skip_tokenizer_init(vllm_runner, request, monkeypatch):
|
||||
"""VllmRunner test fixture with APC."""
|
||||
with _vllm_model(
|
||||
request.param,
|
||||
vllm_runner,
|
||||
monkeypatch,
|
||||
skip_tokenizer_init=True,
|
||||
) as vllm_model:
|
||||
yield vllm_model
|
||||
|
||||
|
||||
def _get_test_sampling_params(
|
||||
prompt_list: list[str],
|
||||
seed: Optional[int] = 42,
|
||||
structured_outputs: bool = False,
|
||||
) -> tuple[list[SamplingParams], list[int]]:
|
||||
"""Generate random sampling params for a batch."""
|
||||
|
||||
@ -62,14 +92,34 @@ def _get_test_sampling_params(
|
||||
n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))]
|
||||
# High temperature to maximize the chance of unique completions
|
||||
return [
|
||||
SamplingParams(temperature=0.95, top_p=0.95, n=n, seed=seed)
|
||||
for n in n_list
|
||||
SamplingParams(
|
||||
temperature=0.95,
|
||||
top_p=0.95,
|
||||
n=n,
|
||||
seed=seed,
|
||||
guided_decoding=GuidedDecodingParams(
|
||||
regex="[0-9]+") if structured_outputs else None,
|
||||
) for n in n_list
|
||||
], n_list
|
||||
|
||||
|
||||
def test_compatibility_with_skip_tokenizer_init(
|
||||
vllm_model_skip_tokenizer_init: VllmRunner,
|
||||
example_prompts: list[str],
|
||||
):
|
||||
# Case 1: Structured output request should raise an error.
|
||||
sampling_params_list, _ = _get_test_sampling_params(
|
||||
example_prompts,
|
||||
structured_outputs=True,
|
||||
)
|
||||
model: LLM = vllm_model_skip_tokenizer_init.model
|
||||
with pytest.raises(ValueError):
|
||||
_ = model.generate(example_prompts, sampling_params_list)
|
||||
|
||||
|
||||
def test_parallel_sampling(vllm_model, example_prompts) -> None:
|
||||
"""Test passes if parallel sampling `n>1` yields `n` unique completions.
|
||||
|
||||
|
||||
Args:
|
||||
vllm_model: VllmRunner instance under test.
|
||||
example_prompt: test fixture providing prompts for testing.
|
||||
|
||||
@ -152,6 +152,11 @@ class Processor:
|
||||
if not params.guided_decoding or not self.decoding_config:
|
||||
return
|
||||
|
||||
if self.model_config.skip_tokenizer_init and params.guided_decoding:
|
||||
raise ValueError(
|
||||
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501
|
||||
)
|
||||
|
||||
engine_level_backend = self.decoding_config.backend
|
||||
if params.guided_decoding.backend:
|
||||
# Request-level backend selection is not supported in V1.
|
||||
|
||||
@ -40,22 +40,25 @@ class StructuredOutputManager:
|
||||
self._grammar_bitmask: Optional[torch.Tensor] = None
|
||||
self._full_mask = torch.tensor(-1, dtype=torch.int32)
|
||||
|
||||
# The default max_workers if not specified is the number of CPUs * 5,
|
||||
# which is way too high since these tasks are CPU-bound, not I/O bound.
|
||||
# We also know we would never dominate CPU usage with just grammar
|
||||
# compilation, so we set it to half the number of CPUs.
|
||||
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=self.vllm_config.model_config,
|
||||
scheduler_config=self.vllm_config.scheduler_config,
|
||||
lora_config=self.vllm_config.lora_config,
|
||||
).get_lora_tokenizer(None)
|
||||
reasoning_backend = vllm_config.decoding_config.reasoning_backend
|
||||
if reasoning_backend:
|
||||
reasoner_cls = ReasoningParserManager.get_reasoning_parser(
|
||||
reasoning_backend)
|
||||
self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
|
||||
if not self.vllm_config.model_config.skip_tokenizer_init:
|
||||
# The default max_workers if not specified is the number of
|
||||
# CPUs * 5, which is way too high since these tasks are CPU-bound,
|
||||
# not I/O bound. We also know we would never dominate CPU usage
|
||||
# with just grammar compilation, so we set it to half the number
|
||||
# of CPUs.
|
||||
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
model_config=self.vllm_config.model_config,
|
||||
scheduler_config=self.vllm_config.scheduler_config,
|
||||
lora_config=self.vllm_config.lora_config,
|
||||
).get_lora_tokenizer(None)
|
||||
reasoning_backend = \
|
||||
self.vllm_config.decoding_config.reasoning_backend
|
||||
if reasoning_backend:
|
||||
reasoner_cls = ReasoningParserManager.get_reasoning_parser(
|
||||
reasoning_backend)
|
||||
self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
|
||||
|
||||
def grammar_init(self, request: Request) -> None:
|
||||
if request.structured_output_request is None:
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user