Improve configs - ModelConfig (#17130)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-04-30 03:38:22 +01:00 committed by GitHub
parent 2c4f59afc3
commit 13698db634
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 490 additions and 648 deletions

View File

@ -738,7 +738,7 @@ class VllmRunner:
- `block_size`: Set to `16` instead of `None` to reduce memory usage. - `block_size`: Set to `16` instead of `None` to reduce memory usage.
- `enable_chunked_prefill`: Set to `False` instead of `None` for - `enable_chunked_prefill`: Set to `False` instead of `None` for
test reproducibility. test reproducibility.
- `enforce_eager`: Set to `False` instead of `None` to test CUDA graph. - `enforce_eager`: Set to `False` to test CUDA graph.
""" """
def __init__( def __init__(

View File

@ -8,7 +8,7 @@ from typing import Literal, Optional
import pytest import pytest
from vllm.config import PoolerConfig, config from vllm.config import config
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs, from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
get_type, is_not_builtin, is_type, get_type, is_not_builtin, is_type,
literal_to_kwargs, nullable_kvs, literal_to_kwargs, nullable_kvs,
@ -222,17 +222,6 @@ def test_prefix_cache_default():
assert not engine_args.enable_prefix_caching assert not engine_args.enable_prefix_caching
def test_valid_pooling_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args([
'--override-pooler-config',
'{"pooling_type": "MEAN"}',
])
engine_args = EngineArgs.from_cli_args(args=args)
assert engine_args.override_pooler_config == PoolerConfig(
pooling_type="MEAN", )
@pytest.mark.parametrize( @pytest.mark.parametrize(
("arg"), ("arg"),
[ [

View File

@ -14,7 +14,7 @@ import torch.nn.functional as F
from vllm.model_executor.layers.linear import LinearBase # noqa: E501 from vllm.model_executor.layers.linear import LinearBase # noqa: E501
from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import ( from vllm.model_executor.layers.quantization import (
get_quantization_config, register_quantization_config) QuantizationMethods, get_quantization_config, register_quantization_config)
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig) QuantizationConfig)
@ -54,7 +54,7 @@ class CustomQuantConfig(QuantizationConfig):
"""Initialize the quantization config.""" """Initialize the quantization config."""
self.num_bits = num_bits self.num_bits = num_bits
def get_name(self) -> str: def get_name(self) -> QuantizationMethods:
"""Name of the quantization method.""" """Name of the quantization method."""
return "custom_quant" return "custom_quant"

View File

@ -185,7 +185,7 @@ def test_get_pooling_config():
revision=None, revision=None,
) )
pooling_config = model_config._init_pooler_config(None) pooling_config = model_config._init_pooler_config()
assert pooling_config is not None assert pooling_config is not None
assert pooling_config.normalize assert pooling_config.normalize
@ -205,11 +205,12 @@ def test_get_pooling_config_from_args():
dtype="float16", dtype="float16",
revision=None) revision=None)
override_config = PoolerConfig(pooling_type='CLS', normalize=True) override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True)
model_config.override_pooler_config = override_pooler_config
pooling_config = model_config._init_pooler_config(override_config) pooling_config = model_config._init_pooler_config()
assert pooling_config is not None assert pooling_config is not None
assert asdict(pooling_config) == asdict(override_config) assert asdict(pooling_config) == asdict(override_pooler_config)
@pytest.mark.skipif(current_platform.is_rocm(), @pytest.mark.skipif(current_platform.is_rocm(),

View File

@ -16,9 +16,8 @@ from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
replace) replace)
from importlib.util import find_spec from importlib.util import find_spec
from pathlib import Path from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal, from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
Optional, Protocol, TypeVar, Union, cast, get_args, Protocol, TypeVar, Union, cast, get_args, get_origin)
get_origin)
import torch import torch
from pydantic import BaseModel, Field, PrivateAttr from pydantic import BaseModel, Field, PrivateAttr
@ -211,103 +210,190 @@ def get_field(cls: ConfigType, name: str) -> Field:
f"{cls.__name__}.{name} must have a default value or default factory.") f"{cls.__name__}.{name} must have a default value or default factory.")
class ModelConfig: TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
"""Configuration for the model. ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
Args:
model: Name or path of the huggingface model to use. @config
It is also used as the content for `model_name` tag in metrics @dataclass
output when `served_model_name` is not specified. class ModelConfig:
task: The task to use the model for. Each vLLM instance only supports """Configuration for the model."""
one task, even if the same model can be used for multiple tasks.
When the model only supports one task, "auto" can be used to select model: str = "facebook/opt-125m"
it; otherwise, you must specify explicitly which task to use. """Name or path of the Hugging Face model to use. It is also used as the
tokenizer: Name or path of the huggingface tokenizer to use. content for `model_name` tag in metrics output when `served_model_name` is
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if not specified."""
available, "slow" will always use the slow tokenizer, task: Literal[TaskOption, Literal["draft"]] = "auto"
"mistral" will always use the tokenizer from `mistral_common`, and """The task to use the model for. Each vLLM instance only supports one
"custom" will use --tokenizer to select the preregistered tokenizer. task, even if the same model can be used for multiple tasks. When the model
trust_remote_code: Trust remote code (e.g., from HuggingFace) when only supports one task, "auto" can be used to select it; otherwise, you
downloading the model and tokenizer. must specify explicitly which task to use."""
allowed_local_media_path: Allowing API requests to read local images or tokenizer: str = None # type: ignore
videos from directories specified by the server file system. """Name or path of the Hugging Face tokenizer to use. If unspecified, model
This is a security risk. Should only be enabled in trusted name or path will be used."""
environments. tokenizer_mode: TokenizerMode = "auto"
dtype: Data type for model weights and activations. The "auto" option """Tokenizer mode:\n
will use FP16 precision for FP32 and FP16 models, and BF16 precision - "auto" will use the fast tokenizer if available.\n
for BF16 models. - "slow" will always use the slow tokenizer.\n
seed: Random seed for reproducibility. - "mistral" will always use the tokenizer from `mistral_common`.\n
revision: The specific model version to use. It can be a branch name, - "custom" will use --tokenizer to select the preregistered tokenizer."""
a tag name, or a commit id. If unspecified, will use the default trust_remote_code: bool = False
version. """Trust remote code (e.g., from HuggingFace) when downloading the model
code_revision: The specific revision to use for the model code on and tokenizer."""
Hugging Face Hub. It can be a branch name, a tag name, or a dtype: Union[ModelDType, torch.dtype] = "auto"
commit id. If unspecified, will use the default version. """Data type for model weights and activations:\n
tokenizer_revision: The specific tokenizer version to use. It can be a - "auto" will use FP16 precision for FP32 and FP16 models, and BF16
branch name, a tag name, or a commit id. If unspecified, will use precision for BF16 models.\n
the default version. - "half" for FP16. Recommended for AWQ quantization.\n
max_model_len: Maximum length of a sequence (including prompt and - "float16" is the same as "half".\n
output). If None, will be derived from the model. - "bfloat16" for a balance between precision and range.\n
spec_target_max_model_len: Specify the the maximum length for spec - "float" is shorthand for FP32 precision.\n
decoding draft models. - "float32" for FP32 precision."""
quantization: Quantization method that was used to quantize the model seed: Optional[int] = None
weights. If None, we assume the model weights are not quantized. """Random seed for reproducibility."""
enforce_eager: Whether to enforce eager execution. If True, we will hf_config_path: Optional[str] = None
disable CUDA graph and always execute the model in eager mode. """Name or path of the Hugging Face config to use. If unspecified, model
If False, we will use CUDA graph and eager execution in hybrid. name or path will be used."""
If None, the user did not specify, so default to False. allowed_local_media_path: str = ""
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs. """Allowing API requests to read local images or videos from directories
When a sequence has context length larger than this, we fall back specified by the server file system. This is a security risk. Should only
to eager mode. Additionally for encoder-decoder models, if the be enabled in trusted environments."""
sequence length of the encoder input is larger than this, we fall revision: Optional[str] = None
back to the eager mode. """The specific model version to use. It can be a branch name, a tag name,
max_logprobs: Maximum number of log probabilities. Defaults to 20. or a commit id. If unspecified, will use the default version."""
disable_sliding_window: Whether to disable sliding window. If True, code_revision: Optional[str] = None
we will disable the sliding window functionality of the model. """The specific revision to use for the model code on the Hugging Face Hub.
If the model does not support sliding window, this argument is It can be a branch name, a tag name, or a commit id. If unspecified, will
ignored. use the default version."""
skip_tokenizer_init: If true, skip initialization of tokenizer and rope_scaling: dict[str, Any] = field(default_factory=dict)
detokenizer. """RoPE scaling configuration in JSON format. For example,
served_model_name: The model name used in metrics tag `model_name`, `{"rope_type":"dynamic","factor":2.0}`."""
matches the model name exposed via the APIs. If multiple model rope_theta: Optional[float] = None
names provided, the first name will be used. If not specified, """RoPE theta. Use with `rope_scaling`. In some cases, changing the RoPE
the model name will be the same as `model`. theta improves the performance of the scaled model."""
limit_mm_per_prompt: Maximum number of data items per modality tokenizer_revision: Optional[str] = None
per prompt. Only applicable for multimodal models. """The specific revision to use for the tokenizer on the Hugging Face Hub.
mm_processor_kwargs: Overrides for the multi-modal processor obtained It can be a branch name, a tag name, or a commit id. If unspecified, will
from `AutoProcessor.from_pretrained`. use the default version."""
disable_mm_preprocessor_cache: If True, disable caching of the max_model_len: int = None # type: ignore
processed multi-modal inputs. """Model context length (prompt and output). If unspecified, will be
use_async_output_proc: Whether to use async output processor. automatically derived from the model config.
Defaults to True.
config_format: The config format which shall be loaded. When passing via `--max-model-len`, supports k/m/g/K/M/G in human-readable
Defaults to 'auto' which defaults to 'hf'. format. Examples:\n
hf_token: The token to use as HTTP bearer authorization for remote files - 1k -> 1000\n
. If `True`, will use the token generated when running - 1K -> 1024\n
`huggingface-cli login` (stored in `~/.huggingface`). - 25.6k -> 25,600"""
hf_overrides: If a dictionary, contains arguments to be forwarded to the spec_target_max_model_len: Optional[int] = None
HuggingFace config. If a callable, it is called to update the """Specify the the maximum length for spec decoding draft models."""
HuggingFace config. quantization: Optional[QuantizationMethods] = None
override_neuron_config: Initialize non default neuron config or """Method used to quantize the weights. If `None`, we first check the
override default neuron config that are specific to Neuron devices, `quantization_config` attribute in the model config file. If that is
this argument will be used to configure the neuron config that `None`, we assume the model weights are not quantized and use `dtype` to
can not be gathered from the vllm arguments. determine the data type of the weights."""
override_pooler_config: Initialize non default pooling config or enforce_eager: bool = False
override default pooling config for the pooling model. """Whether to always use eager-mode PyTorch. If True, we will disable CUDA
logits_processor_pattern: Optional regex pattern specifying valid graph and always execute the model in eager mode. If False, we will use
logits processor qualified names that can be passed with the CUDA graph and eager execution in hybrid for maximal performance and
`logits_processors` extra completion argument. Defaults to None, flexibility."""
which allows no processors. max_seq_len_to_capture: int = 8192
generation_config: Configuration parameter file for generation. """Maximum sequence len covered by CUDA graphs. When a sequence has context
model_impl: Which implementation of the model to use: length larger than this, we fall back to eager mode. Additionally for
"auto" will try to use the vLLM implementation if it exists and encoder-decoder models, if the sequence length of the encoder input is
fall back to the Transformers implementation if no vLLM larger than this, we fall back to the eager mode."""
implementation is available. max_logprobs: int = 20
"vllm" will use the vLLM model implementation. """Maximum number of log probabilities to return when `logprobs` is
"transformers" will use the Transformers model implementation. specified in `SamplingParams`. The default value comes the default for the
override_generation_config: Override the generation config with the OpenAI Chat Completions API."""
given config. disable_sliding_window: bool = False
""" """Whether to disable sliding window. If True, we will disable the sliding
window functionality of the model, capping to sliding window size. If the
model does not support sliding window, this argument is ignored."""
disable_cascade_attn: bool = False
"""Disable cascade attention for V1. While cascade attention does not
change the mathematical correctness, disabling it could be useful for
preventing potential numerical issues. Note that even if this is set to
False, cascade attention will be only used when the heuristic tells that
it's beneficial."""
skip_tokenizer_init: bool = False
"""Skip initialization of tokenizer and detokenizer. Expects valid
`prompt_token_ids` and `None` for prompt from the input. The generated
output will contain token ids."""
served_model_name: Optional[Union[str, list[str]]] = None
"""The model name(s) used in the API. If multiple names are provided, the
server will respond to any of the provided names. The model name in the
model field of a response will be the first name in this list. If not
specified, the model name will be the same as the `--model` argument. Noted
that this name(s) will also be used in `model_name` tag content of
prometheus metrics, if multiple names provided, metrics tag will take the
first one."""
limit_mm_per_prompt: dict[str, int] = field(default_factory=dict)
"""Maximum number of data items per modality per prompt. Only applicable
for multimodal models."""
use_async_output_proc: bool = True
"""Whether to use async output processor."""
config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value
"""The format of the model config to load:\n
- "auto" will try to load the config in hf format if available else it
will try to load in mistral format.\n
- "hf" will load the config in hf format.\n
- "mistral" will load the config in mistral format."""
hf_token: Optional[Union[bool, str]] = None
"""The token to use as HTTP bearer authorization for remote files . If
`True`, will use the token generated when running `huggingface-cli login`
(stored in `~/.huggingface`)."""
hf_overrides: HfOverrides = field(default_factory=dict)
"""If a dictionary, contains arguments to be forwarded to the Hugging Face
config. If a callable, it is called to update the HuggingFace config. When
specified via CLI, the argument must be a valid JSON string."""
mm_processor_kwargs: Optional[dict[str, Any]] = None
"""Arguments to be forwarded to the model's processor for multi-modal data,
e.g., image processor. Overrides for the multi-modal processor obtained
from `AutoProcessor.from_pretrained`. The available overrides depend on the
model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`.
When specified via CLI, the argument must be a valid JSON string."""
disable_mm_preprocessor_cache: bool = False
"""If `True`, disable caching of the multi-modal preprocessor/mapper (not
recommended)."""
override_neuron_config: dict[str, Any] = field(default_factory=dict)
"""Initialize non-default neuron config or override default neuron config
that are specific to Neuron devices, this argument will be used to
configure the neuron config that can not be gathered from the vllm
arguments. e.g. `{"cast_logits_dtype": "bloat16"}`. When specified via CLI,
the argument must be a valid JSON string."""
pooler_config: Optional["PoolerConfig"] = field(init=False)
"""Pooler config which controls the behaviour of output pooling in pooling
models."""
override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None
"""Initialize non-default pooling config or override default pooling config
for the pooling model. e.g. `{"pooling_type": "mean", "normalize": false}`.
When specified via CLI, the argument must be a valid JSON string."""
logits_processor_pattern: Optional[str] = None
"""Optional regex pattern specifying valid logits processor qualified names
that can be passed with the `logits_processors` extra completion argument.
Defaults to `None`, which allows no processors."""
generation_config: str = "auto"
"""The folder path to the generation config. Defaults to `"auto"`, the
generation config will be loaded from model path. If set to `"vllm"`, no
generation config is loaded, vLLM defaults will be used. If set to a folder
path, the generation config will be loaded from the specified folder path.
If `max_new_tokens` is specified in generation config, then it sets a
server-wide limit on the number of output tokens for all requests."""
override_generation_config: dict[str, Any] = field(default_factory=dict)
"""Overrides or sets generation config. e.g. `{"temperature": 0.5}`. If
used with `--generation-config auto`, the override parameters will be
merged with the default config from the model. If used with
`--generation-config vllm`, only the override parameters are used.
When specified via CLI, the argument must be a valid JSON string."""
enable_sleep_mode: bool = False
"""Enable sleep mode for the engine (only cuda platform is supported)."""
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO.value
"""Which implementation of the model to use:\n
- "auto" will try to use the vLLM implementation, if it exists, and fall
back to the Transformers implementation if no vLLM implementation is
available.\n
- "vllm" will use the vLLM model implementation.\n
- "transformers" will use the Transformers model implementation."""
def compute_hash(self) -> str: def compute_hash(self) -> str:
""" """
@ -342,92 +428,43 @@ class ModelConfig:
assert_hashable(str_factors) assert_hashable(str_factors)
return hashlib.sha256(str(factors).encode()).hexdigest() return hashlib.sha256(str(factors).encode()).hexdigest()
def __init__( def __post_init__(self) -> None:
self, self.model = maybe_model_redirect(self.model)
model: str, # The tokenizer is consistent with the model by default.
task: Literal[TaskOption, Literal["draft"]], if self.tokenizer is None:
tokenizer: str, self.tokenizer = self.model
tokenizer_mode: str, if self.tokenizer_revision is None:
trust_remote_code: bool, self.tokenizer_revision = self.revision
dtype: Union[str, torch.dtype], self.tokenizer = maybe_model_redirect(self.tokenizer)
seed: int,
hf_config_path: Optional[str] = None,
allowed_local_media_path: str = "",
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict[str, Any]] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
disable_cascade_attn: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, list[str]]] = None,
limit_mm_per_prompt: Optional[dict[str, int]] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
disable_mm_preprocessor_cache: bool = False,
use_async_output_proc: bool = True,
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_token: Optional[Union[bool, str]] = None,
hf_overrides: Optional[HfOverrides] = None,
override_neuron_config: Optional[dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None,
generation_config: str = "auto",
enable_sleep_mode: bool = False,
override_generation_config: Optional[dict[str, Any]] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
) -> None:
self.model = maybe_model_redirect(model)
self.tokenizer = maybe_model_redirect(tokenizer)
self.hf_config_path = hf_config_path if isinstance(self.hf_config_path, str):
if isinstance(hf_config_path, str): self.hf_config_path = maybe_model_redirect(self.hf_config_path)
self.hf_config_path = maybe_model_redirect(hf_config_path)
self.tokenizer_mode = tokenizer_mode if callable(self.hf_overrides):
self.trust_remote_code = trust_remote_code
self.allowed_local_media_path = allowed_local_media_path
self.seed = seed
self.revision = revision
self.code_revision = code_revision
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
self.model_impl = model_impl
if hf_overrides is None:
hf_overrides = {}
if callable(hf_overrides):
hf_overrides_kw = {} hf_overrides_kw = {}
hf_overrides_fn = hf_overrides hf_overrides_fn = self.hf_overrides
else: else:
hf_overrides_kw = hf_overrides hf_overrides_kw = self.hf_overrides
hf_overrides_fn = None hf_overrides_fn = None
if rope_scaling is not None: if self.rope_scaling:
hf_override: dict[str, Any] = {"rope_scaling": rope_scaling} hf_override: dict[str, Any] = {"rope_scaling": self.rope_scaling}
hf_overrides_kw.update(hf_override) hf_overrides_kw.update(hf_override)
hf_overrides_str = json.dumps(hf_overrides) hf_overrides_str = json.dumps(hf_overrides_kw)
msg = ( msg = (
"`--rope-scaling` will be removed in a future release. " "`--rope-scaling` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_overrides_str}'`") f"'Please instead use `--hf-overrides '{hf_overrides_str}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2) warnings.warn(DeprecationWarning(msg), stacklevel=2)
if rope_theta is not None: if self.rope_theta is not None:
hf_override = {"rope_theta": rope_theta} hf_override = {"rope_theta": self.rope_theta}
hf_overrides_kw.update(hf_override) hf_overrides_kw.update(hf_override)
hf_overrides_str = json.dumps(hf_overrides) hf_overrides_str = json.dumps(hf_overrides_kw)
msg = ( msg = (
"`--rope-theta` will be removed in a future release. " "`--rope-theta` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_overrides_str}'`") f"'Please instead use `--hf-overrides '{hf_overrides_str}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2) warnings.warn(DeprecationWarning(msg), stacklevel=2)
self.maybe_pull_model_tokenizer_for_s3(model, tokenizer) self.maybe_pull_model_tokenizer_for_s3(self.model, self.tokenizer)
if (backend := envs.VLLM_ATTENTION_BACKEND if (backend := envs.VLLM_ATTENTION_BACKEND
) and backend == "FLASHINFER" and find_spec("flashinfer") is None: ) and backend == "FLASHINFER" and find_spec("flashinfer") is None:
@ -437,20 +474,6 @@ class ModelConfig:
"https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501 "https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501
"for instructions on how to install it.") "for instructions on how to install it.")
# The tokenizer version is consistent with the model version by default.
if tokenizer_revision is None:
self.tokenizer_revision = revision
else:
self.tokenizer_revision = tokenizer_revision
self.quantization = quantization
self.enforce_eager = enforce_eager
self.max_seq_len_to_capture = max_seq_len_to_capture
self.max_logprobs = max_logprobs
self.disable_sliding_window = disable_sliding_window
self.disable_cascade_attn = disable_cascade_attn
self.skip_tokenizer_init = skip_tokenizer_init
self.enable_sleep_mode = enable_sleep_mode
from vllm.platforms import current_platform from vllm.platforms import current_platform
if (self.enable_sleep_mode if (self.enable_sleep_mode
@ -458,9 +481,12 @@ class ModelConfig:
raise ValueError( raise ValueError(
"Sleep mode is not supported on current platform.") "Sleep mode is not supported on current platform.")
if isinstance(self.config_format, str):
self.config_format = ConfigFormat(self.config_format)
hf_config = get_config(self.hf_config_path or self.model, hf_config = get_config(self.hf_config_path or self.model,
trust_remote_code, revision, code_revision, self.trust_remote_code, self.revision,
config_format) self.code_revision, self.config_format)
if hf_overrides_kw: if hf_overrides_kw:
logger.info("Overriding HF config with %s", hf_overrides_kw) logger.info("Overriding HF config with %s", hf_overrides_kw)
@ -476,13 +502,8 @@ class ModelConfig:
"attention_chunk_size", None) "attention_chunk_size", None)
self.encoder_config = self._get_encoder_config() self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config( self.hf_image_processor_config = get_hf_image_processor_config(
self.model, hf_token=hf_token, revision=revision) self.model, hf_token=self.hf_token, revision=self.revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype)
self.use_async_output_proc = use_async_output_proc
# Set enforce_eager to False if the value is unset.
if self.enforce_eager is None:
self.enforce_eager = False
interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"] interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"]
sliding_window = getattr(self.hf_text_config, "sliding_window", None) sliding_window = getattr(self.hf_text_config, "sliding_window", None)
@ -515,18 +536,14 @@ class ModelConfig:
self.max_model_len = _get_and_verify_max_len( self.max_model_len = _get_and_verify_max_len(
hf_config=self.hf_text_config, hf_config=self.hf_text_config,
max_model_len=max_model_len, max_model_len=self.max_model_len,
disable_sliding_window=self.disable_sliding_window, disable_sliding_window=self.disable_sliding_window,
sliding_window_len=self.get_hf_config_sliding_window(), sliding_window_len=self.get_hf_config_sliding_window(),
spec_target_max_model_len=spec_target_max_model_len, spec_target_max_model_len=self.spec_target_max_model_len,
encoder_config=self.encoder_config) encoder_config=self.encoder_config)
self.served_model_name = get_served_model_name(model, self.served_model_name = get_served_model_name(self.model,
served_model_name) self.served_model_name)
self.multimodal_config = self._init_multimodal_config( self.multimodal_config = self._init_multimodal_config()
limit_mm_per_prompt=limit_mm_per_prompt,
mm_processor_kwargs=mm_processor_kwargs,
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
)
if not self.skip_tokenizer_init: if not self.skip_tokenizer_init:
self._verify_tokenizer_mode() self._verify_tokenizer_mode()
@ -535,24 +552,19 @@ class ModelConfig:
self.has_noops = self._init_has_noops() self.has_noops = self._init_has_noops()
self.has_inner_state = self._init_has_inner_state() self.has_inner_state = self._init_has_inner_state()
if current_platform.is_neuron(): if (not current_platform.is_neuron() and self.override_neuron_config):
self.override_neuron_config = override_neuron_config raise ValueError(
else: "`override_neuron_config` is only supported on Neuron.")
self.override_neuron_config = None
supported_tasks, task = self._resolve_task(task) supported_tasks, task = self._resolve_task(self.task)
self.supported_tasks = supported_tasks self.supported_tasks = supported_tasks
self.task: Final = task self.task = task
if self.task in ("draft", "generate"): if self.task in ("draft", "generate"):
self.truncation_side = "left" self.truncation_side = "left"
else: else:
self.truncation_side = "right" self.truncation_side = "right"
self.pooler_config = self._init_pooler_config(override_pooler_config) self.pooler_config = self._init_pooler_config()
self.logits_processor_pattern = logits_processor_pattern
self.generation_config = generation_config
self.override_generation_config = override_generation_config or {}
self._verify_quantization() self._verify_quantization()
self._verify_cuda_graph() self._verify_cuda_graph()
@ -591,26 +603,21 @@ class ModelConfig:
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"]) model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
self.tokenizer = s3_tokenizer.dir self.tokenizer = s3_tokenizer.dir
def _init_multimodal_config( def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:
self,
limit_mm_per_prompt: Optional[dict[str, int]],
mm_processor_kwargs: Optional[dict[str, Any]],
disable_mm_preprocessor_cache: bool,
) -> Optional["MultiModalConfig"]:
if self.registry.is_multimodal_model(self.architectures): if self.registry.is_multimodal_model(self.architectures):
return MultiModalConfig( return MultiModalConfig(
limit_per_prompt=limit_mm_per_prompt or {}, limit_per_prompt=self.limit_mm_per_prompt,
mm_processor_kwargs=mm_processor_kwargs or {}, mm_processor_kwargs=self.mm_processor_kwargs,
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache, disable_mm_preprocessor_cache=self.
) disable_mm_preprocessor_cache)
if limit_mm_per_prompt: if self.limit_mm_per_prompt:
raise ValueError("`limit_mm_per_prompt` is only supported for " raise ValueError("`limit_mm_per_prompt` is only supported for "
"multimodal models.") "multimodal models.")
if mm_processor_kwargs: if self.mm_processor_kwargs:
raise ValueError("`mm_processor_kwargs` is only supported for " raise ValueError("`mm_processor_kwargs` is only supported for "
"multimodal models.") "multimodal models.")
if disable_mm_preprocessor_cache: if self.disable_mm_preprocessor_cache:
raise ValueError("`disable_mm_preprocessor_cache` is only " raise ValueError("`disable_mm_preprocessor_cache` is only "
"supported for multimodal models.") "supported for multimodal models.")
@ -620,31 +627,32 @@ class ModelConfig:
return get_sentence_transformer_tokenizer_config( return get_sentence_transformer_tokenizer_config(
self.model, self.revision) self.model, self.revision)
def _init_pooler_config( def _init_pooler_config(self) -> Optional["PoolerConfig"]:
self,
override_pooler_config: Optional["PoolerConfig"],
) -> Optional["PoolerConfig"]:
if self.runner_type == "pooling": if self.runner_type == "pooling":
user_config = override_pooler_config or PoolerConfig() if isinstance(self.override_pooler_config, dict):
self.override_pooler_config = PoolerConfig(
**self.override_pooler_config)
pooler_config = self.override_pooler_config or PoolerConfig()
base_config = get_pooling_config(self.model, self.revision) base_config = get_pooling_config(self.model, self.revision)
if base_config is not None: if base_config is not None:
# Only set values that are not overridden by the user # Only set values that are not overridden by the user
for k, v in base_config.items(): for k, v in base_config.items():
if getattr(user_config, k) is None: if getattr(pooler_config, k) is None:
setattr(user_config, k, v) setattr(pooler_config, k, v)
if self.is_matryoshka: if self.is_matryoshka:
if user_config.normalize is None: if pooler_config.normalize is None:
user_config.normalize = True pooler_config.normalize = True
elif not user_config.normalize: elif not pooler_config.normalize:
raise ValueError( raise ValueError(
"`normalize` must be enabled (set to True) " "`normalize` must be enabled (set to True) "
"for models that are compatible with " "for models that are compatible with "
"Matryoshka Representation.") "Matryoshka Representation.")
return user_config return pooler_config
return None return None
@ -662,11 +670,11 @@ class ModelConfig:
return self.registry.model_has_inner_state(self.architectures) return self.registry.model_has_inner_state(self.architectures)
def _verify_tokenizer_mode(self) -> None: def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower() tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower())
if tokenizer_mode not in ["auto", "slow", "mistral", "custom"]: if tokenizer_mode not in get_args(TokenizerMode):
raise ValueError( raise ValueError(
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
"either 'auto', 'slow', 'mistral' or 'custom'.") f"one of {get_args(TokenizerMode)}.")
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
def _get_preferred_task( def _get_preferred_task(
@ -781,7 +789,8 @@ class ModelConfig:
"quark", "nvfp4", "bitblas", "gptq_bitblas" "quark", "nvfp4", "bitblas", "gptq_bitblas"
] ]
if self.quantization is not None: if self.quantization is not None:
self.quantization = self.quantization.lower() self.quantization = cast(QuantizationMethods,
self.quantization.lower())
# Parse quantization method from the HF model config, if available. # Parse quantization method from the HF model config, if available.
quant_cfg = self._parse_quant_hf_config() quant_cfg = self._parse_quant_hf_config()
@ -857,8 +866,6 @@ class ModelConfig:
"non-quantized models.", self.quantization) "non-quantized models.", self.quantization)
def _verify_cuda_graph(self) -> None: def _verify_cuda_graph(self) -> None:
if self.max_seq_len_to_capture is None:
self.max_seq_len_to_capture = self.max_model_len
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
self.max_model_len) self.max_model_len)
ROCM_UNSUPPORTED_MODELS = ['mllama'] ROCM_UNSUPPORTED_MODELS = ['mllama']
@ -1294,7 +1301,7 @@ class ModelConfig:
@property @property
def runner_type(self) -> RunnerType: def runner_type(self) -> RunnerType:
return _TASK_RUNNER[self.task] return _TASK_RUNNER[cast(_ResolvedTask, self.task)]
@property @property
def is_v1_compatible(self) -> bool: def is_v1_compatible(self) -> bool:
@ -2201,7 +2208,7 @@ class SpeculativeConfig:
according to the log probability settings in SamplingParams.""" according to the log probability settings in SamplingParams."""
# Draft model configuration # Draft model configuration
quantization: Optional[str] = None quantization: Optional[QuantizationMethods] = None
"""Quantization method that was used to quantize the draft model weights. """Quantization method that was used to quantize the draft model weights.
If `None`, we assume the model weights are not quantized. Note that it only If `None`, we assume the model weights are not quantized. Note that it only
takes effect when using the draft model-based speculative method.""" takes effect when using the draft model-based speculative method."""
@ -2386,7 +2393,6 @@ class SpeculativeConfig:
code_revision=self.code_revision, code_revision=self.code_revision,
tokenizer_revision=self.target_model_config. tokenizer_revision=self.target_model_config.
tokenizer_revision, tokenizer_revision,
max_model_len=None,
spec_target_max_model_len=self.target_model_config. spec_target_max_model_len=self.target_model_config.
max_model_len, max_model_len,
quantization=self.quantization, quantization=self.quantization,
@ -2793,30 +2799,31 @@ class PromptAdapterConfig:
class MultiModalConfig: class MultiModalConfig:
"""Controls the behavior of multimodal models.""" """Controls the behavior of multimodal models."""
limit_per_prompt: dict[str, int] = field(default_factory=dict) limit_per_prompt: dict[str, int] = get_field(ModelConfig,
"limit_mm_per_prompt")
""" """
The maximum number of input items allowed per prompt for each modality. The maximum number of input items allowed per prompt for each modality.
This should be a JSON string that will be parsed into a dictionary. This should be a JSON string that will be parsed into a dictionary.
Defaults to 1 (V0) or 999 (V1) for each modality. Defaults to 1 (V0) or 999 (V1) for each modality.
For example, to allow up to 16 images and 2 videos per prompt: For example, to allow up to 16 images and 2 videos per prompt:
:code:`{"images": 16, "videos": 2}` `{"images": 16, "videos": 2}`
""" """
mm_processor_kwargs: Optional[dict[str, object]] = None mm_processor_kwargs: Optional[dict[str, object]] = None
""" """
Overrides for the multi-modal processor obtained from Overrides for the multi-modal processor obtained from
:meth:`transformers.AutoProcessor.from_pretrained`. `transformers.AutoProcessor.from_pretrained`.
The available overrides depend on the model that is being run. The available overrides depend on the model that is being run.
For example, for Phi-3-Vision: For example, for Phi-3-Vision:
:code:`{"num_crops": 4}`. `{"num_crops": 4}`.
""" """
disable_mm_preprocessor_cache: bool = False disable_mm_preprocessor_cache: bool = False
""" """
If :code:`True`, disable caching of the processed multi-modal inputs. If `True`, disable caching of the processed multi-modal inputs.
""" """
def compute_hash(self) -> str: def compute_hash(self) -> str:
@ -2907,10 +2914,6 @@ class PoolerConfig:
usedforsecurity=False).hexdigest() usedforsecurity=False).hexdigest()
return hash_str return hash_str
@staticmethod
def from_json(json_str: str) -> "PoolerConfig":
return PoolerConfig(**json.loads(json_str))
_STR_DTYPE_TO_TORCH_DTYPE = { _STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16, "half": torch.float16,

View File

@ -20,15 +20,16 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DeviceConfig, DistributedExecutorBackend, DeviceConfig, DistributedExecutorBackend,
GuidedDecodingBackend, GuidedDecodingBackendV1, GuidedDecodingBackend, GuidedDecodingBackendV1,
HfOverrides, KVTransferConfig, LoadConfig, LoadFormat, HfOverrides, KVTransferConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, ModelImpl, MultiModalConfig, LoRAConfig, ModelConfig, ModelDType, ModelImpl,
ObservabilityConfig, ParallelConfig, PoolerConfig, MultiModalConfig, ObservabilityConfig, ParallelConfig,
PrefixCachingHashAlgo, PromptAdapterConfig, PoolerConfig, PrefixCachingHashAlgo,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig, PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
TaskOption, TokenizerPoolConfig, VllmConfig, SpeculativeConfig, TaskOption, TokenizerMode,
get_attr_docs, get_field) TokenizerPoolConfig, VllmConfig, get_attr_docs,
get_field)
from vllm.executor.executor_base import ExecutorBase from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.plugins import load_general_plugins from vllm.plugins import load_general_plugins
from vllm.reasoning import ReasoningParserManager from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3 from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
@ -183,6 +184,9 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
kwargs[name]["nargs"] = "+" kwargs[name]["nargs"] = "+"
elif contains_type(type_hints, int): elif contains_type(type_hints, int):
kwargs[name]["type"] = int kwargs[name]["type"] = int
# Special case for large integers
if name in {"max_model_len"}:
kwargs[name]["type"] = human_readable_int
elif contains_type(type_hints, float): elif contains_type(type_hints, float):
kwargs[name]["type"] = float kwargs[name]["type"] = float
elif contains_type(type_hints, dict): elif contains_type(type_hints, dict):
@ -212,22 +216,23 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
@dataclass @dataclass
class EngineArgs: class EngineArgs:
"""Arguments for vLLM engine.""" """Arguments for vLLM engine."""
model: str = 'facebook/opt-125m' model: str = ModelConfig.model
served_model_name: Optional[Union[str, List[str]]] = None served_model_name: Optional[Union[
tokenizer: Optional[str] = None str, List[str]]] = ModelConfig.served_model_name
hf_config_path: Optional[str] = None tokenizer: Optional[str] = ModelConfig.tokenizer
task: TaskOption = "auto" hf_config_path: Optional[str] = ModelConfig.hf_config_path
skip_tokenizer_init: bool = False task: TaskOption = ModelConfig.task
tokenizer_mode: str = 'auto' skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
trust_remote_code: bool = False tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
allowed_local_media_path: str = "" trust_remote_code: bool = ModelConfig.trust_remote_code
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
download_dir: Optional[str] = LoadConfig.download_dir download_dir: Optional[str] = LoadConfig.download_dir
load_format: str = LoadConfig.load_format load_format: str = LoadConfig.load_format
config_format: ConfigFormat = ConfigFormat.AUTO config_format: str = ModelConfig.config_format
dtype: str = 'auto' dtype: ModelDType = ModelConfig.dtype
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
seed: Optional[int] = None seed: Optional[int] = ModelConfig.seed
max_model_len: Optional[int] = None max_model_len: Optional[int] = ModelConfig.max_model_len
# Note: Specifying a custom executor backend by passing a class # Note: Specifying a custom executor backend by passing a class
# is intended for expert use only. The API may change without # is intended for expert use only. The API may change without
# notice. # notice.
@ -245,8 +250,8 @@ class EngineArgs:
enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
prefix_caching_hash_algo: PrefixCachingHashAlgo = \ prefix_caching_hash_algo: PrefixCachingHashAlgo = \
CacheConfig.prefix_caching_hash_algo CacheConfig.prefix_caching_hash_algo
disable_sliding_window: bool = False disable_sliding_window: bool = ModelConfig.disable_sliding_window
disable_cascade_attn: bool = False disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
use_v2_block_manager: bool = True use_v2_block_manager: bool = True
swap_space: float = CacheConfig.swap_space swap_space: float = CacheConfig.swap_space
cpu_offload_gb: float = CacheConfig.cpu_offload_gb cpu_offload_gb: float = CacheConfig.cpu_offload_gb
@ -258,18 +263,19 @@ class EngineArgs:
long_prefill_token_threshold: int = \ long_prefill_token_threshold: int = \
SchedulerConfig.long_prefill_token_threshold SchedulerConfig.long_prefill_token_threshold
max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API max_logprobs: int = ModelConfig.max_logprobs
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = None revision: Optional[str] = ModelConfig.revision
code_revision: Optional[str] = None code_revision: Optional[str] = ModelConfig.code_revision
rope_scaling: Optional[Dict[str, Any]] = None rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
rope_theta: Optional[float] = None rope_theta: Optional[float] = ModelConfig.rope_theta
hf_token: Optional[Union[bool, str]] = None hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token
hf_overrides: Optional[HfOverrides] = None hf_overrides: Optional[HfOverrides] = \
tokenizer_revision: Optional[str] = None get_field(ModelConfig, "hf_overrides")
quantization: Optional[str] = None tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
enforce_eager: Optional[bool] = None quantization: Optional[QuantizationMethods] = ModelConfig.quantization
max_seq_len_to_capture: int = 8192 enforce_eager: bool = ModelConfig.enforce_eager
max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
# The following three fields are deprecated and will be removed in a future # The following three fields are deprecated and will be removed in a future
# release. Setting them will have no effect. Please remove them from your # release. Setting them will have no effect. Please remove them from your
@ -280,8 +286,10 @@ class EngineArgs:
get_field(TokenizerPoolConfig, "extra_config") get_field(TokenizerPoolConfig, "extra_config")
limit_mm_per_prompt: dict[str, int] = \ limit_mm_per_prompt: dict[str, int] = \
get_field(MultiModalConfig, "limit_per_prompt") get_field(MultiModalConfig, "limit_per_prompt")
mm_processor_kwargs: Optional[Dict[str, Any]] = None mm_processor_kwargs: Optional[Dict[str, Any]] = \
disable_mm_preprocessor_cache: bool = False MultiModalConfig.mm_processor_kwargs
disable_mm_preprocessor_cache: bool = \
MultiModalConfig.disable_mm_preprocessor_cache
# LoRA fields # LoRA fields
enable_lora: bool = False enable_lora: bool = False
enable_lora_bias: bool = LoRAConfig.bias_enabled enable_lora_bias: bool = LoRAConfig.bias_enabled
@ -323,7 +331,8 @@ class EngineArgs:
DecodingConfig.disable_any_whitespace DecodingConfig.disable_any_whitespace
guided_decoding_disable_additional_properties: bool = \ guided_decoding_disable_additional_properties: bool = \
DecodingConfig.disable_additional_properties DecodingConfig.disable_additional_properties
logits_processor_pattern: Optional[str] = None logits_processor_pattern: Optional[
str] = ModelConfig.logits_processor_pattern
speculative_config: Optional[Dict[str, Any]] = None speculative_config: Optional[Dict[str, Any]] = None
@ -331,22 +340,25 @@ class EngineArgs:
show_hidden_metrics_for_version: Optional[str] = None show_hidden_metrics_for_version: Optional[str] = None
otlp_traces_endpoint: Optional[str] = None otlp_traces_endpoint: Optional[str] = None
collect_detailed_traces: Optional[str] = None collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False disable_async_output_proc: bool = not ModelConfig.use_async_output_proc
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
override_neuron_config: Optional[Dict[str, Any]] = None override_neuron_config: dict[str, Any] = \
override_pooler_config: Optional[PoolerConfig] = None get_field(ModelConfig, "override_neuron_config")
override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
ModelConfig.override_pooler_config
compilation_config: Optional[CompilationConfig] = None compilation_config: Optional[CompilationConfig] = None
worker_cls: str = ParallelConfig.worker_cls worker_cls: str = ParallelConfig.worker_cls
worker_extension_cls: str = ParallelConfig.worker_extension_cls worker_extension_cls: str = ParallelConfig.worker_extension_cls
kv_transfer_config: Optional[KVTransferConfig] = None kv_transfer_config: Optional[KVTransferConfig] = None
generation_config: Optional[str] = "auto" generation_config: str = ModelConfig.generation_config
override_generation_config: Optional[Dict[str, Any]] = None enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
enable_sleep_mode: bool = False override_generation_config: dict[str, Any] = \
model_impl: str = "auto" get_field(ModelConfig, "override_generation_config")
model_impl: str = ModelConfig.model_impl
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
@ -356,9 +368,6 @@ class EngineArgs:
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
def __post_init__(self): def __post_init__(self):
if not self.tokenizer:
self.tokenizer = self.model
# support `EngineArgs(compilation_config={...})` # support `EngineArgs(compilation_config={...})`
# without having to manually construct a # without having to manually construct a
# CompilationConfig object # CompilationConfig object
@ -375,80 +384,87 @@ class EngineArgs:
"""Shared CLI arguments for vLLM engine.""" """Shared CLI arguments for vLLM engine."""
# Model arguments # Model arguments
parser.add_argument( model_kwargs = get_kwargs(ModelConfig)
'--model', model_group = parser.add_argument_group(
title="ModelConfig",
description=ModelConfig.__doc__,
)
model_group.add_argument("--model", **model_kwargs["model"])
model_group.add_argument("--task", **model_kwargs["task"])
model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
model_group.add_argument("--tokenizer-mode",
**model_kwargs["tokenizer_mode"])
model_group.add_argument("--trust-remote-code",
**model_kwargs["trust_remote_code"])
model_group.add_argument("--dtype", **model_kwargs["dtype"])
model_group.add_argument("--seed", **model_kwargs["seed"])
model_group.add_argument("--hf-config-path",
**model_kwargs["hf_config_path"])
model_group.add_argument("--allowed-local-media-path",
**model_kwargs["allowed_local_media_path"])
model_group.add_argument("--revision", **model_kwargs["revision"])
model_group.add_argument("--code-revision",
**model_kwargs["code_revision"])
model_group.add_argument("--rope-scaling",
**model_kwargs["rope_scaling"])
model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"])
model_group.add_argument("--tokenizer-revision",
**model_kwargs["tokenizer_revision"])
model_group.add_argument("--max-model-len",
**model_kwargs["max_model_len"])
model_group.add_argument("--quantization", "-q",
**model_kwargs["quantization"])
model_group.add_argument("--enforce-eager",
**model_kwargs["enforce_eager"])
model_group.add_argument("--max-seq-len-to-capture",
**model_kwargs["max_seq_len_to_capture"])
model_group.add_argument("--max-logprobs",
**model_kwargs["max_logprobs"])
model_group.add_argument("--disable-sliding-window",
**model_kwargs["disable_sliding_window"])
model_group.add_argument("--disable-cascade-attn",
**model_kwargs["disable_cascade_attn"])
model_group.add_argument("--skip-tokenizer-init",
**model_kwargs["skip_tokenizer_init"])
model_group.add_argument("--served-model-name",
**model_kwargs["served_model_name"])
# This one is a special case because it is the
# opposite of ModelConfig.use_async_output_proc
model_group.add_argument(
"--disable-async-output-proc",
action="store_true",
default=EngineArgs.disable_async_output_proc,
help="Disable async output processing. This may result in "
"lower performance.")
model_group.add_argument("--config-format",
choices=[f.value for f in ConfigFormat],
**model_kwargs["config_format"])
# This one is a special case because it can bool
# or str. TODO: Handle this in get_kwargs
model_group.add_argument("--hf-token",
type=str, type=str,
default=EngineArgs.model, nargs="?",
help='Name or path of the huggingface model to use.') const=True,
parser.add_argument( default=model_kwargs["hf_token"]["default"],
'--task', help=model_kwargs["hf_token"]["help"])
default=EngineArgs.task, model_group.add_argument("--hf-overrides",
choices=get_args(TaskOption), **model_kwargs["hf_overrides"])
help='The task to use the model for. Each vLLM instance only ' model_group.add_argument("--override-neuron-config",
'supports one task, even if the same model can be used for ' **model_kwargs["override_neuron_config"])
'multiple tasks. When the model only supports one task, ``"auto"`` ' model_group.add_argument("--override-pooler-config",
'can be used to select it; otherwise, you must specify explicitly ' **model_kwargs["override_pooler_config"])
'which task to use.') model_group.add_argument("--logits-processor-pattern",
parser.add_argument( **model_kwargs["logits_processor_pattern"])
'--tokenizer', model_group.add_argument("--generation-config",
type=optional_type(str), **model_kwargs["generation_config"])
default=EngineArgs.tokenizer, model_group.add_argument("--override-generation-config",
help='Name or path of the huggingface tokenizer to use. ' **model_kwargs["override_generation_config"])
'If unspecified, model name or path will be used.') model_group.add_argument("--enable-sleep-mode",
parser.add_argument( **model_kwargs["enable_sleep_mode"])
"--hf-config-path", model_group.add_argument("--model-impl",
type=optional_type(str), choices=[f.value for f in ModelImpl],
default=EngineArgs.hf_config_path, **model_kwargs["model_impl"])
help='Name or path of the huggingface config to use. '
'If unspecified, model name or path will be used.')
parser.add_argument(
'--skip-tokenizer-init',
action='store_true',
help='Skip initialization of tokenizer and detokenizer. '
'Expects valid prompt_token_ids and None for prompt from '
'the input. The generated output will contain token ids.')
parser.add_argument(
'--revision',
type=optional_type(str),
default=None,
help='The specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument(
'--code-revision',
type=optional_type(str),
default=None,
help='The specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.')
parser.add_argument(
'--tokenizer-revision',
type=optional_type(str),
default=None,
help='Revision of the huggingface tokenizer to use. '
'It can be a branch name, a tag name, or a commit id. '
'If unspecified, will use the default version.')
parser.add_argument(
'--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow', 'mistral', 'custom'],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer. \n* '
'"custom" will use --tokenizer to select the '
'preregistered tokenizer.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='Trust remote code from huggingface.')
parser.add_argument(
'--allowed-local-media-path',
type=str,
help="Allowing API requests to read local images or videos "
"from directories specified by the server file system. "
"This is a security risk. "
"Should only be enabled in trusted environments.")
# Model loading arguments # Model loading arguments
load_kwargs = get_kwargs(LoadConfig) load_kwargs = get_kwargs(LoadConfig)
load_group = parser.add_argument_group( load_group = parser.add_argument_group(
@ -465,38 +481,6 @@ class EngineArgs:
load_group.add_argument('--use-tqdm-on-load', load_group.add_argument('--use-tqdm-on-load',
**load_kwargs["use_tqdm_on_load"]) **load_kwargs["use_tqdm_on_load"])
parser.add_argument(
'--config-format',
default=EngineArgs.config_format,
choices=[f.value for f in ConfigFormat],
help='The format of the model config to load.\n\n'
'* "auto" will try to load the config in hf format '
'if available else it will try to load in mistral format ')
parser.add_argument(
'--dtype',
type=str,
default=EngineArgs.dtype,
choices=[
'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
],
help='Data type for model weights and activations.\n\n'
'* "auto" will use FP16 precision for FP32 and FP16 models, and '
'BF16 precision for BF16 models.\n'
'* "half" for FP16. Recommended for AWQ quantization.\n'
'* "float16" is the same as "half".\n'
'* "bfloat16" for a balance between precision and range.\n'
'* "float" is shorthand for FP32 precision.\n'
'* "float32" for FP32 precision.')
parser.add_argument('--max-model-len',
type=human_readable_int,
default=EngineArgs.max_model_len,
help='Model context length. If unspecified, will '
'be automatically derived from the model config. '
'Supports k/m/g/K/M/G in human-readable format.\n'
'Examples:\n'
'- 1k → 1000\n'
'- 1K → 1024\n')
# Guided decoding arguments # Guided decoding arguments
guided_decoding_kwargs = get_kwargs(DecodingConfig) guided_decoding_kwargs = get_kwargs(DecodingConfig)
guided_decoding_group = parser.add_argument_group( guided_decoding_group = parser.add_argument_group(
@ -520,26 +504,6 @@ class EngineArgs:
choices=list(ReasoningParserManager.reasoning_parsers), choices=list(ReasoningParserManager.reasoning_parsers),
**guided_decoding_kwargs["reasoning_backend"]) **guided_decoding_kwargs["reasoning_backend"])
parser.add_argument(
'--logits-processor-pattern',
type=optional_type(str),
default=None,
help='Optional regex pattern specifying valid logits processor '
'qualified names that can be passed with the `logits_processors` '
'extra completion argument. Defaults to None, which allows no '
'processors.')
parser.add_argument(
'--model-impl',
type=str,
default=EngineArgs.model_impl,
choices=[f.value for f in ModelImpl],
help='Which implementation of the model to use.\n\n'
'* "auto" will try to use the vLLM implementation if it exists '
'and fall back to the Transformers implementation if no vLLM '
'implementation is available.\n'
'* "vllm" will use the vLLM model implementation.\n'
'* "transformers" will use the Transformers model '
'implementation.\n')
# Parallel arguments # Parallel arguments
parallel_kwargs = get_kwargs(ParallelConfig) parallel_kwargs = get_kwargs(ParallelConfig)
parallel_group = parser.add_argument_group( parallel_group = parser.add_argument_group(
@ -592,10 +556,6 @@ class EngineArgs:
cache_group.add_argument('--calculate-kv-scales', cache_group.add_argument('--calculate-kv-scales',
**cache_kwargs["calculate_kv_scales"]) **cache_kwargs["calculate_kv_scales"])
parser.add_argument('--disable-sliding-window',
action='store_true',
help='Disables sliding window, '
'capping to sliding window size.')
parser.add_argument('--use-v2-block-manager', parser.add_argument('--use-v2-block-manager',
action='store_true', action='store_true',
default=True, default=True,
@ -605,73 +565,9 @@ class EngineArgs:
'Setting this flag to True or False' 'Setting this flag to True or False'
' has no effect on vLLM behavior.') ' has no effect on vLLM behavior.')
parser.add_argument('--seed',
type=int,
default=EngineArgs.seed,
help='Random seed for operations.')
parser.add_argument(
'--max-logprobs',
type=int,
default=EngineArgs.max_logprobs,
help=('Max number of log probs to return logprobs is specified in'
' SamplingParams.'))
parser.add_argument('--disable-log-stats', parser.add_argument('--disable-log-stats',
action='store_true', action='store_true',
help='Disable logging statistics.') help='Disable logging statistics.')
# Quantization settings.
parser.add_argument('--quantization',
'-q',
type=optional_type(str),
choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.quantization,
help='Method used to quantize the weights. If '
'None, we first check the `quantization_config` '
'attribute in the model config file. If that is '
'None, we assume the model weights are not '
'quantized and use `dtype` to determine the data '
'type of the weights.')
parser.add_argument(
'--rope-scaling',
default=None,
type=json.loads,
help='RoPE scaling configuration in JSON format. '
'For example, ``{"rope_type":"dynamic","factor":2.0}``')
parser.add_argument('--rope-theta',
default=None,
type=float,
help='RoPE theta. Use with `rope_scaling`. In '
'some cases, changing the RoPE theta improves the '
'performance of the scaled model.')
parser.add_argument(
'--hf-token',
type=str,
nargs='?',
const=True,
default=None,
help='The token to use as HTTP bearer authorization'
' for remote files. If `True`, will use the token '
'generated when running `huggingface-cli login` '
'(stored in `~/.huggingface`).')
parser.add_argument('--hf-overrides',
type=json.loads,
default=EngineArgs.hf_overrides,
help='Extra arguments for the HuggingFace config. '
'This should be a JSON string that will be '
'parsed into a dictionary.')
parser.add_argument('--enforce-eager',
action='store_true',
help='Always use eager-mode PyTorch. If False, '
'will use eager mode and CUDA graph in hybrid '
'for maximal performance and flexibility.')
parser.add_argument('--max-seq-len-to-capture',
type=int,
default=EngineArgs.max_seq_len_to_capture,
help='Maximum sequence length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode. '
'Additionally for encoder-decoder models, if the '
'sequence length of the encoder input is larger '
'than this, we fall back to the eager mode.')
# Tokenizer arguments # Tokenizer arguments
tokenizer_kwargs = get_kwargs(TokenizerPoolConfig) tokenizer_kwargs = get_kwargs(TokenizerPoolConfig)
@ -775,20 +671,6 @@ class EngineArgs:
"Default to `original/**/*` to avoid repeated loading of llama's " "Default to `original/**/*` to avoid repeated loading of llama's "
"checkpoints.") "checkpoints.")
parser.add_argument(
"--served-model-name",
nargs="+",
type=str,
default=None,
help="The model name(s) used in the API. If multiple "
"names are provided, the server will respond to any "
"of the provided names. The model name in the model "
"field of a response will be the first name in this "
"list. If not specified, the model name will be the "
"same as the ``--model`` argument. Noted that this name(s) "
"will also be used in `model_name` tag content of "
"prometheus metrics, if multiple names provided, metrics "
"tag will take the first one.")
parser.add_argument('--qlora-adapter-name-or-path', parser.add_argument('--qlora-adapter-name-or-path',
type=str, type=str,
default=None, default=None,
@ -822,13 +704,6 @@ class EngineArgs:
"modules. This involves use of possibly costly and or blocking " "modules. This involves use of possibly costly and or blocking "
"operations and hence might have a performance impact.") "operations and hence might have a performance impact.")
parser.add_argument(
'--disable-async-output-proc',
action='store_true',
default=EngineArgs.disable_async_output_proc,
help="Disable async output processing. This may result in "
"lower performance.")
# Scheduler arguments # Scheduler arguments
scheduler_kwargs = get_kwargs(SchedulerConfig) scheduler_kwargs = get_kwargs(SchedulerConfig)
scheduler_group = parser.add_argument_group( scheduler_group = parser.add_argument_group(
@ -871,19 +746,6 @@ class EngineArgs:
parser.add_argument('--scheduler-cls', parser.add_argument('--scheduler-cls',
**scheduler_kwargs["scheduler_cls"]) **scheduler_kwargs["scheduler_cls"])
parser.add_argument(
'--override-neuron-config',
type=json.loads,
default=None,
help="Override or set neuron device configuration. "
"e.g. ``{\"cast_logits_dtype\": \"bloat16\"}``.")
parser.add_argument(
'--override-pooler-config',
type=PoolerConfig.from_json,
default=None,
help="Override or set the pooling method for pooling models. "
"e.g. ``{\"pooling_type\": \"mean\", \"normalize\": false}``.")
parser.add_argument('--compilation-config', parser.add_argument('--compilation-config',
'-O', '-O',
type=CompilationConfig.from_cli, type=CompilationConfig.from_cli,
@ -920,34 +782,6 @@ class EngineArgs:
help='The worker extension class on top of the worker cls, ' help='The worker extension class on top of the worker cls, '
'it is useful if you just want to add new functions to the worker ' 'it is useful if you just want to add new functions to the worker '
'class without changing the existing functions.') 'class without changing the existing functions.')
parser.add_argument(
"--generation-config",
type=optional_type(str),
default="auto",
help="The folder path to the generation config. "
"Defaults to 'auto', the generation config will be loaded from "
"model path. If set to 'vllm', no generation config is loaded, "
"vLLM defaults will be used. If set to a folder path, the "
"generation config will be loaded from the specified folder path. "
"If `max_new_tokens` is specified in generation config, then "
"it sets a server-wide limit on the number of output tokens "
"for all requests.")
parser.add_argument(
"--override-generation-config",
type=json.loads,
default=None,
help="Overrides or sets generation config in JSON format. "
"e.g. ``{\"temperature\": 0.5}``. If used with "
"--generation-config=auto, the override parameters will be merged "
"with the default config from the model. If generation-config is "
"None, only the override parameters are used.")
parser.add_argument("--enable-sleep-mode",
action="store_true",
default=False,
help="Enable sleep mode for the engine. "
"(only cuda platform is supported)")
parser.add_argument( parser.add_argument(
"--additional-config", "--additional-config",
@ -966,16 +800,6 @@ class EngineArgs:
"If enabled, the model will be able to generate reasoning content." "If enabled, the model will be able to generate reasoning content."
) )
parser.add_argument(
"--disable-cascade-attn",
action="store_true",
default=False,
help="Disable cascade attention for V1. While cascade attention "
"does not change the mathematical correctness, disabling it "
"could be useful for preventing potential numerical issues. "
"Note that even if this is set to False, cascade attention will be "
"only used when the heuristic tells that it's beneficial.")
return parser return parser
@classmethod @classmethod
@ -1002,8 +826,7 @@ class EngineArgs:
model=self.model, model=self.model,
hf_config_path=self.hf_config_path, hf_config_path=self.hf_config_path,
task=self.task, task=self.task,
# We know this is not None because we set it in __post_init__ tokenizer=self.tokenizer,
tokenizer=cast(str, self.tokenizer),
tokenizer_mode=self.tokenizer_mode, tokenizer_mode=self.tokenizer_mode,
trust_remote_code=self.trust_remote_code, trust_remote_code=self.trust_remote_code,
allowed_local_media_path=self.allowed_local_media_path, allowed_local_media_path=self.allowed_local_media_path,

View File

@ -13,7 +13,7 @@ from typing_extensions import TypeVar, deprecated
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score) BeamSearchSequence, get_beam_search_score)
from vllm.config import CompilationConfig from vllm.config import CompilationConfig, ModelDType, TokenizerMode
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig, from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
TaskOption) TaskOption)
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
@ -32,6 +32,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding.guided_fields import ( from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest, LLMGuidedOptions) GuidedDecodingRequest, LLMGuidedOptions)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput, from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
PoolingRequestOutput, RequestOutput, PoolingRequestOutput, RequestOutput,
ScoringRequestOutput) ScoringRequestOutput)
@ -163,20 +164,20 @@ class LLM:
self, self,
model: str, model: str,
tokenizer: Optional[str] = None, tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto", tokenizer_mode: TokenizerMode = "auto",
skip_tokenizer_init: bool = False, skip_tokenizer_init: bool = False,
trust_remote_code: bool = False, trust_remote_code: bool = False,
allowed_local_media_path: str = "", allowed_local_media_path: str = "",
tensor_parallel_size: int = 1, tensor_parallel_size: int = 1,
dtype: str = "auto", dtype: ModelDType = "auto",
quantization: Optional[str] = None, quantization: Optional[QuantizationMethods] = None,
revision: Optional[str] = None, revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None, tokenizer_revision: Optional[str] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
gpu_memory_utilization: float = 0.9, gpu_memory_utilization: float = 0.9,
swap_space: float = 4, swap_space: float = 4,
cpu_offload_gb: float = 0, cpu_offload_gb: float = 0,
enforce_eager: Optional[bool] = None, enforce_eager: bool = False,
max_seq_len_to_capture: int = 8192, max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False, disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False, disable_async_output_proc: bool = False,
@ -189,12 +190,7 @@ class LLM:
compilation_config: Optional[Union[int, dict[str, Any]]] = None, compilation_config: Optional[Union[int, dict[str, Any]]] = None,
**kwargs, **kwargs,
) -> None: ) -> None:
''' """LLM constructor."""
LLM constructor.
Note: if enforce_eager is unset (enforce_eager is None)
it defaults to False.
'''
if "disable_log_stats" not in kwargs: if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True kwargs["disable_log_stats"] = True

View File

@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
@ -186,7 +187,7 @@ class AQLMConfig(QuantizationConfig):
f"out_group_size={self.out_group_size})") f"out_group_size={self.out_group_size})")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "aqlm" return "aqlm"
@classmethod @classmethod

View File

@ -7,6 +7,7 @@ import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.parameter import (GroupQuantScaleParameter, from vllm.model_executor.parameter import (GroupQuantScaleParameter,
@ -44,7 +45,7 @@ class AWQConfig(QuantizationConfig):
f"zero_point={self.zero_point}, " f"zero_point={self.zero_point}, "
f"modules_to_not_convert={self.modules_to_not_convert})") f"modules_to_not_convert={self.modules_to_not_convert})")
def get_name(self) -> str: def get_name(self) -> QuantizationMethods:
return "awq" return "awq"
def get_supported_act_dtypes(self) -> List[torch.dtype]: def get_supported_act_dtypes(self) -> List[torch.dtype]:

View File

@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod, UnquantizedLinearMethod,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.awq import (AWQConfig, from vllm.model_executor.layers.quantization.awq import (AWQConfig,
is_layer_skipped_awq) is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
@ -73,7 +74,7 @@ class AWQMarlinConfig(QuantizationConfig):
f"modules_to_not_convert={self.modules_to_not_convert})") f"modules_to_not_convert={self.modules_to_not_convert})")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "awq_marlin" return "awq_marlin"
@classmethod @classmethod
@ -101,8 +102,8 @@ class AWQMarlinConfig(QuantizationConfig):
modules_to_not_convert, config) modules_to_not_convert, config)
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(
user_quant) -> Optional[str]: cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg) can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "marlin" is_valid_user_quant = (user_quant is None or user_quant == "marlin"
or user_quant == "awq_marlin") or user_quant == "awq_marlin")

View File

@ -2,11 +2,16 @@
import inspect import inspect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
import torch import torch
from torch import nn from torch import nn
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods
else:
QuantizationMethods = str
class QuantizeMethodBase(ABC): class QuantizeMethodBase(ABC):
"""Base class for different quantized methods.""" """Base class for different quantized methods."""
@ -66,7 +71,7 @@ class QuantizationConfig(ABC):
self.packed_modules_mapping: Dict[str, List[str]] = dict() self.packed_modules_mapping: Dict[str, List[str]] = dict()
@abstractmethod @abstractmethod
def get_name(self) -> str: def get_name(self) -> QuantizationMethods:
"""Name of the quantization method.""" """Name of the quantization method."""
raise NotImplementedError raise NotImplementedError
@ -99,8 +104,8 @@ class QuantizationConfig(ABC):
raise NotImplementedError raise NotImplementedError
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(
user_quant) -> Optional[str]: cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
""" """
Detects if this quantization method can support a given checkpoint Detects if this quantization method can support a given checkpoint
format by overriding the user specified quantization method -- format by overriding the user specified quantization method --

View File

@ -5,6 +5,7 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import ( from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
@ -100,7 +101,7 @@ class BitBLASConfig(QuantizationConfig):
f"quant_method={self.quant_method})") f"quant_method={self.quant_method})")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "bitblas" return "bitblas"
@classmethod @classmethod
@ -139,8 +140,8 @@ class BitBLASConfig(QuantizationConfig):
lm_head_quantized) lm_head_quantized)
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(
user_quant) -> Optional[str]: cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
# compat: autogptq >=0.8.0 use checkpoint_format: str # compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_bitblas_format: bool # compat: autogptq <=0.7.1 is_bitblas_format: bool
is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas" is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas"

View File

@ -7,6 +7,7 @@ import torch
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod, UnquantizedLinearMethod,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
@ -56,7 +57,7 @@ class BitsAndBytesConfig(QuantizationConfig):
f"llm_int8_skip_modules={self.llm_int8_skip_modules})") f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
@classmethod @classmethod
def get_name(self) -> str: def get_name(self) -> QuantizationMethods:
return "bitsandbytes" return "bitsandbytes"
@classmethod @classmethod

View File

@ -16,6 +16,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501 from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
@ -71,7 +72,7 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
return 70 return 70
def get_name(self) -> str: def get_name(self) -> QuantizationMethods:
return "compressed-tensors" return "compressed-tensors"
def get_quant_method( def get_quant_method(

View File

@ -7,6 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
@ -41,8 +42,8 @@ class DeepSpeedFPConfig(QuantizationConfig):
f"group_size={self.group_size}") f"group_size={self.group_size}")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "DeepSpeedFP" return "deepspeedfp"
@classmethod @classmethod
def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig": def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig":

View File

@ -8,6 +8,7 @@ from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
@ -20,7 +21,7 @@ class ExpertsInt8Config(QuantizationConfig):
super().__init__() super().__init__()
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "experts_int8" return "experts_int8"
@classmethod @classmethod

View File

@ -9,6 +9,7 @@ from torch.nn.parameter import Parameter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import ( from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
@ -38,7 +39,7 @@ class FBGEMMFp8Config(QuantizationConfig):
self.fp8_linear = Fp8LinearOp() self.fp8_linear = Fp8LinearOp()
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "fbgemm_fp8" return "fbgemm_fp8"
@classmethod @classmethod

View File

@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported) FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
@ -83,7 +84,7 @@ class Fp8Config(QuantizationConfig):
self.weight_block_size = weight_block_size self.weight_block_size = weight_block_size
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "fp8" return "fp8"
@classmethod @classmethod

View File

@ -13,6 +13,7 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase) FusedMoEMethodBase)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -31,7 +32,7 @@ class GGUFConfig(QuantizationConfig):
def __repr__(self) -> str: def __repr__(self) -> str:
return ("GGUFConfig()") return ("GGUFConfig()")
def get_name(self) -> str: def get_name(self) -> QuantizationMethods:
return "gguf" return "gguf"
def get_supported_act_dtypes(self) -> List[torch.dtype]: def get_supported_act_dtypes(self) -> List[torch.dtype]:

View File

@ -10,6 +10,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.gptq_utils import ( from vllm.model_executor.layers.quantization.utils.gptq_utils import (
@ -79,7 +80,7 @@ class GPTQConfig(QuantizationConfig):
f"dynamic={self.dynamic}") f"dynamic={self.dynamic}")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "gptq" return "gptq"
@classmethod @classmethod

View File

@ -7,6 +7,7 @@ from torch.nn.parameter import Parameter
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
@ -123,7 +124,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
f"quant_method={self.quant_method})") f"quant_method={self.quant_method})")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "gptq_bitblas" return "gptq_bitblas"
@classmethod @classmethod
@ -151,8 +152,8 @@ class GPTQBitBLASConfig(QuantizationConfig):
lm_head_quantized) lm_head_quantized)
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(
user_quant) -> Optional[str]: cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg) can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "bitblas" is_valid_user_quant = (user_quant is None or user_quant == "bitblas"

View File

@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs) set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import ( from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
@ -100,7 +101,7 @@ class GPTQMarlinConfig(QuantizationConfig):
f"dynamic={self.dynamic}") f"dynamic={self.dynamic}")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "gptq_marlin" return "gptq_marlin"
@classmethod @classmethod
@ -130,8 +131,8 @@ class GPTQMarlinConfig(QuantizationConfig):
lm_head_quantized, dynamic, config) lm_head_quantized, dynamic, config)
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(
user_quant) -> Optional[str]: cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg) can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "marlin" is_valid_user_quant = (user_quant is None or user_quant == "marlin"

View File

@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
@ -85,7 +86,7 @@ class GPTQMarlin24Config(QuantizationConfig):
self.quant_type, self.group_size) self.quant_type, self.group_size)
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "gptq_marlin_24" return "gptq_marlin_24"
@classmethod @classmethod
@ -108,8 +109,8 @@ class GPTQMarlin24Config(QuantizationConfig):
return cls(weight_bits, group_size) return cls(weight_bits, group_size)
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(
user_quant) -> Optional[str]: cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
is_marlin_24_format = ( is_marlin_24_format = (
hf_quant_cfg.get("checkpoint_format") == "marlin_24") hf_quant_cfg.get("checkpoint_format") == "marlin_24")

View File

@ -8,6 +8,7 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
@ -50,7 +51,7 @@ class HQQMarlinConfig(QuantizationConfig):
f"group_size={self.group_size})") f"group_size={self.group_size})")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "hqq" return "hqq"
@classmethod @classmethod

View File

@ -6,6 +6,7 @@ import torch
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod, from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod,
is_layer_skipped_awq) is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
@ -58,7 +59,7 @@ class IPEXConfig(QuantizationConfig):
f"group_size={self.group_size})") f"group_size={self.group_size})")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "ipex" return "ipex"
@classmethod @classmethod
@ -97,8 +98,8 @@ class IPEXConfig(QuantizationConfig):
lm_head_quantized) lm_head_quantized)
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(
user_quant) -> Optional[str]: cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
if not current_platform.is_cpu() and not current_platform.is_xpu(): if not current_platform.is_cpu() and not current_platform.is_xpu():
return None return None

View File

@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
@ -63,7 +64,7 @@ class MarlinConfig(QuantizationConfig):
f"lm_head_quantized={self.lm_head_quantized})") f"lm_head_quantized={self.lm_head_quantized})")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "marlin" return "marlin"
@classmethod @classmethod
@ -87,8 +88,8 @@ class MarlinConfig(QuantizationConfig):
return cls(group_size, lm_head_quantized) return cls(group_size, lm_head_quantized)
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(
user_quant) -> Optional[str]: cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
# compat: autogptq >=0.8.0 use checkpoint_format: str # compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_marlin_format: bool # compat: autogptq <=0.7.1 is_marlin_format: bool
is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin" is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin"

View File

@ -11,6 +11,7 @@ from vllm._custom_ops import (cutlass_scaled_fp4_mm,
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
@ -42,7 +43,7 @@ class ModelOptFp8Config(QuantizationConfig):
" the format is experimental and could change.") " the format is experimental and could change.")
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "modelopt" return "modelopt"
@classmethod @classmethod
@ -184,8 +185,8 @@ class ModelOptNvFp4Config(QuantizationConfig):
self.exclude_modules = exclude_modules self.exclude_modules = exclude_modules
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "modelopt_nvfp4" return "nvfp4"
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> List[torch.dtype]:

View File

@ -9,6 +9,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils import ( from vllm.model_executor.layers.quantization.utils.marlin_utils import (
@ -64,7 +65,7 @@ class MoeWNA16Config(QuantizationConfig):
self.modules_to_not_convert = modules_to_not_convert self.modules_to_not_convert = modules_to_not_convert
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "moe_wna16" return "moe_wna16"
@classmethod @classmethod
@ -100,8 +101,8 @@ class MoeWNA16Config(QuantizationConfig):
lm_head_quantized, modules_to_not_convert, config) lm_head_quantized, modules_to_not_convert, config)
@classmethod @classmethod
def override_quantization_method(cls, hf_quant_cfg, def override_quantization_method(
user_quant) -> Optional[str]: cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg) can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
if can_convert and user_quant == "moe_wna16": if can_convert and user_quant == "moe_wna16":
return cls.get_name() return cls.get_name()

View File

@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional
from torch.nn import Module from torch.nn import Module
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
@ -30,7 +31,7 @@ class NeuronQuantConfig(QuantizationConfig):
self.dequant_dtype = dequant_dtype self.dequant_dtype = dequant_dtype
self.quantize_method = quantize_method self.quantize_method = quantize_method
def get_name(self) -> str: def get_name(self) -> QuantizationMethods:
return "neuron_quant" return "neuron_quant"
def get_supported_act_dtypes(self) -> List[str]: def get_supported_act_dtypes(self) -> List[str]:

View File

@ -9,6 +9,7 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase) QuantizeMethodBase)
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config, from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
@ -50,7 +51,7 @@ class PTPCFp8Config(Fp8Config):
ignored_layers=ignored_layers) ignored_layers=ignored_layers)
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "ptpc_fp8" return "ptpc_fp8"
@classmethod @classmethod

View File

@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.parameter import (BasevLLMParameter, from vllm.model_executor.parameter import (BasevLLMParameter,
@ -84,7 +85,7 @@ class QQQConfig(QuantizationConfig):
self.weight_bits, self.group_size) self.weight_bits, self.group_size)
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> QuantizationMethods:
return "qqq" return "qqq"
@classmethod @classmethod

View File

@ -8,6 +8,7 @@ import torch
from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
@ -47,7 +48,7 @@ class QuarkConfig(QuantizationConfig):
def get_min_capability(cls) -> int: def get_min_capability(cls) -> int:
return 70 return 70
def get_name(self) -> str: def get_name(self) -> QuantizationMethods:
return "quark" return "quark"
def get_quant_method(self, layer: torch.nn.Module, def get_quant_method(self, layer: torch.nn.Module,

View File

@ -6,6 +6,7 @@ import torch.nn.functional as F
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
@ -20,7 +21,7 @@ class TorchAOConfig(QuantizationConfig):
def __repr__(self) -> str: def __repr__(self) -> str:
return f"TorchAOConfig({self.torchao_config})" return f"TorchAOConfig({self.torchao_config})"
def get_name(self) -> str: def get_name(self) -> QuantizationMethods:
return "torchao" return "torchao"
def get_supported_act_dtypes(self) -> List[torch.dtype]: def get_supported_act_dtypes(self) -> List[torch.dtype]:

View File

@ -7,6 +7,7 @@ from torch.nn import Module
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.parameter import ModelWeightParameter from vllm.model_executor.parameter import ModelWeightParameter
@ -27,7 +28,7 @@ class Int8TpuConfig(QuantizationConfig):
f"Unsupported activation scheme {activation_scheme}") f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme self.activation_scheme = activation_scheme
def get_name(self) -> str: def get_name(self) -> QuantizationMethods:
return "tpu_int8" return "tpu_int8"
def get_supported_act_dtypes(self) -> List[torch.dtype]: def get_supported_act_dtypes(self) -> List[torch.dtype]:

View File

@ -1496,7 +1496,7 @@ def get_rope(
if key in _ROPE_DICT: if key in _ROPE_DICT:
return _ROPE_DICT[key] return _ROPE_DICT[key]
if rope_scaling is None: if not rope_scaling:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype) is_neox_style, dtype)
else: else:

View File

@ -180,7 +180,6 @@ def _get_neuron_config_after_override(default_neuron_config,
NeuronConfig, QuantizationConfig, NeuronConfig, QuantizationConfig,
SparseAttnConfig) SparseAttnConfig)
overridden_neuron_config = overridden_neuron_config or {}
sparse_attn = overridden_neuron_config.pop("sparse_attn", {}) sparse_attn = overridden_neuron_config.pop("sparse_attn", {})
if sparse_attn: if sparse_attn:
overridden_neuron_config["sparse_attn"] = SparseAttnConfig( overridden_neuron_config["sparse_attn"] = SparseAttnConfig(