mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-28 04:47:03 +08:00
Remove SkipValidation from ModelConfig (#30695)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
17fec3af09
commit
970713d4a4
@ -8,7 +8,7 @@ from functools import cached_property
|
|||||||
from typing import TYPE_CHECKING, Any, Literal, cast, get_args
|
from typing import TYPE_CHECKING, Any, Literal, cast, get_args
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import ConfigDict, SkipValidation, field_validator, model_validator
|
from pydantic import ConfigDict, Field, field_validator, model_validator
|
||||||
from pydantic.dataclasses import dataclass
|
from pydantic.dataclasses import dataclass
|
||||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||||
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
|
from transformers.configuration_utils import ALLOWED_LAYER_TYPES
|
||||||
@ -109,7 +109,7 @@ class ModelConfig:
|
|||||||
"""Convert the model using adapters defined in
|
"""Convert the model using adapters defined in
|
||||||
[vllm.model_executor.models.adapters][]. The most common use case is to
|
[vllm.model_executor.models.adapters][]. The most common use case is to
|
||||||
adapt a text generation model to be used for pooling tasks."""
|
adapt a text generation model to be used for pooling tasks."""
|
||||||
tokenizer: SkipValidation[str] = None # type: ignore
|
tokenizer: str = Field(default=None)
|
||||||
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
|
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
|
||||||
name or path will be used."""
|
name or path will be used."""
|
||||||
tokenizer_mode: TokenizerMode | str = "auto"
|
tokenizer_mode: TokenizerMode | str = "auto"
|
||||||
@ -164,7 +164,7 @@ class ModelConfig:
|
|||||||
"""The specific revision to use for the tokenizer on the Hugging Face Hub.
|
"""The specific revision to use for the tokenizer on the Hugging Face Hub.
|
||||||
It can be a branch name, a tag name, or a commit id. If unspecified, will
|
It can be a branch name, a tag name, or a commit id. If unspecified, will
|
||||||
use the default version."""
|
use the default version."""
|
||||||
max_model_len: SkipValidation[int] = None # type: ignore
|
max_model_len: int = Field(default=None, gt=0)
|
||||||
"""Model context length (prompt and output). If unspecified, will be
|
"""Model context length (prompt and output). If unspecified, will be
|
||||||
automatically derived from the model config.
|
automatically derived from the model config.
|
||||||
|
|
||||||
@ -175,7 +175,7 @@ class ModelConfig:
|
|||||||
- 25.6k -> 25,600"""
|
- 25.6k -> 25,600"""
|
||||||
spec_target_max_model_len: int | None = None
|
spec_target_max_model_len: int | None = None
|
||||||
"""Specify the maximum length for spec decoding draft models."""
|
"""Specify the maximum length for spec decoding draft models."""
|
||||||
quantization: SkipValidation[QuantizationMethods | None] = None
|
quantization: QuantizationMethods | str | None = None
|
||||||
"""Method used to quantize the weights. If `None`, we first check the
|
"""Method used to quantize the weights. If `None`, we first check the
|
||||||
`quantization_config` attribute in the model config file. If that is
|
`quantization_config` attribute in the model config file. If that is
|
||||||
`None`, we assume the model weights are not quantized and use `dtype` to
|
`None`, we assume the model weights are not quantized and use `dtype` to
|
||||||
@ -597,6 +597,14 @@ class ModelConfig:
|
|||||||
self._verify_cuda_graph()
|
self._verify_cuda_graph()
|
||||||
self._verify_bnb_config()
|
self._verify_bnb_config()
|
||||||
|
|
||||||
|
@field_validator("tokenizer", "max_model_len", mode="wrap")
|
||||||
|
@classmethod
|
||||||
|
def _skip_none_validation(cls, value: Any, handler: Callable) -> Any:
|
||||||
|
"""Skip validation if the value is `None` when initialisation is delayed."""
|
||||||
|
if value is None:
|
||||||
|
return value
|
||||||
|
return handler(value)
|
||||||
|
|
||||||
@field_validator("tokenizer_mode", mode="after")
|
@field_validator("tokenizer_mode", mode="after")
|
||||||
def _lowercase_tokenizer_mode(cls, tokenizer_mode: str) -> str:
|
def _lowercase_tokenizer_mode(cls, tokenizer_mode: str) -> str:
|
||||||
return tokenizer_mode.lower()
|
return tokenizer_mode.lower()
|
||||||
@ -610,13 +618,14 @@ class ModelConfig:
|
|||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_model_config_after(self: "ModelConfig") -> "ModelConfig":
|
def validate_model_config_after(self: "ModelConfig") -> "ModelConfig":
|
||||||
|
"""Called after __post_init__"""
|
||||||
if not isinstance(self.tokenizer, str):
|
if not isinstance(self.tokenizer, str):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"tokenizer must be a string, got "
|
f"tokenizer must be a string, got "
|
||||||
f"{type(self.tokenizer).__name__}: {self.tokenizer!r}. "
|
f"{type(self.tokenizer).__name__}: {self.tokenizer!r}. "
|
||||||
"Please provide a valid tokenizer path or HuggingFace model ID."
|
"Please provide a valid tokenizer path or HuggingFace model ID."
|
||||||
)
|
)
|
||||||
if not isinstance(self.max_model_len, int) or self.max_model_len <= 0:
|
if not isinstance(self.max_model_len, int):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"max_model_len must be a positive integer, "
|
f"max_model_len must be a positive integer, "
|
||||||
f"got {type(self.max_model_len).__name__}: {self.max_model_len!r}. "
|
f"got {type(self.max_model_len).__name__}: {self.max_model_len!r}. "
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user