mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-11 04:54:59 +08:00
[Structured Outputs][V1] Skipping with models doesn't contain tokenizers (#20365)
Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
parent
a7bab0c9e5
commit
4a98edff1f
@ -9,7 +9,7 @@ import torch
|
|||||||
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
from vllm.config import (CacheConfig, KVTransferConfig, ModelConfig,
|
||||||
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
SchedulerConfig, SpeculativeConfig, VllmConfig)
|
||||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||||
from vllm.sampling_params import SamplingParams
|
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||||
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
from vllm.v1.core.sched.output import CachedRequestData, SchedulerOutput
|
||||||
from vllm.v1.core.sched.scheduler import Scheduler
|
from vllm.v1.core.sched.scheduler import Scheduler
|
||||||
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
||||||
@ -17,6 +17,7 @@ from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
|
|||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.v1.request import Request, RequestStatus
|
from vllm.v1.request import Request, RequestStatus
|
||||||
from vllm.v1.structured_output import StructuredOutputManager
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
from vllm.v1.structured_output.request import StructuredOutputRequest
|
||||||
|
|
||||||
EOS_TOKEN_ID = 50256
|
EOS_TOKEN_ID = 50256
|
||||||
|
|
||||||
@ -33,6 +34,7 @@ def create_scheduler(
|
|||||||
block_size: int = 16,
|
block_size: int = 16,
|
||||||
max_model_len: Optional[int] = None,
|
max_model_len: Optional[int] = None,
|
||||||
num_speculative_tokens: Optional[int] = None,
|
num_speculative_tokens: Optional[int] = None,
|
||||||
|
skip_tokenizer_init: bool = False,
|
||||||
) -> Scheduler:
|
) -> Scheduler:
|
||||||
'''Create scheduler under test.
|
'''Create scheduler under test.
|
||||||
|
|
||||||
@ -65,6 +67,7 @@ def create_scheduler(
|
|||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
dtype="float16",
|
dtype="float16",
|
||||||
seed=42,
|
seed=42,
|
||||||
|
skip_tokenizer_init=skip_tokenizer_init,
|
||||||
)
|
)
|
||||||
# Cache config, optionally force APC
|
# Cache config, optionally force APC
|
||||||
kwargs_cache = ({} if enable_prefix_caching is None else {
|
kwargs_cache = ({} if enable_prefix_caching is None else {
|
||||||
@ -186,7 +189,7 @@ def test_get_num_unfinished_requests():
|
|||||||
])
|
])
|
||||||
def test_schedule(enable_prefix_caching: Optional[bool],
|
def test_schedule(enable_prefix_caching: Optional[bool],
|
||||||
prompt_logprobs: Optional[int]):
|
prompt_logprobs: Optional[int]):
|
||||||
'''Test scheduling.
|
'''Test scheduling.
|
||||||
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
|
Two cases: default APC/no prompt logprobs; APC=True + prompt logprobs
|
||||||
'''
|
'''
|
||||||
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
|
scheduler = create_scheduler(enable_prefix_caching=enable_prefix_caching)
|
||||||
@ -1408,7 +1411,7 @@ def create_requests_with_priority(
|
|||||||
|
|
||||||
|
|
||||||
def test_priority_scheduling_basic_ordering():
|
def test_priority_scheduling_basic_ordering():
|
||||||
"""Test that requests are scheduled in priority order
|
"""Test that requests are scheduled in priority order
|
||||||
(lower value = higher priority)."""
|
(lower value = higher priority)."""
|
||||||
scheduler = create_scheduler_with_priority()
|
scheduler = create_scheduler_with_priority()
|
||||||
|
|
||||||
@ -1437,7 +1440,7 @@ def test_priority_scheduling_basic_ordering():
|
|||||||
|
|
||||||
|
|
||||||
def test_priority_scheduling_arrival_time_tiebreaker():
|
def test_priority_scheduling_arrival_time_tiebreaker():
|
||||||
"""Test that arrival time is used
|
"""Test that arrival time is used
|
||||||
as tiebreaker when priorities are equal."""
|
as tiebreaker when priorities are equal."""
|
||||||
scheduler = create_scheduler_with_priority()
|
scheduler = create_scheduler_with_priority()
|
||||||
|
|
||||||
@ -1495,7 +1498,7 @@ def test_priority_scheduling_mixed_priority_and_arrival():
|
|||||||
|
|
||||||
|
|
||||||
def test_priority_scheduling_preemption():
|
def test_priority_scheduling_preemption():
|
||||||
"""Test that priority scheduling preempts
|
"""Test that priority scheduling preempts
|
||||||
lower priority requests when memory is constrained."""
|
lower priority requests when memory is constrained."""
|
||||||
# Create scheduler with very limited memory to force preemption
|
# Create scheduler with very limited memory to force preemption
|
||||||
scheduler = create_scheduler_with_priority(
|
scheduler = create_scheduler_with_priority(
|
||||||
@ -1576,7 +1579,7 @@ def test_priority_scheduling_preemption():
|
|||||||
|
|
||||||
|
|
||||||
def test_priority_scheduling_no_preemption_when_space_available():
|
def test_priority_scheduling_no_preemption_when_space_available():
|
||||||
"""Test that preemption doesn't happen
|
"""Test that preemption doesn't happen
|
||||||
when there's space for new requests."""
|
when there's space for new requests."""
|
||||||
scheduler = create_scheduler_with_priority(
|
scheduler = create_scheduler_with_priority(
|
||||||
max_num_seqs=3, # Allow 3 concurrent requests
|
max_num_seqs=3, # Allow 3 concurrent requests
|
||||||
@ -1626,7 +1629,7 @@ def test_priority_scheduling_no_preemption_when_space_available():
|
|||||||
|
|
||||||
|
|
||||||
def test_priority_scheduling_preemption_victim_selection():
|
def test_priority_scheduling_preemption_victim_selection():
|
||||||
"""Test that the correct victim is selected for
|
"""Test that the correct victim is selected for
|
||||||
preemption based on priority and arrival time."""
|
preemption based on priority and arrival time."""
|
||||||
# This test verifies the priority-based victim selection logic
|
# This test verifies the priority-based victim selection logic
|
||||||
# by checking the waiting queue order after adding requests with different
|
# by checking the waiting queue order after adding requests with different
|
||||||
@ -1743,7 +1746,7 @@ def test_priority_scheduling_waiting_queue_order():
|
|||||||
|
|
||||||
|
|
||||||
def test_priority_scheduling_fcfs_fallback():
|
def test_priority_scheduling_fcfs_fallback():
|
||||||
"""Test that FCFS behavior is maintained when all
|
"""Test that FCFS behavior is maintained when all
|
||||||
requests have same priority."""
|
requests have same priority."""
|
||||||
scheduler = create_scheduler_with_priority()
|
scheduler = create_scheduler_with_priority()
|
||||||
|
|
||||||
@ -1811,7 +1814,7 @@ def test_priority_scheduling_with_limited_slots():
|
|||||||
|
|
||||||
|
|
||||||
def test_priority_scheduling_heap_property():
|
def test_priority_scheduling_heap_property():
|
||||||
"""Test that the waiting queue maintains heap
|
"""Test that the waiting queue maintains heap
|
||||||
property for priority scheduling."""
|
property for priority scheduling."""
|
||||||
scheduler = create_scheduler_with_priority(
|
scheduler = create_scheduler_with_priority(
|
||||||
max_num_seqs=1, # Only one request can run at a time
|
max_num_seqs=1, # Only one request can run at a time
|
||||||
@ -1857,3 +1860,39 @@ def test_priority_scheduling_heap_property():
|
|||||||
# Verify requests were scheduled in priority order (lowest value first)
|
# Verify requests were scheduled in priority order (lowest value first)
|
||||||
expected_priorities = sorted(priorities)
|
expected_priorities = sorted(priorities)
|
||||||
assert scheduled_priorities == expected_priorities
|
assert scheduled_priorities == expected_priorities
|
||||||
|
|
||||||
|
|
||||||
|
def test_schedule_skip_tokenizer_init():
|
||||||
|
scheduler = create_scheduler(skip_tokenizer_init=True)
|
||||||
|
requests = create_requests(num_requests=5)
|
||||||
|
for request in requests:
|
||||||
|
scheduler.add_request(request)
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(output.scheduled_new_reqs) == len(requests)
|
||||||
|
assert output.grammar_bitmask is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_schedule_skip_tokenizer_init_structured_output_request():
|
||||||
|
scheduler = create_scheduler(skip_tokenizer_init=True)
|
||||||
|
guided_params = GuidedDecodingParams(regex="[0-9]+")
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
ignore_eos=False,
|
||||||
|
max_tokens=16,
|
||||||
|
guided_decoding=guided_params,
|
||||||
|
)
|
||||||
|
request = Request(
|
||||||
|
request_id="0",
|
||||||
|
prompt_token_ids=[0, 1],
|
||||||
|
multi_modal_inputs=None,
|
||||||
|
multi_modal_hashes=None,
|
||||||
|
multi_modal_placeholders=None,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
pooling_params=None,
|
||||||
|
eos_token_id=EOS_TOKEN_ID,
|
||||||
|
structured_output_request=StructuredOutputRequest(sampling_params),
|
||||||
|
)
|
||||||
|
scheduler.add_request(request)
|
||||||
|
output = scheduler.schedule()
|
||||||
|
assert len(output.scheduled_new_reqs) == 0
|
||||||
|
assert len(scheduler.running) == 0
|
||||||
|
assert len(scheduler.waiting) == 1
|
||||||
|
|||||||
@ -1,19 +1,30 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import random
|
import random
|
||||||
from typing import Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM
|
||||||
|
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
|
||||||
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector
|
from vllm.v1.metrics.reader import Counter, Gauge, Histogram, Metric, Vector
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from tests.conftest import VllmRunner
|
||||||
|
|
||||||
MODEL = "facebook/opt-125m"
|
MODEL = "facebook/opt-125m"
|
||||||
DTYPE = "half"
|
DTYPE = "half"
|
||||||
|
|
||||||
|
|
||||||
def _vllm_model(apc: bool, vllm_runner, monkeypatch):
|
def _vllm_model(
|
||||||
|
apc: bool,
|
||||||
|
vllm_runner: type[VllmRunner],
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
*,
|
||||||
|
skip_tokenizer_init: bool = False,
|
||||||
|
):
|
||||||
"""Set up VllmRunner instance."""
|
"""Set up VllmRunner instance."""
|
||||||
monkeypatch.setenv("VLLM_USE_V1", "1")
|
monkeypatch.setenv("VLLM_USE_V1", "1")
|
||||||
return vllm_runner(
|
return vllm_runner(
|
||||||
@ -23,6 +34,7 @@ def _vllm_model(apc: bool, vllm_runner, monkeypatch):
|
|||||||
enforce_eager=True,
|
enforce_eager=True,
|
||||||
enable_prefix_caching=apc,
|
enable_prefix_caching=apc,
|
||||||
gpu_memory_utilization=0.5,
|
gpu_memory_utilization=0.5,
|
||||||
|
skip_tokenizer_init=skip_tokenizer_init,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -45,9 +57,27 @@ def vllm_model_apc(vllm_runner, monkeypatch):
|
|||||||
yield vllm_model
|
yield vllm_model
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(
|
||||||
|
# Function scope decouples tests & allows
|
||||||
|
# env var adjustment via monkeypatch
|
||||||
|
scope="function",
|
||||||
|
# Prefix caching
|
||||||
|
params=[False, True])
|
||||||
|
def vllm_model_skip_tokenizer_init(vllm_runner, request, monkeypatch):
|
||||||
|
"""VllmRunner test fixture with APC."""
|
||||||
|
with _vllm_model(
|
||||||
|
request.param,
|
||||||
|
vllm_runner,
|
||||||
|
monkeypatch,
|
||||||
|
skip_tokenizer_init=True,
|
||||||
|
) as vllm_model:
|
||||||
|
yield vllm_model
|
||||||
|
|
||||||
|
|
||||||
def _get_test_sampling_params(
|
def _get_test_sampling_params(
|
||||||
prompt_list: list[str],
|
prompt_list: list[str],
|
||||||
seed: Optional[int] = 42,
|
seed: Optional[int] = 42,
|
||||||
|
structured_outputs: bool = False,
|
||||||
) -> tuple[list[SamplingParams], list[int]]:
|
) -> tuple[list[SamplingParams], list[int]]:
|
||||||
"""Generate random sampling params for a batch."""
|
"""Generate random sampling params for a batch."""
|
||||||
|
|
||||||
@ -62,14 +92,34 @@ def _get_test_sampling_params(
|
|||||||
n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))]
|
n_list = [get_mostly_n_gt1() for _ in range(len(prompt_list))]
|
||||||
# High temperature to maximize the chance of unique completions
|
# High temperature to maximize the chance of unique completions
|
||||||
return [
|
return [
|
||||||
SamplingParams(temperature=0.95, top_p=0.95, n=n, seed=seed)
|
SamplingParams(
|
||||||
for n in n_list
|
temperature=0.95,
|
||||||
|
top_p=0.95,
|
||||||
|
n=n,
|
||||||
|
seed=seed,
|
||||||
|
guided_decoding=GuidedDecodingParams(
|
||||||
|
regex="[0-9]+") if structured_outputs else None,
|
||||||
|
) for n in n_list
|
||||||
], n_list
|
], n_list
|
||||||
|
|
||||||
|
|
||||||
|
def test_compatibility_with_skip_tokenizer_init(
|
||||||
|
vllm_model_skip_tokenizer_init: VllmRunner,
|
||||||
|
example_prompts: list[str],
|
||||||
|
):
|
||||||
|
# Case 1: Structured output request should raise an error.
|
||||||
|
sampling_params_list, _ = _get_test_sampling_params(
|
||||||
|
example_prompts,
|
||||||
|
structured_outputs=True,
|
||||||
|
)
|
||||||
|
model: LLM = vllm_model_skip_tokenizer_init.model
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_ = model.generate(example_prompts, sampling_params_list)
|
||||||
|
|
||||||
|
|
||||||
def test_parallel_sampling(vllm_model, example_prompts) -> None:
|
def test_parallel_sampling(vllm_model, example_prompts) -> None:
|
||||||
"""Test passes if parallel sampling `n>1` yields `n` unique completions.
|
"""Test passes if parallel sampling `n>1` yields `n` unique completions.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
vllm_model: VllmRunner instance under test.
|
vllm_model: VllmRunner instance under test.
|
||||||
example_prompt: test fixture providing prompts for testing.
|
example_prompt: test fixture providing prompts for testing.
|
||||||
|
|||||||
@ -152,6 +152,11 @@ class Processor:
|
|||||||
if not params.guided_decoding or not self.decoding_config:
|
if not params.guided_decoding or not self.decoding_config:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if self.model_config.skip_tokenizer_init and params.guided_decoding:
|
||||||
|
raise ValueError(
|
||||||
|
"Structured outputs requires a tokenizer so it can't be used with 'skip_tokenizer_init'" # noqa: E501
|
||||||
|
)
|
||||||
|
|
||||||
engine_level_backend = self.decoding_config.backend
|
engine_level_backend = self.decoding_config.backend
|
||||||
if params.guided_decoding.backend:
|
if params.guided_decoding.backend:
|
||||||
# Request-level backend selection is not supported in V1.
|
# Request-level backend selection is not supported in V1.
|
||||||
|
|||||||
@ -40,22 +40,25 @@ class StructuredOutputManager:
|
|||||||
self._grammar_bitmask: Optional[torch.Tensor] = None
|
self._grammar_bitmask: Optional[torch.Tensor] = None
|
||||||
self._full_mask = torch.tensor(-1, dtype=torch.int32)
|
self._full_mask = torch.tensor(-1, dtype=torch.int32)
|
||||||
|
|
||||||
# The default max_workers if not specified is the number of CPUs * 5,
|
if not self.vllm_config.model_config.skip_tokenizer_init:
|
||||||
# which is way too high since these tasks are CPU-bound, not I/O bound.
|
# The default max_workers if not specified is the number of
|
||||||
# We also know we would never dominate CPU usage with just grammar
|
# CPUs * 5, which is way too high since these tasks are CPU-bound,
|
||||||
# compilation, so we set it to half the number of CPUs.
|
# not I/O bound. We also know we would never dominate CPU usage
|
||||||
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
# with just grammar compilation, so we set it to half the number
|
||||||
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
# of CPUs.
|
||||||
self.tokenizer = init_tokenizer_from_configs(
|
max_workers = max(1, (multiprocessing.cpu_count() + 1) // 2)
|
||||||
model_config=self.vllm_config.model_config,
|
self.executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||||
scheduler_config=self.vllm_config.scheduler_config,
|
self.tokenizer = init_tokenizer_from_configs(
|
||||||
lora_config=self.vllm_config.lora_config,
|
model_config=self.vllm_config.model_config,
|
||||||
).get_lora_tokenizer(None)
|
scheduler_config=self.vllm_config.scheduler_config,
|
||||||
reasoning_backend = vllm_config.decoding_config.reasoning_backend
|
lora_config=self.vllm_config.lora_config,
|
||||||
if reasoning_backend:
|
).get_lora_tokenizer(None)
|
||||||
reasoner_cls = ReasoningParserManager.get_reasoning_parser(
|
reasoning_backend = \
|
||||||
reasoning_backend)
|
self.vllm_config.decoding_config.reasoning_backend
|
||||||
self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
|
if reasoning_backend:
|
||||||
|
reasoner_cls = ReasoningParserManager.get_reasoning_parser(
|
||||||
|
reasoning_backend)
|
||||||
|
self.reasoner = reasoner_cls(tokenizer=self.tokenizer)
|
||||||
|
|
||||||
def grammar_init(self, request: Request) -> None:
|
def grammar_init(self, request: Request) -> None:
|
||||||
if request.structured_output_request is None:
|
if request.structured_output_request is None:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user