mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-12 00:04:44 +08:00
[Feature] Pydantic validation for speculative.py (#27156)
Signed-off-by: Navya Srivastava <navya.srivastava1707@gmail.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
570c3e1cd4
commit
faee3ccdc2
@ -5,7 +5,7 @@ import ast
|
|||||||
import hashlib
|
import hashlib
|
||||||
from typing import TYPE_CHECKING, Any, Literal
|
from typing import TYPE_CHECKING, Any, Literal
|
||||||
|
|
||||||
from pydantic import SkipValidation, model_validator
|
from pydantic import Field, SkipValidation, model_validator
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
@ -62,7 +62,7 @@ class SpeculativeConfig:
|
|||||||
enforce_eager: bool | None = None
|
enforce_eager: bool | None = None
|
||||||
"""Override the default enforce_eager from model_config"""
|
"""Override the default enforce_eager from model_config"""
|
||||||
# General speculative decoding control
|
# General speculative decoding control
|
||||||
num_speculative_tokens: SkipValidation[int] = None # type: ignore
|
num_speculative_tokens: int = Field(default=None, gt=0)
|
||||||
"""The number of speculative tokens, if provided. It will default to the
|
"""The number of speculative tokens, if provided. It will default to the
|
||||||
number in the draft model config if present, otherwise, it is required."""
|
number in the draft model config if present, otherwise, it is required."""
|
||||||
model: str | None = None
|
model: str | None = None
|
||||||
@ -76,7 +76,7 @@ class SpeculativeConfig:
|
|||||||
|
|
||||||
If using `ngram` method, the related configuration `prompt_lookup_max` and
|
If using `ngram` method, the related configuration `prompt_lookup_max` and
|
||||||
`prompt_lookup_min` should be considered."""
|
`prompt_lookup_min` should be considered."""
|
||||||
draft_tensor_parallel_size: int | None = None
|
draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
|
||||||
"""The degree of the tensor parallelism for the draft model. Can only be 1
|
"""The degree of the tensor parallelism for the draft model. Can only be 1
|
||||||
or the same as the target model's tensor parallel size."""
|
or the same as the target model's tensor parallel size."""
|
||||||
disable_logprobs: bool = True
|
disable_logprobs: bool = True
|
||||||
@ -89,7 +89,7 @@ class SpeculativeConfig:
|
|||||||
"""Quantization method that was used to quantize the draft model weights.
|
"""Quantization method that was used to quantize the draft model weights.
|
||||||
If `None`, we assume the model weights are not quantized. Note that it only
|
If `None`, we assume the model weights are not quantized. Note that it only
|
||||||
takes effect when using the draft model-based speculative method."""
|
takes effect when using the draft model-based speculative method."""
|
||||||
max_model_len: int | None = None
|
max_model_len: int | None = Field(default=None, ge=1)
|
||||||
"""The maximum model length of the draft model. Used when testing the
|
"""The maximum model length of the draft model. Used when testing the
|
||||||
ability to skip speculation for some sequences."""
|
ability to skip speculation for some sequences."""
|
||||||
revision: str | None = None
|
revision: str | None = None
|
||||||
@ -102,7 +102,7 @@ class SpeculativeConfig:
|
|||||||
will use the default version."""
|
will use the default version."""
|
||||||
|
|
||||||
# Advanced control
|
# Advanced control
|
||||||
disable_by_batch_size: int | None = None
|
disable_by_batch_size: int | None = Field(default=None, ge=2)
|
||||||
"""Disable speculative decoding for new incoming requests when the number
|
"""Disable speculative decoding for new incoming requests when the number
|
||||||
of enqueued requests is larger than this value, if provided."""
|
of enqueued requests is larger than this value, if provided."""
|
||||||
disable_padded_drafter_batch: bool = False
|
disable_padded_drafter_batch: bool = False
|
||||||
@ -112,10 +112,10 @@ class SpeculativeConfig:
|
|||||||
only affects the EAGLE method of speculation."""
|
only affects the EAGLE method of speculation."""
|
||||||
|
|
||||||
# Ngram proposer configuration
|
# Ngram proposer configuration
|
||||||
prompt_lookup_max: int | None = None
|
prompt_lookup_max: int | None = Field(default=None, ge=1)
|
||||||
"""Maximum size of ngram token window when using Ngram proposer, required
|
"""Maximum size of ngram token window when using Ngram proposer, required
|
||||||
when method is set to ngram."""
|
when method is set to ngram."""
|
||||||
prompt_lookup_min: int | None = None
|
prompt_lookup_min: int | None = Field(default=None, ge=1)
|
||||||
"""Minimum size of ngram token window when using Ngram proposer, if
|
"""Minimum size of ngram token window when using Ngram proposer, if
|
||||||
provided. Defaults to 1."""
|
provided. Defaults to 1."""
|
||||||
|
|
||||||
@ -232,9 +232,8 @@ class SpeculativeConfig:
|
|||||||
|
|
||||||
if self.model is None and self.num_speculative_tokens is not None:
|
if self.model is None and self.num_speculative_tokens is not None:
|
||||||
if self.method == "mtp":
|
if self.method == "mtp":
|
||||||
assert self.target_model_config is not None, (
|
if self.target_model_config is None:
|
||||||
"target_model_config must be present for mtp"
|
raise ValueError("target_model_config must be present for mtp")
|
||||||
)
|
|
||||||
if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
|
if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
|
||||||
# FIXME(luccafong): cudgraph with v32 MTP is not supported,
|
# FIXME(luccafong): cudgraph with v32 MTP is not supported,
|
||||||
# remove this when the issue is fixed.
|
# remove this when the issue is fixed.
|
||||||
@ -268,21 +267,21 @@ class SpeculativeConfig:
|
|||||||
self.prompt_lookup_min = 5
|
self.prompt_lookup_min = 5
|
||||||
self.prompt_lookup_max = 5
|
self.prompt_lookup_max = 5
|
||||||
elif self.prompt_lookup_min is None:
|
elif self.prompt_lookup_min is None:
|
||||||
assert self.prompt_lookup_max is not None
|
if self.prompt_lookup_max is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Either prompt_lookup_max or prompt_lookup_min must be "
|
||||||
|
"provided when using the ngram method."
|
||||||
|
)
|
||||||
self.prompt_lookup_min = self.prompt_lookup_max
|
self.prompt_lookup_min = self.prompt_lookup_max
|
||||||
elif self.prompt_lookup_max is None:
|
elif self.prompt_lookup_max is None:
|
||||||
assert self.prompt_lookup_min is not None
|
if self.prompt_lookup_min is None:
|
||||||
|
raise ValueError(
|
||||||
|
"Either prompt_lookup_max or prompt_lookup_min must be "
|
||||||
|
"provided when using the ngram method."
|
||||||
|
)
|
||||||
self.prompt_lookup_max = self.prompt_lookup_min
|
self.prompt_lookup_max = self.prompt_lookup_min
|
||||||
|
|
||||||
# Validate values
|
# Validate values
|
||||||
if self.prompt_lookup_min < 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"prompt_lookup_min={self.prompt_lookup_min} must be > 0"
|
|
||||||
)
|
|
||||||
if self.prompt_lookup_max < 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"prompt_lookup_max={self.prompt_lookup_max} must be > 0"
|
|
||||||
)
|
|
||||||
if self.prompt_lookup_min > self.prompt_lookup_max:
|
if self.prompt_lookup_min > self.prompt_lookup_max:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"prompt_lookup_min={self.prompt_lookup_min} must "
|
f"prompt_lookup_min={self.prompt_lookup_min} must "
|
||||||
@ -446,6 +445,7 @@ class SpeculativeConfig:
|
|||||||
self.target_parallel_config, self.draft_tensor_parallel_size
|
self.target_parallel_config, self.draft_tensor_parallel_size
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _maybe_override_draft_max_model_len(
|
def _maybe_override_draft_max_model_len(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user