[Feature] Pydantic validation for speculative.py (#27156)

Signed-off-by: Navya Srivastava <navya.srivastava1707@gmail.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Navya Srivastava 2025-10-23 05:19:33 -07:00 committed by GitHub
parent 570c3e1cd4
commit faee3ccdc2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,7 +5,7 @@ import ast
import hashlib
from typing import TYPE_CHECKING, Any, Literal
from pydantic import SkipValidation, model_validator
from pydantic import Field, SkipValidation, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
@ -62,7 +62,7 @@ class SpeculativeConfig:
enforce_eager: bool | None = None
"""Override the default enforce_eager from model_config"""
# General speculative decoding control
num_speculative_tokens: SkipValidation[int] = None # type: ignore
num_speculative_tokens: int = Field(default=None, gt=0)
"""The number of speculative tokens, if provided. It will default to the
number in the draft model config if present, otherwise, it is required."""
model: str | None = None
@ -76,7 +76,7 @@ class SpeculativeConfig:
If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered."""
draft_tensor_parallel_size: int | None = None
draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
"""The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size."""
disable_logprobs: bool = True
@ -89,7 +89,7 @@ class SpeculativeConfig:
"""Quantization method that was used to quantize the draft model weights.
If `None`, we assume the model weights are not quantized. Note that it only
takes effect when using the draft model-based speculative method."""
max_model_len: int | None = None
max_model_len: int | None = Field(default=None, ge=1)
"""The maximum model length of the draft model. Used when testing the
ability to skip speculation for some sequences."""
revision: str | None = None
@ -102,7 +102,7 @@ class SpeculativeConfig:
will use the default version."""
# Advanced control
disable_by_batch_size: int | None = None
disable_by_batch_size: int | None = Field(default=None, ge=2)
"""Disable speculative decoding for new incoming requests when the number
of enqueued requests is larger than this value, if provided."""
disable_padded_drafter_batch: bool = False
@ -112,10 +112,10 @@ class SpeculativeConfig:
only affects the EAGLE method of speculation."""
# Ngram proposer configuration
prompt_lookup_max: int | None = None
prompt_lookup_max: int | None = Field(default=None, ge=1)
"""Maximum size of ngram token window when using Ngram proposer, required
when method is set to ngram."""
prompt_lookup_min: int | None = None
prompt_lookup_min: int | None = Field(default=None, ge=1)
"""Minimum size of ngram token window when using Ngram proposer, if
provided. Defaults to 1."""
@ -232,9 +232,8 @@ class SpeculativeConfig:
if self.model is None and self.num_speculative_tokens is not None:
if self.method == "mtp":
assert self.target_model_config is not None, (
"target_model_config must be present for mtp"
)
if self.target_model_config is None:
raise ValueError("target_model_config must be present for mtp")
if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
# FIXME(luccafong): cudgraph with v32 MTP is not supported,
# remove this when the issue is fixed.
@ -268,21 +267,21 @@ class SpeculativeConfig:
self.prompt_lookup_min = 5
self.prompt_lookup_max = 5
elif self.prompt_lookup_min is None:
assert self.prompt_lookup_max is not None
if self.prompt_lookup_max is None:
raise ValueError(
"Either prompt_lookup_max or prompt_lookup_min must be "
"provided when using the ngram method."
)
self.prompt_lookup_min = self.prompt_lookup_max
elif self.prompt_lookup_max is None:
assert self.prompt_lookup_min is not None
if self.prompt_lookup_min is None:
raise ValueError(
"Either prompt_lookup_max or prompt_lookup_min must be "
"provided when using the ngram method."
)
self.prompt_lookup_max = self.prompt_lookup_min
# Validate values
if self.prompt_lookup_min < 1:
raise ValueError(
f"prompt_lookup_min={self.prompt_lookup_min} must be > 0"
)
if self.prompt_lookup_max < 1:
raise ValueError(
f"prompt_lookup_max={self.prompt_lookup_max} must be > 0"
)
if self.prompt_lookup_min > self.prompt_lookup_max:
raise ValueError(
f"prompt_lookup_min={self.prompt_lookup_min} must "
@ -446,6 +445,7 @@ class SpeculativeConfig:
self.target_parallel_config, self.draft_tensor_parallel_size
)
)
return self
@staticmethod
def _maybe_override_draft_max_model_len(