[Feature] Pydantic validation for speculative.py (#27156)

Signed-off-by: Navya Srivastava <navya.srivastava1707@gmail.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Navya Srivastava 2025-10-23 05:19:33 -07:00 committed by GitHub
parent 570c3e1cd4
commit faee3ccdc2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -5,7 +5,7 @@ import ast
import hashlib import hashlib
from typing import TYPE_CHECKING, Any, Literal from typing import TYPE_CHECKING, Any, Literal
from pydantic import SkipValidation, model_validator from pydantic import Field, SkipValidation, model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from typing_extensions import Self from typing_extensions import Self
@ -62,7 +62,7 @@ class SpeculativeConfig:
enforce_eager: bool | None = None enforce_eager: bool | None = None
"""Override the default enforce_eager from model_config""" """Override the default enforce_eager from model_config"""
# General speculative decoding control # General speculative decoding control
num_speculative_tokens: SkipValidation[int] = None # type: ignore num_speculative_tokens: int = Field(default=None, gt=0)
"""The number of speculative tokens, if provided. It will default to the """The number of speculative tokens, if provided. It will default to the
number in the draft model config if present, otherwise, it is required.""" number in the draft model config if present, otherwise, it is required."""
model: str | None = None model: str | None = None
@ -76,7 +76,7 @@ class SpeculativeConfig:
If using `ngram` method, the related configuration `prompt_lookup_max` and If using `ngram` method, the related configuration `prompt_lookup_max` and
`prompt_lookup_min` should be considered.""" `prompt_lookup_min` should be considered."""
draft_tensor_parallel_size: int | None = None draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
"""The degree of the tensor parallelism for the draft model. Can only be 1 """The degree of the tensor parallelism for the draft model. Can only be 1
or the same as the target model's tensor parallel size.""" or the same as the target model's tensor parallel size."""
disable_logprobs: bool = True disable_logprobs: bool = True
@ -89,7 +89,7 @@ class SpeculativeConfig:
"""Quantization method that was used to quantize the draft model weights. """Quantization method that was used to quantize the draft model weights.
If `None`, we assume the model weights are not quantized. Note that it only If `None`, we assume the model weights are not quantized. Note that it only
takes effect when using the draft model-based speculative method.""" takes effect when using the draft model-based speculative method."""
max_model_len: int | None = None max_model_len: int | None = Field(default=None, ge=1)
"""The maximum model length of the draft model. Used when testing the """The maximum model length of the draft model. Used when testing the
ability to skip speculation for some sequences.""" ability to skip speculation for some sequences."""
revision: str | None = None revision: str | None = None
@ -102,7 +102,7 @@ class SpeculativeConfig:
will use the default version.""" will use the default version."""
# Advanced control # Advanced control
disable_by_batch_size: int | None = None disable_by_batch_size: int | None = Field(default=None, ge=2)
"""Disable speculative decoding for new incoming requests when the number """Disable speculative decoding for new incoming requests when the number
of enqueued requests is larger than this value, if provided.""" of enqueued requests is larger than this value, if provided."""
disable_padded_drafter_batch: bool = False disable_padded_drafter_batch: bool = False
@ -112,10 +112,10 @@ class SpeculativeConfig:
only affects the EAGLE method of speculation.""" only affects the EAGLE method of speculation."""
# Ngram proposer configuration # Ngram proposer configuration
prompt_lookup_max: int | None = None prompt_lookup_max: int | None = Field(default=None, ge=1)
"""Maximum size of ngram token window when using Ngram proposer, required """Maximum size of ngram token window when using Ngram proposer, required
when method is set to ngram.""" when method is set to ngram."""
prompt_lookup_min: int | None = None prompt_lookup_min: int | None = Field(default=None, ge=1)
"""Minimum size of ngram token window when using Ngram proposer, if """Minimum size of ngram token window when using Ngram proposer, if
provided. Defaults to 1.""" provided. Defaults to 1."""
@ -232,9 +232,8 @@ class SpeculativeConfig:
if self.model is None and self.num_speculative_tokens is not None: if self.model is None and self.num_speculative_tokens is not None:
if self.method == "mtp": if self.method == "mtp":
assert self.target_model_config is not None, ( if self.target_model_config is None:
"target_model_config must be present for mtp" raise ValueError("target_model_config must be present for mtp")
)
if self.target_model_config.hf_text_config.model_type == "deepseek_v32": if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
# FIXME(luccafong): cudgraph with v32 MTP is not supported, # FIXME(luccafong): cudgraph with v32 MTP is not supported,
# remove this when the issue is fixed. # remove this when the issue is fixed.
@ -268,21 +267,21 @@ class SpeculativeConfig:
self.prompt_lookup_min = 5 self.prompt_lookup_min = 5
self.prompt_lookup_max = 5 self.prompt_lookup_max = 5
elif self.prompt_lookup_min is None: elif self.prompt_lookup_min is None:
assert self.prompt_lookup_max is not None if self.prompt_lookup_max is None:
raise ValueError(
"Either prompt_lookup_max or prompt_lookup_min must be "
"provided when using the ngram method."
)
self.prompt_lookup_min = self.prompt_lookup_max self.prompt_lookup_min = self.prompt_lookup_max
elif self.prompt_lookup_max is None: elif self.prompt_lookup_max is None:
assert self.prompt_lookup_min is not None if self.prompt_lookup_min is None:
raise ValueError(
"Either prompt_lookup_max or prompt_lookup_min must be "
"provided when using the ngram method."
)
self.prompt_lookup_max = self.prompt_lookup_min self.prompt_lookup_max = self.prompt_lookup_min
# Validate values # Validate values
if self.prompt_lookup_min < 1:
raise ValueError(
f"prompt_lookup_min={self.prompt_lookup_min} must be > 0"
)
if self.prompt_lookup_max < 1:
raise ValueError(
f"prompt_lookup_max={self.prompt_lookup_max} must be > 0"
)
if self.prompt_lookup_min > self.prompt_lookup_max: if self.prompt_lookup_min > self.prompt_lookup_max:
raise ValueError( raise ValueError(
f"prompt_lookup_min={self.prompt_lookup_min} must " f"prompt_lookup_min={self.prompt_lookup_min} must "
@ -446,6 +445,7 @@ class SpeculativeConfig:
self.target_parallel_config, self.draft_tensor_parallel_size self.target_parallel_config, self.draft_tensor_parallel_size
) )
) )
return self
@staticmethod @staticmethod
def _maybe_override_draft_max_model_len( def _maybe_override_draft_max_model_len(