[Feature] Pydantic validation for scheduler.py and structured_outputs.py (#26519)

Signed-off-by: Vinay Damodaran <vrdn@hey.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Vinay R Damodaran 2025-10-31 11:05:50 -07:00 committed by GitHub
parent 9e5bd3076e
commit 5e8862e9e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 39 additions and 35 deletions

View File

@ -2,10 +2,11 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from dataclasses import InitVar, field
from collections.abc import Callable
from dataclasses import InitVar
from typing import Any, Literal
from pydantic import SkipValidation, model_validator
from pydantic import Field, field_validator, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
@ -31,28 +32,28 @@ class SchedulerConfig:
runner_type: RunnerType = "generate"
"""The runner type to launch for the model."""
max_num_batched_tokens: SkipValidation[int] = None # type: ignore
max_num_batched_tokens: int = Field(default=None, ge=1)
"""Maximum number of tokens to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
max_num_seqs: SkipValidation[int] = None # type: ignore
max_num_seqs: int = Field(default=None, ge=1)
"""Maximum number of sequences to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context."""
max_model_len: SkipValidation[int] = None # type: ignore
max_model_len: int = Field(default=None, ge=1)
"""Maximum length of a sequence (including prompt and generated text). This
is primarily set in `ModelConfig` and that value should be manually
duplicated here."""
max_num_partial_prefills: int = 1
max_num_partial_prefills: int = Field(default=1, ge=1)
"""For chunked prefill, the maximum number of sequences that can be
partially prefilled concurrently."""
max_long_partial_prefills: int = 1
max_long_partial_prefills: int = Field(default=1, ge=1)
"""For chunked prefill, the maximum number of prompts longer than
long_prefill_token_threshold that will be prefilled concurrently. Setting
this less than max_num_partial_prefills will allow shorter prompts to jump
@ -62,7 +63,7 @@ class SchedulerConfig:
"""For chunked prefill, a request is considered long if the prompt is
longer than this number of tokens."""
num_lookahead_slots: int = 0
num_lookahead_slots: int = Field(default=0, ge=0)
"""The number of slots to allocate per sequence per
step, beyond the known token ids. This is used in speculative
decoding to store KV activations of tokens which may or may not be
@ -71,7 +72,7 @@ class SchedulerConfig:
NOTE: This will be replaced by speculative config in the future; it is
present to enable correctness tests until then."""
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
enable_chunked_prefill: bool = Field(default=None)
"""If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens."""
@ -86,14 +87,14 @@ class SchedulerConfig:
"""
# TODO (ywang96): Make this configurable.
max_num_encoder_input_tokens: int = field(init=False)
max_num_encoder_input_tokens: int = Field(init=False)
"""Multimodal encoder compute budget, only used in V1.
NOTE: This is not currently configurable. It will be overridden by
max_num_batched_tokens in case max multimodal embedding size is larger."""
# TODO (ywang96): Make this configurable.
encoder_cache_size: int = field(init=False)
encoder_cache_size: int = Field(init=False)
"""Multimodal encoder cache size, only used in V1.
NOTE: This is not currently configurable. It will be overridden by
@ -106,7 +107,7 @@ class SchedulerConfig:
- "priority" means requests are handled based on given priority (lower
value means earlier handling) and time of arrival deciding any ties)."""
chunked_prefill_enabled: bool = field(init=False)
chunked_prefill_enabled: bool = Field(init=False)
"""True if chunked prefill is enabled."""
disable_chunked_mm_input: bool = False
@ -155,6 +156,20 @@ class SchedulerConfig:
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
@field_validator(
"max_num_batched_tokens",
"max_num_seqs",
"max_model_len",
"enable_chunked_prefill",
mode="wrap",
)
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
"""Skip validation if the value is `None` when initialisation is delayed."""
if value is None:
return value
return handler(value)
def __post_init__(self, is_encoder_decoder: bool) -> None:
if self.max_model_len is None:
self.max_model_len = 8192
@ -260,19 +275,7 @@ class SchedulerConfig:
self.max_num_seqs * self.max_model_len,
)
if self.num_lookahead_slots < 0:
raise ValueError(
"num_lookahead_slots "
f"({self.num_lookahead_slots}) must be greater than or "
"equal to 0."
)
if self.max_num_partial_prefills < 1:
raise ValueError(
f"max_num_partial_prefills ({self.max_num_partial_prefills}) "
"must be greater than or equal to 1."
)
elif self.max_num_partial_prefills > 1:
if self.max_num_partial_prefills > 1:
if not self.chunked_prefill_enabled:
raise ValueError(
"Chunked prefill must be enabled to set "
@ -286,13 +289,10 @@ class SchedulerConfig:
f"than the max_model_len ({self.max_model_len})."
)
if (self.max_long_partial_prefills < 1) or (
self.max_long_partial_prefills > self.max_num_partial_prefills
):
if self.max_long_partial_prefills > self.max_num_partial_prefills:
raise ValueError(
f"max_long_partial_prefills ({self.max_long_partial_prefills}) "
"must be greater than or equal to 1 and less than or equal to "
f"max_num_partial_prefills ({self.max_num_partial_prefills})."
f"{self.max_long_partial_prefills=} must be less than or equal to "
f"{self.max_num_partial_prefills=}."
)
return self

View File

@ -2,8 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from typing import Any, Literal
from typing import Any, Literal, Self
from pydantic import model_validator
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
@ -56,7 +57,8 @@ class StructuredOutputsConfig:
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
@model_validator(mode="after")
def _validate_structured_output_config(self) -> Self:
if self.disable_any_whitespace and self.backend not in ("xgrammar", "guidance"):
raise ValueError(
"disable_any_whitespace is only supported for "
@ -67,3 +69,4 @@ class StructuredOutputsConfig:
"disable_additional_properties is only supported "
"for the guidance backend."
)
return self

View File

@ -1807,7 +1807,7 @@ class EngineArgs:
incremental_prefill_supported = (
pooling_type is not None
and pooling_type.lower() == "last"
and is_causal
and bool(is_causal)
)
action = "Enabling" if incremental_prefill_supported else "Disabling"

View File

@ -2,11 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import json
import re
import uuid
from collections.abc import Sequence
from typing import Any
import regex as re
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
DeltaFunctionCall,