mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-30 00:51:51 +08:00
Remove SkipValidation from ModelConfig (#30695)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
17fec3af09
commit
970713d4a4
@ -8,7 +8,7 @@ from functools import cached_property
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast, get_args
|
||||
|
||||
import torch
|
||||
from pydantic import ConfigDict, SkipValidation, field_validator, model_validator
|
||||
from pydantic import ConfigDict, Field, field_validator, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
|
||||
@ -109,7 +109,7 @@ class ModelConfig:
|
||||
"""Convert the model using adapters defined in
|
||||
[vllm.model_executor.models.adapters][]. The most common use case is to
|
||||
adapt a text generation model to be used for pooling tasks."""
|
||||
tokenizer: SkipValidation[str] = None # type: ignore
|
||||
tokenizer: str = Field(default=None)
|
||||
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
|
||||
name or path will be used."""
|
||||
tokenizer_mode: TokenizerMode | str = "auto"
|
||||
@ -164,7 +164,7 @@ class ModelConfig:
|
||||
"""The specific revision to use for the tokenizer on the Hugging Face Hub.
|
||||
It can be a branch name, a tag name, or a commit id. If unspecified, will
|
||||
use the default version."""
|
||||
max_model_len: SkipValidation[int] = None # type: ignore
|
||||
max_model_len: int = Field(default=None, gt=0)
|
||||
"""Model context length (prompt and output). If unspecified, will be
|
||||
automatically derived from the model config.
|
||||
|
||||
@ -175,7 +175,7 @@ class ModelConfig:
|
||||
- 25.6k -> 25,600"""
|
||||
spec_target_max_model_len: int | None = None
|
||||
"""Specify the maximum length for spec decoding draft models."""
|
||||
quantization: SkipValidation[QuantizationMethods | None] = None
|
||||
quantization: QuantizationMethods | str | None = None
|
||||
"""Method used to quantize the weights. If `None`, we first check the
|
||||
`quantization_config` attribute in the model config file. If that is
|
||||
`None`, we assume the model weights are not quantized and use `dtype` to
|
||||
@ -597,6 +597,14 @@ class ModelConfig:
|
||||
self._verify_cuda_graph()
|
||||
self._verify_bnb_config()
|
||||
|
||||
@field_validator("tokenizer", "max_model_len", mode="wrap")
|
||||
@classmethod
|
||||
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
|
||||
"""Skip validation if the value is `None` when initialisation is delayed."""
|
||||
if value is None:
|
||||
return value
|
||||
return handler(value)
|
||||
|
||||
@field_validator("tokenizer_mode", mode="after")
|
||||
def _lowercase_tokenizer_mode(cls, tokenizer_mode: str) -> str:
|
||||
return tokenizer_mode.lower()
|
||||
@ -610,13 +618,14 @@ class ModelConfig:
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_model_config_after(self: "ModelConfig") -> "ModelConfig":
|
||||
"""Called after __post_init__"""
|
||||
if not isinstance(self.tokenizer, str):
|
||||
raise ValueError(
|
||||
f"tokenizer must be a string, got "
|
||||
f"{type(self.tokenizer).__name__}: {self.tokenizer!r}. "
|
||||
"Please provide a valid tokenizer path or HuggingFace model ID."
|
||||
)
|
||||
if not isinstance(self.max_model_len, int) or self.max_model_len <= 0:
|
||||
if not isinstance(self.max_model_len, int):
|
||||
raise ValueError(
|
||||
f"max_model_len must be a positive integer, "
|
||||
f"got {type(self.max_model_len).__name__}: {self.max_model_len!r}. "
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user