Remove SkipValidation from ModelConfig (#30695)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-12-15 17:34:08 +00:00 committed by GitHub
parent 17fec3af09
commit 970713d4a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -8,7 +8,7 @@ from functools import cached_property
from typing import TYPE_CHECKING, Any, Literal, cast, get_args from typing import TYPE_CHECKING, Any, Literal, cast, get_args
import torch import torch
from pydantic import ConfigDict, SkipValidation, field_validator, model_validator from pydantic import ConfigDict, Field, field_validator, model_validator
from pydantic.dataclasses import dataclass from pydantic.dataclasses import dataclass
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from transformers.configuration_utils import ALLOWED_LAYER_TYPES from transformers.configuration_utils import ALLOWED_LAYER_TYPES
@ -109,7 +109,7 @@ class ModelConfig:
"""Convert the model using adapters defined in """Convert the model using adapters defined in
[vllm.model_executor.models.adapters][]. The most common use case is to [vllm.model_executor.models.adapters][]. The most common use case is to
adapt a text generation model to be used for pooling tasks.""" adapt a text generation model to be used for pooling tasks."""
tokenizer: SkipValidation[str] = None # type: ignore tokenizer: str = Field(default=None)
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model """Name or path of the Hugging Face tokenizer to use. If unspecified, model
name or path will be used.""" name or path will be used."""
tokenizer_mode: TokenizerMode | str = "auto" tokenizer_mode: TokenizerMode | str = "auto"
@ -164,7 +164,7 @@ class ModelConfig:
"""The specific revision to use for the tokenizer on the Hugging Face Hub. """The specific revision to use for the tokenizer on the Hugging Face Hub.
It can be a branch name, a tag name, or a commit id. If unspecified, will It can be a branch name, a tag name, or a commit id. If unspecified, will
use the default version.""" use the default version."""
max_model_len: SkipValidation[int] = None # type: ignore max_model_len: int = Field(default=None, gt=0)
"""Model context length (prompt and output). If unspecified, will be """Model context length (prompt and output). If unspecified, will be
automatically derived from the model config. automatically derived from the model config.
@ -175,7 +175,7 @@ class ModelConfig:
- 25.6k -> 25,600""" - 25.6k -> 25,600"""
spec_target_max_model_len: int | None = None spec_target_max_model_len: int | None = None
"""Specify the maximum length for spec decoding draft models.""" """Specify the maximum length for spec decoding draft models."""
quantization: SkipValidation[QuantizationMethods | None] = None quantization: QuantizationMethods | str | None = None
"""Method used to quantize the weights. If `None`, we first check the """Method used to quantize the weights. If `None`, we first check the
`quantization_config` attribute in the model config file. If that is `quantization_config` attribute in the model config file. If that is
`None`, we assume the model weights are not quantized and use `dtype` to `None`, we assume the model weights are not quantized and use `dtype` to
@ -597,6 +597,14 @@ class ModelConfig:
self._verify_cuda_graph() self._verify_cuda_graph()
self._verify_bnb_config() self._verify_bnb_config()
@field_validator("tokenizer", "max_model_len", mode="wrap")
@classmethod
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
"""Skip validation if the value is `None` when initialisation is delayed."""
if value is None:
return value
return handler(value)
@field_validator("tokenizer_mode", mode="after") @field_validator("tokenizer_mode", mode="after")
def _lowercase_tokenizer_mode(cls, tokenizer_mode: str) -> str: def _lowercase_tokenizer_mode(cls, tokenizer_mode: str) -> str:
return tokenizer_mode.lower() return tokenizer_mode.lower()
@ -610,13 +618,14 @@ class ModelConfig:
@model_validator(mode="after") @model_validator(mode="after")
def validate_model_config_after(self: "ModelConfig") -> "ModelConfig": def validate_model_config_after(self: "ModelConfig") -> "ModelConfig":
"""Called after __post_init__"""
if not isinstance(self.tokenizer, str): if not isinstance(self.tokenizer, str):
raise ValueError( raise ValueError(
f"tokenizer must be a string, got " f"tokenizer must be a string, got "
f"{type(self.tokenizer).__name__}: {self.tokenizer!r}. " f"{type(self.tokenizer).__name__}: {self.tokenizer!r}. "
"Please provide a valid tokenizer path or HuggingFace model ID." "Please provide a valid tokenizer path or HuggingFace model ID."
) )
if not isinstance(self.max_model_len, int) or self.max_model_len <= 0: if not isinstance(self.max_model_len, int):
raise ValueError( raise ValueError(
f"max_model_len must be a positive integer, " f"max_model_len must be a positive integer, "
f"got {type(self.max_model_len).__name__}: {self.max_model_len!r}. " f"got {type(self.max_model_len).__name__}: {self.max_model_len!r}. "