mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-18 20:27:17 +08:00
[Feature] Pydantic validation for speculative.py (#27156)
Signed-off-by: Navya Srivastava <navya.srivastava1707@gmail.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
570c3e1cd4
commit
faee3ccdc2
@ -5,7 +5,7 @@ import ast
|
||||
import hashlib
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from pydantic import SkipValidation, model_validator
|
||||
from pydantic import Field, SkipValidation, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing_extensions import Self
|
||||
|
||||
@ -62,7 +62,7 @@ class SpeculativeConfig:
|
||||
enforce_eager: bool | None = None
|
||||
"""Override the default enforce_eager from model_config"""
|
||||
# General speculative decoding control
|
||||
num_speculative_tokens: SkipValidation[int] = None # type: ignore
|
||||
num_speculative_tokens: int = Field(default=None, gt=0)
|
||||
"""The number of speculative tokens, if provided. It will default to the
|
||||
number in the draft model config if present, otherwise, it is required."""
|
||||
model: str | None = None
|
||||
@ -76,7 +76,7 @@ class SpeculativeConfig:
|
||||
|
||||
If using `ngram` method, the related configuration `prompt_lookup_max` and
|
||||
`prompt_lookup_min` should be considered."""
|
||||
draft_tensor_parallel_size: int | None = None
|
||||
draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
|
||||
"""The degree of the tensor parallelism for the draft model. Can only be 1
|
||||
or the same as the target model's tensor parallel size."""
|
||||
disable_logprobs: bool = True
|
||||
@ -89,7 +89,7 @@ class SpeculativeConfig:
|
||||
"""Quantization method that was used to quantize the draft model weights.
|
||||
If `None`, we assume the model weights are not quantized. Note that it only
|
||||
takes effect when using the draft model-based speculative method."""
|
||||
max_model_len: int | None = None
|
||||
max_model_len: int | None = Field(default=None, ge=1)
|
||||
"""The maximum model length of the draft model. Used when testing the
|
||||
ability to skip speculation for some sequences."""
|
||||
revision: str | None = None
|
||||
@ -102,7 +102,7 @@ class SpeculativeConfig:
|
||||
will use the default version."""
|
||||
|
||||
# Advanced control
|
||||
disable_by_batch_size: int | None = None
|
||||
disable_by_batch_size: int | None = Field(default=None, ge=2)
|
||||
"""Disable speculative decoding for new incoming requests when the number
|
||||
of enqueued requests is larger than this value, if provided."""
|
||||
disable_padded_drafter_batch: bool = False
|
||||
@ -112,10 +112,10 @@ class SpeculativeConfig:
|
||||
only affects the EAGLE method of speculation."""
|
||||
|
||||
# Ngram proposer configuration
|
||||
prompt_lookup_max: int | None = None
|
||||
prompt_lookup_max: int | None = Field(default=None, ge=1)
|
||||
"""Maximum size of ngram token window when using Ngram proposer, required
|
||||
when method is set to ngram."""
|
||||
prompt_lookup_min: int | None = None
|
||||
prompt_lookup_min: int | None = Field(default=None, ge=1)
|
||||
"""Minimum size of ngram token window when using Ngram proposer, if
|
||||
provided. Defaults to 1."""
|
||||
|
||||
@ -232,9 +232,8 @@ class SpeculativeConfig:
|
||||
|
||||
if self.model is None and self.num_speculative_tokens is not None:
|
||||
if self.method == "mtp":
|
||||
assert self.target_model_config is not None, (
|
||||
"target_model_config must be present for mtp"
|
||||
)
|
||||
if self.target_model_config is None:
|
||||
raise ValueError("target_model_config must be present for mtp")
|
||||
if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
|
||||
# FIXME(luccafong): cudgraph with v32 MTP is not supported,
|
||||
# remove this when the issue is fixed.
|
||||
@ -268,21 +267,21 @@ class SpeculativeConfig:
|
||||
self.prompt_lookup_min = 5
|
||||
self.prompt_lookup_max = 5
|
||||
elif self.prompt_lookup_min is None:
|
||||
assert self.prompt_lookup_max is not None
|
||||
if self.prompt_lookup_max is None:
|
||||
raise ValueError(
|
||||
"Either prompt_lookup_max or prompt_lookup_min must be "
|
||||
"provided when using the ngram method."
|
||||
)
|
||||
self.prompt_lookup_min = self.prompt_lookup_max
|
||||
elif self.prompt_lookup_max is None:
|
||||
assert self.prompt_lookup_min is not None
|
||||
if self.prompt_lookup_min is None:
|
||||
raise ValueError(
|
||||
"Either prompt_lookup_max or prompt_lookup_min must be "
|
||||
"provided when using the ngram method."
|
||||
)
|
||||
self.prompt_lookup_max = self.prompt_lookup_min
|
||||
|
||||
# Validate values
|
||||
if self.prompt_lookup_min < 1:
|
||||
raise ValueError(
|
||||
f"prompt_lookup_min={self.prompt_lookup_min} must be > 0"
|
||||
)
|
||||
if self.prompt_lookup_max < 1:
|
||||
raise ValueError(
|
||||
f"prompt_lookup_max={self.prompt_lookup_max} must be > 0"
|
||||
)
|
||||
if self.prompt_lookup_min > self.prompt_lookup_max:
|
||||
raise ValueError(
|
||||
f"prompt_lookup_min={self.prompt_lookup_min} must "
|
||||
@ -446,6 +445,7 @@ class SpeculativeConfig:
|
||||
self.target_parallel_config, self.draft_tensor_parallel_size
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _maybe_override_draft_max_model_len(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user