From faee3ccdc21dd6b3a1cfb5ae2146163f6567bdaa Mon Sep 17 00:00:00 2001 From: Navya Srivastava <143343265+Navya1707@users.noreply.github.com> Date: Thu, 23 Oct 2025 05:19:33 -0700 Subject: [PATCH] [Feature] Pydantic validation for speculative.py (#27156) Signed-off-by: Navya Srivastava Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> --- vllm/config/speculative.py | 40 +++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index a5bc4d1fa3c07..4c7b7369ed4b5 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -5,7 +5,7 @@ import ast import hashlib from typing import TYPE_CHECKING, Any, Literal -from pydantic import SkipValidation, model_validator +from pydantic import Field, SkipValidation, model_validator from pydantic.dataclasses import dataclass from typing_extensions import Self @@ -62,7 +62,7 @@ class SpeculativeConfig: enforce_eager: bool | None = None """Override the default enforce_eager from model_config""" # General speculative decoding control - num_speculative_tokens: SkipValidation[int] = None # type: ignore + num_speculative_tokens: int = Field(default=None, gt=0) """The number of speculative tokens, if provided. It will default to the number in the draft model config if present, otherwise, it is required.""" model: str | None = None @@ -76,7 +76,7 @@ class SpeculativeConfig: If using `ngram` method, the related configuration `prompt_lookup_max` and `prompt_lookup_min` should be considered.""" - draft_tensor_parallel_size: int | None = None + draft_tensor_parallel_size: int | None = Field(default=None, ge=1) """The degree of the tensor parallelism for the draft model. Can only be 1 or the same as the target model's tensor parallel size.""" disable_logprobs: bool = True @@ -89,7 +89,7 @@ class SpeculativeConfig: """Quantization method that was used to quantize the draft model weights. If `None`, we assume the model weights are not quantized. Note that it only takes effect when using the draft model-based speculative method.""" - max_model_len: int | None = None + max_model_len: int | None = Field(default=None, ge=1) """The maximum model length of the draft model. Used when testing the ability to skip speculation for some sequences.""" revision: str | None = None @@ -102,7 +102,7 @@ class SpeculativeConfig: will use the default version.""" # Advanced control - disable_by_batch_size: int | None = None + disable_by_batch_size: int | None = Field(default=None, ge=2) """Disable speculative decoding for new incoming requests when the number of enqueued requests is larger than this value, if provided.""" disable_padded_drafter_batch: bool = False @@ -112,10 +112,10 @@ class SpeculativeConfig: only affects the EAGLE method of speculation.""" # Ngram proposer configuration - prompt_lookup_max: int | None = None + prompt_lookup_max: int | None = Field(default=None, ge=1) """Maximum size of ngram token window when using Ngram proposer, required when method is set to ngram.""" - prompt_lookup_min: int | None = None + prompt_lookup_min: int | None = Field(default=None, ge=1) """Minimum size of ngram token window when using Ngram proposer, if provided. Defaults to 1.""" @@ -232,9 +232,8 @@ class SpeculativeConfig: if self.model is None and self.num_speculative_tokens is not None: if self.method == "mtp": - assert self.target_model_config is not None, ( - "target_model_config must be present for mtp" - ) + if self.target_model_config is None: + raise ValueError("target_model_config must be present for mtp") if self.target_model_config.hf_text_config.model_type == "deepseek_v32": # FIXME(luccafong): cudgraph with v32 MTP is not supported, # remove this when the issue is fixed. @@ -268,21 +267,21 @@ class SpeculativeConfig: self.prompt_lookup_min = 5 self.prompt_lookup_max = 5 elif self.prompt_lookup_min is None: - assert self.prompt_lookup_max is not None + if self.prompt_lookup_max is None: + raise ValueError( + "Either prompt_lookup_max or prompt_lookup_min must be " + "provided when using the ngram method." + ) self.prompt_lookup_min = self.prompt_lookup_max elif self.prompt_lookup_max is None: - assert self.prompt_lookup_min is not None + if self.prompt_lookup_min is None: + raise ValueError( + "Either prompt_lookup_max or prompt_lookup_min must be " + "provided when using the ngram method." + ) self.prompt_lookup_max = self.prompt_lookup_min # Validate values - if self.prompt_lookup_min < 1: - raise ValueError( - f"prompt_lookup_min={self.prompt_lookup_min} must be > 0" - ) - if self.prompt_lookup_max < 1: - raise ValueError( - f"prompt_lookup_max={self.prompt_lookup_max} must be > 0" - ) if self.prompt_lookup_min > self.prompt_lookup_max: raise ValueError( f"prompt_lookup_min={self.prompt_lookup_min} must " @@ -446,6 +445,7 @@ class SpeculativeConfig: self.target_parallel_config, self.draft_tensor_parallel_size ) ) + return self @staticmethod def _maybe_override_draft_max_model_len(