Improve configs - ModelConfig (#17130)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor 2025-04-30 03:38:22 +01:00 committed by GitHub
parent 2c4f59afc3
commit 13698db634
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
36 changed files with 490 additions and 648 deletions

View File

@ -738,7 +738,7 @@ class VllmRunner:
- `block_size`: Set to `16` instead of `None` to reduce memory usage.
- `enable_chunked_prefill`: Set to `False` instead of `None` for
test reproducibility.
- `enforce_eager`: Set to `False` instead of `None` to test CUDA graph.
- `enforce_eager`: Set to `False` to test CUDA graph.
"""
def __init__(

View File

@ -8,7 +8,7 @@ from typing import Literal, Optional
import pytest
from vllm.config import PoolerConfig, config
from vllm.config import config
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
get_type, is_not_builtin, is_type,
literal_to_kwargs, nullable_kvs,
@ -222,17 +222,6 @@ def test_prefix_cache_default():
assert not engine_args.enable_prefix_caching
def test_valid_pooling_config():
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
args = parser.parse_args([
'--override-pooler-config',
'{"pooling_type": "MEAN"}',
])
engine_args = EngineArgs.from_cli_args(args=args)
assert engine_args.override_pooler_config == PoolerConfig(
pooling_type="MEAN", )
@pytest.mark.parametrize(
("arg"),
[

View File

@ -14,7 +14,7 @@ import torch.nn.functional as F
from vllm.model_executor.layers.linear import LinearBase # noqa: E501
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization import (
get_quantization_config, register_quantization_config)
QuantizationMethods, get_quantization_config, register_quantization_config)
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig)
@ -54,7 +54,7 @@ class CustomQuantConfig(QuantizationConfig):
"""Initialize the quantization config."""
self.num_bits = num_bits
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
"""Name of the quantization method."""
return "custom_quant"

View File

@ -185,7 +185,7 @@ def test_get_pooling_config():
revision=None,
)
pooling_config = model_config._init_pooler_config(None)
pooling_config = model_config._init_pooler_config()
assert pooling_config is not None
assert pooling_config.normalize
@ -205,11 +205,12 @@ def test_get_pooling_config_from_args():
dtype="float16",
revision=None)
override_config = PoolerConfig(pooling_type='CLS', normalize=True)
override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True)
model_config.override_pooler_config = override_pooler_config
pooling_config = model_config._init_pooler_config(override_config)
pooling_config = model_config._init_pooler_config()
assert pooling_config is not None
assert asdict(pooling_config) == asdict(override_config)
assert asdict(pooling_config) == asdict(override_pooler_config)
@pytest.mark.skipif(current_platform.is_rocm(),

View File

@ -16,9 +16,8 @@ from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
replace)
from importlib.util import find_spec
from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
Optional, Protocol, TypeVar, Union, cast, get_args,
get_origin)
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
Protocol, TypeVar, Union, cast, get_args, get_origin)
import torch
from pydantic import BaseModel, Field, PrivateAttr
@ -211,103 +210,190 @@ def get_field(cls: ConfigType, name: str) -> Field:
f"{cls.__name__}.{name} must have a default value or default factory.")
class ModelConfig:
"""Configuration for the model.
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
Args:
model: Name or path of the huggingface model to use.
It is also used as the content for `model_name` tag in metrics
output when `served_model_name` is not specified.
task: The task to use the model for. Each vLLM instance only supports
one task, even if the same model can be used for multiple tasks.
When the model only supports one task, "auto" can be used to select
it; otherwise, you must specify explicitly which task to use.
tokenizer: Name or path of the huggingface tokenizer to use.
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
available, "slow" will always use the slow tokenizer,
"mistral" will always use the tokenizer from `mistral_common`, and
"custom" will use --tokenizer to select the preregistered tokenizer.
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
downloading the model and tokenizer.
allowed_local_media_path: Allowing API requests to read local images or
videos from directories specified by the server file system.
This is a security risk. Should only be enabled in trusted
environments.
dtype: Data type for model weights and activations. The "auto" option
will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models.
seed: Random seed for reproducibility.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id. If unspecified, will use the default
version.
code_revision: The specific revision to use for the model code on
Hugging Face Hub. It can be a branch name, a tag name, or a
commit id. If unspecified, will use the default version.
tokenizer_revision: The specific tokenizer version to use. It can be a
branch name, a tag name, or a commit id. If unspecified, will use
the default version.
max_model_len: Maximum length of a sequence (including prompt and
output). If None, will be derived from the model.
spec_target_max_model_len: Specify the the maximum length for spec
decoding draft models.
quantization: Quantization method that was used to quantize the model
weights. If None, we assume the model weights are not quantized.
enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
If None, the user did not specify, so default to False.
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
When a sequence has context length larger than this, we fall back
to eager mode. Additionally for encoder-decoder models, if the
sequence length of the encoder input is larger than this, we fall
back to the eager mode.
max_logprobs: Maximum number of log probabilities. Defaults to 20.
disable_sliding_window: Whether to disable sliding window. If True,
we will disable the sliding window functionality of the model.
If the model does not support sliding window, this argument is
ignored.
skip_tokenizer_init: If true, skip initialization of tokenizer and
detokenizer.
served_model_name: The model name used in metrics tag `model_name`,
matches the model name exposed via the APIs. If multiple model
names provided, the first name will be used. If not specified,
the model name will be the same as `model`.
limit_mm_per_prompt: Maximum number of data items per modality
per prompt. Only applicable for multimodal models.
mm_processor_kwargs: Overrides for the multi-modal processor obtained
from `AutoProcessor.from_pretrained`.
disable_mm_preprocessor_cache: If True, disable caching of the
processed multi-modal inputs.
use_async_output_proc: Whether to use async output processor.
Defaults to True.
config_format: The config format which shall be loaded.
Defaults to 'auto' which defaults to 'hf'.
hf_token: The token to use as HTTP bearer authorization for remote files
. If `True`, will use the token generated when running
`huggingface-cli login` (stored in `~/.huggingface`).
hf_overrides: If a dictionary, contains arguments to be forwarded to the
HuggingFace config. If a callable, it is called to update the
HuggingFace config.
override_neuron_config: Initialize non default neuron config or
override default neuron config that are specific to Neuron devices,
this argument will be used to configure the neuron config that
can not be gathered from the vllm arguments.
override_pooler_config: Initialize non default pooling config or
override default pooling config for the pooling model.
logits_processor_pattern: Optional regex pattern specifying valid
logits processor qualified names that can be passed with the
`logits_processors` extra completion argument. Defaults to None,
which allows no processors.
generation_config: Configuration parameter file for generation.
model_impl: Which implementation of the model to use:
"auto" will try to use the vLLM implementation if it exists and
fall back to the Transformers implementation if no vLLM
implementation is available.
"vllm" will use the vLLM model implementation.
"transformers" will use the Transformers model implementation.
override_generation_config: Override the generation config with the
given config.
"""
@config
@dataclass
class ModelConfig:
"""Configuration for the model."""
model: str = "facebook/opt-125m"
"""Name or path of the Hugging Face model to use. It is also used as the
content for `model_name` tag in metrics output when `served_model_name` is
not specified."""
task: Literal[TaskOption, Literal["draft"]] = "auto"
"""The task to use the model for. Each vLLM instance only supports one
task, even if the same model can be used for multiple tasks. When the model
only supports one task, "auto" can be used to select it; otherwise, you
must specify explicitly which task to use."""
tokenizer: str = None # type: ignore
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
name or path will be used."""
tokenizer_mode: TokenizerMode = "auto"
"""Tokenizer mode:\n
- "auto" will use the fast tokenizer if available.\n
- "slow" will always use the slow tokenizer.\n
- "mistral" will always use the tokenizer from `mistral_common`.\n
- "custom" will use --tokenizer to select the preregistered tokenizer."""
trust_remote_code: bool = False
"""Trust remote code (e.g., from HuggingFace) when downloading the model
and tokenizer."""
dtype: Union[ModelDType, torch.dtype] = "auto"
"""Data type for model weights and activations:\n
- "auto" will use FP16 precision for FP32 and FP16 models, and BF16
precision for BF16 models.\n
- "half" for FP16. Recommended for AWQ quantization.\n
- "float16" is the same as "half".\n
- "bfloat16" for a balance between precision and range.\n
- "float" is shorthand for FP32 precision.\n
- "float32" for FP32 precision."""
seed: Optional[int] = None
"""Random seed for reproducibility."""
hf_config_path: Optional[str] = None
"""Name or path of the Hugging Face config to use. If unspecified, model
name or path will be used."""
allowed_local_media_path: str = ""
"""Allowing API requests to read local images or videos from directories
specified by the server file system. This is a security risk. Should only
be enabled in trusted environments."""
revision: Optional[str] = None
"""The specific model version to use. It can be a branch name, a tag name,
or a commit id. If unspecified, will use the default version."""
code_revision: Optional[str] = None
"""The specific revision to use for the model code on the Hugging Face Hub.
It can be a branch name, a tag name, or a commit id. If unspecified, will
use the default version."""
rope_scaling: dict[str, Any] = field(default_factory=dict)
"""RoPE scaling configuration in JSON format. For example,
`{"rope_type":"dynamic","factor":2.0}`."""
rope_theta: Optional[float] = None
"""RoPE theta. Use with `rope_scaling`. In some cases, changing the RoPE
theta improves the performance of the scaled model."""
tokenizer_revision: Optional[str] = None
"""The specific revision to use for the tokenizer on the Hugging Face Hub.
It can be a branch name, a tag name, or a commit id. If unspecified, will
use the default version."""
max_model_len: int = None # type: ignore
"""Model context length (prompt and output). If unspecified, will be
automatically derived from the model config.
When passing via `--max-model-len`, supports k/m/g/K/M/G in human-readable
format. Examples:\n
- 1k -> 1000\n
- 1K -> 1024\n
- 25.6k -> 25,600"""
spec_target_max_model_len: Optional[int] = None
"""Specify the the maximum length for spec decoding draft models."""
quantization: Optional[QuantizationMethods] = None
"""Method used to quantize the weights. If `None`, we first check the
`quantization_config` attribute in the model config file. If that is
`None`, we assume the model weights are not quantized and use `dtype` to
determine the data type of the weights."""
enforce_eager: bool = False
"""Whether to always use eager-mode PyTorch. If True, we will disable CUDA
graph and always execute the model in eager mode. If False, we will use
CUDA graph and eager execution in hybrid for maximal performance and
flexibility."""
max_seq_len_to_capture: int = 8192
"""Maximum sequence len covered by CUDA graphs. When a sequence has context
length larger than this, we fall back to eager mode. Additionally for
encoder-decoder models, if the sequence length of the encoder input is
larger than this, we fall back to the eager mode."""
max_logprobs: int = 20
"""Maximum number of log probabilities to return when `logprobs` is
specified in `SamplingParams`. The default value comes the default for the
OpenAI Chat Completions API."""
disable_sliding_window: bool = False
"""Whether to disable sliding window. If True, we will disable the sliding
window functionality of the model, capping to sliding window size. If the
model does not support sliding window, this argument is ignored."""
disable_cascade_attn: bool = False
"""Disable cascade attention for V1. While cascade attention does not
change the mathematical correctness, disabling it could be useful for
preventing potential numerical issues. Note that even if this is set to
False, cascade attention will be only used when the heuristic tells that
it's beneficial."""
skip_tokenizer_init: bool = False
"""Skip initialization of tokenizer and detokenizer. Expects valid
`prompt_token_ids` and `None` for prompt from the input. The generated
output will contain token ids."""
served_model_name: Optional[Union[str, list[str]]] = None
"""The model name(s) used in the API. If multiple names are provided, the
server will respond to any of the provided names. The model name in the
model field of a response will be the first name in this list. If not
specified, the model name will be the same as the `--model` argument. Noted
that this name(s) will also be used in `model_name` tag content of
prometheus metrics, if multiple names provided, metrics tag will take the
first one."""
limit_mm_per_prompt: dict[str, int] = field(default_factory=dict)
"""Maximum number of data items per modality per prompt. Only applicable
for multimodal models."""
use_async_output_proc: bool = True
"""Whether to use async output processor."""
config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value
"""The format of the model config to load:\n
- "auto" will try to load the config in hf format if available else it
will try to load in mistral format.\n
- "hf" will load the config in hf format.\n
- "mistral" will load the config in mistral format."""
hf_token: Optional[Union[bool, str]] = None
"""The token to use as HTTP bearer authorization for remote files . If
`True`, will use the token generated when running `huggingface-cli login`
(stored in `~/.huggingface`)."""
hf_overrides: HfOverrides = field(default_factory=dict)
"""If a dictionary, contains arguments to be forwarded to the Hugging Face
config. If a callable, it is called to update the HuggingFace config. When
specified via CLI, the argument must be a valid JSON string."""
mm_processor_kwargs: Optional[dict[str, Any]] = None
"""Arguments to be forwarded to the model's processor for multi-modal data,
e.g., image processor. Overrides for the multi-modal processor obtained
from `AutoProcessor.from_pretrained`. The available overrides depend on the
model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`.
When specified via CLI, the argument must be a valid JSON string."""
disable_mm_preprocessor_cache: bool = False
"""If `True`, disable caching of the multi-modal preprocessor/mapper (not
recommended)."""
override_neuron_config: dict[str, Any] = field(default_factory=dict)
"""Initialize non-default neuron config or override default neuron config
that are specific to Neuron devices, this argument will be used to
configure the neuron config that can not be gathered from the vllm
arguments. e.g. `{"cast_logits_dtype": "bloat16"}`. When specified via CLI,
the argument must be a valid JSON string."""
pooler_config: Optional["PoolerConfig"] = field(init=False)
"""Pooler config which controls the behaviour of output pooling in pooling
models."""
override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None
"""Initialize non-default pooling config or override default pooling config
for the pooling model. e.g. `{"pooling_type": "mean", "normalize": false}`.
When specified via CLI, the argument must be a valid JSON string."""
logits_processor_pattern: Optional[str] = None
"""Optional regex pattern specifying valid logits processor qualified names
that can be passed with the `logits_processors` extra completion argument.
Defaults to `None`, which allows no processors."""
generation_config: str = "auto"
"""The folder path to the generation config. Defaults to `"auto"`, the
generation config will be loaded from model path. If set to `"vllm"`, no
generation config is loaded, vLLM defaults will be used. If set to a folder
path, the generation config will be loaded from the specified folder path.
If `max_new_tokens` is specified in generation config, then it sets a
server-wide limit on the number of output tokens for all requests."""
override_generation_config: dict[str, Any] = field(default_factory=dict)
"""Overrides or sets generation config. e.g. `{"temperature": 0.5}`. If
used with `--generation-config auto`, the override parameters will be
merged with the default config from the model. If used with
`--generation-config vllm`, only the override parameters are used.
When specified via CLI, the argument must be a valid JSON string."""
enable_sleep_mode: bool = False
"""Enable sleep mode for the engine (only cuda platform is supported)."""
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO.value
"""Which implementation of the model to use:\n
- "auto" will try to use the vLLM implementation, if it exists, and fall
back to the Transformers implementation if no vLLM implementation is
available.\n
- "vllm" will use the vLLM model implementation.\n
- "transformers" will use the Transformers model implementation."""
def compute_hash(self) -> str:
"""
@ -342,92 +428,43 @@ class ModelConfig:
assert_hashable(str_factors)
return hashlib.sha256(str(factors).encode()).hexdigest()
def __init__(
self,
model: str,
task: Literal[TaskOption, Literal["draft"]],
tokenizer: str,
tokenizer_mode: str,
trust_remote_code: bool,
dtype: Union[str, torch.dtype],
seed: int,
hf_config_path: Optional[str] = None,
allowed_local_media_path: str = "",
revision: Optional[str] = None,
code_revision: Optional[str] = None,
rope_scaling: Optional[dict[str, Any]] = None,
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
enforce_eager: Optional[bool] = None,
max_seq_len_to_capture: Optional[int] = None,
max_logprobs: int = 20,
disable_sliding_window: bool = False,
disable_cascade_attn: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, list[str]]] = None,
limit_mm_per_prompt: Optional[dict[str, int]] = None,
mm_processor_kwargs: Optional[dict[str, Any]] = None,
disable_mm_preprocessor_cache: bool = False,
use_async_output_proc: bool = True,
config_format: ConfigFormat = ConfigFormat.AUTO,
hf_token: Optional[Union[bool, str]] = None,
hf_overrides: Optional[HfOverrides] = None,
override_neuron_config: Optional[dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None,
generation_config: str = "auto",
enable_sleep_mode: bool = False,
override_generation_config: Optional[dict[str, Any]] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
) -> None:
self.model = maybe_model_redirect(model)
self.tokenizer = maybe_model_redirect(tokenizer)
def __post_init__(self) -> None:
self.model = maybe_model_redirect(self.model)
# The tokenizer is consistent with the model by default.
if self.tokenizer is None:
self.tokenizer = self.model
if self.tokenizer_revision is None:
self.tokenizer_revision = self.revision
self.tokenizer = maybe_model_redirect(self.tokenizer)
self.hf_config_path = hf_config_path
if isinstance(hf_config_path, str):
self.hf_config_path = maybe_model_redirect(hf_config_path)
if isinstance(self.hf_config_path, str):
self.hf_config_path = maybe_model_redirect(self.hf_config_path)
self.tokenizer_mode = tokenizer_mode
self.trust_remote_code = trust_remote_code
self.allowed_local_media_path = allowed_local_media_path
self.seed = seed
self.revision = revision
self.code_revision = code_revision
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
self.model_impl = model_impl
if hf_overrides is None:
hf_overrides = {}
if callable(hf_overrides):
if callable(self.hf_overrides):
hf_overrides_kw = {}
hf_overrides_fn = hf_overrides
hf_overrides_fn = self.hf_overrides
else:
hf_overrides_kw = hf_overrides
hf_overrides_kw = self.hf_overrides
hf_overrides_fn = None
if rope_scaling is not None:
hf_override: dict[str, Any] = {"rope_scaling": rope_scaling}
if self.rope_scaling:
hf_override: dict[str, Any] = {"rope_scaling": self.rope_scaling}
hf_overrides_kw.update(hf_override)
hf_overrides_str = json.dumps(hf_overrides)
hf_overrides_str = json.dumps(hf_overrides_kw)
msg = (
"`--rope-scaling` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_overrides_str}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
if rope_theta is not None:
hf_override = {"rope_theta": rope_theta}
if self.rope_theta is not None:
hf_override = {"rope_theta": self.rope_theta}
hf_overrides_kw.update(hf_override)
hf_overrides_str = json.dumps(hf_overrides)
hf_overrides_str = json.dumps(hf_overrides_kw)
msg = (
"`--rope-theta` will be removed in a future release. "
f"'Please instead use `--hf-overrides '{hf_overrides_str}'`")
warnings.warn(DeprecationWarning(msg), stacklevel=2)
self.maybe_pull_model_tokenizer_for_s3(model, tokenizer)
self.maybe_pull_model_tokenizer_for_s3(self.model, self.tokenizer)
if (backend := envs.VLLM_ATTENTION_BACKEND
) and backend == "FLASHINFER" and find_spec("flashinfer") is None:
@ -437,20 +474,6 @@ class ModelConfig:
"https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501
"for instructions on how to install it.")
# The tokenizer version is consistent with the model version by default.
if tokenizer_revision is None:
self.tokenizer_revision = revision
else:
self.tokenizer_revision = tokenizer_revision
self.quantization = quantization
self.enforce_eager = enforce_eager
self.max_seq_len_to_capture = max_seq_len_to_capture
self.max_logprobs = max_logprobs
self.disable_sliding_window = disable_sliding_window
self.disable_cascade_attn = disable_cascade_attn
self.skip_tokenizer_init = skip_tokenizer_init
self.enable_sleep_mode = enable_sleep_mode
from vllm.platforms import current_platform
if (self.enable_sleep_mode
@ -458,9 +481,12 @@ class ModelConfig:
raise ValueError(
"Sleep mode is not supported on current platform.")
if isinstance(self.config_format, str):
self.config_format = ConfigFormat(self.config_format)
hf_config = get_config(self.hf_config_path or self.model,
trust_remote_code, revision, code_revision,
config_format)
self.trust_remote_code, self.revision,
self.code_revision, self.config_format)
if hf_overrides_kw:
logger.info("Overriding HF config with %s", hf_overrides_kw)
@ -476,13 +502,8 @@ class ModelConfig:
"attention_chunk_size", None)
self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, hf_token=hf_token, revision=revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self.use_async_output_proc = use_async_output_proc
# Set enforce_eager to False if the value is unset.
if self.enforce_eager is None:
self.enforce_eager = False
self.model, hf_token=self.hf_token, revision=self.revision)
self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype)
interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"]
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
@ -515,18 +536,14 @@ class ModelConfig:
self.max_model_len = _get_and_verify_max_len(
hf_config=self.hf_text_config,
max_model_len=max_model_len,
max_model_len=self.max_model_len,
disable_sliding_window=self.disable_sliding_window,
sliding_window_len=self.get_hf_config_sliding_window(),
spec_target_max_model_len=spec_target_max_model_len,
spec_target_max_model_len=self.spec_target_max_model_len,
encoder_config=self.encoder_config)
self.served_model_name = get_served_model_name(model,
served_model_name)
self.multimodal_config = self._init_multimodal_config(
limit_mm_per_prompt=limit_mm_per_prompt,
mm_processor_kwargs=mm_processor_kwargs,
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
)
self.served_model_name = get_served_model_name(self.model,
self.served_model_name)
self.multimodal_config = self._init_multimodal_config()
if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
@ -535,24 +552,19 @@ class ModelConfig:
self.has_noops = self._init_has_noops()
self.has_inner_state = self._init_has_inner_state()
if current_platform.is_neuron():
self.override_neuron_config = override_neuron_config
else:
self.override_neuron_config = None
if (not current_platform.is_neuron() and self.override_neuron_config):
raise ValueError(
"`override_neuron_config` is only supported on Neuron.")
supported_tasks, task = self._resolve_task(task)
supported_tasks, task = self._resolve_task(self.task)
self.supported_tasks = supported_tasks
self.task: Final = task
self.task = task
if self.task in ("draft", "generate"):
self.truncation_side = "left"
else:
self.truncation_side = "right"
self.pooler_config = self._init_pooler_config(override_pooler_config)
self.logits_processor_pattern = logits_processor_pattern
self.generation_config = generation_config
self.override_generation_config = override_generation_config or {}
self.pooler_config = self._init_pooler_config()
self._verify_quantization()
self._verify_cuda_graph()
@ -591,26 +603,21 @@ class ModelConfig:
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
self.tokenizer = s3_tokenizer.dir
def _init_multimodal_config(
self,
limit_mm_per_prompt: Optional[dict[str, int]],
mm_processor_kwargs: Optional[dict[str, Any]],
disable_mm_preprocessor_cache: bool,
) -> Optional["MultiModalConfig"]:
def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:
if self.registry.is_multimodal_model(self.architectures):
return MultiModalConfig(
limit_per_prompt=limit_mm_per_prompt or {},
mm_processor_kwargs=mm_processor_kwargs or {},
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
)
limit_per_prompt=self.limit_mm_per_prompt,
mm_processor_kwargs=self.mm_processor_kwargs,
disable_mm_preprocessor_cache=self.
disable_mm_preprocessor_cache)
if limit_mm_per_prompt:
if self.limit_mm_per_prompt:
raise ValueError("`limit_mm_per_prompt` is only supported for "
"multimodal models.")
if mm_processor_kwargs:
if self.mm_processor_kwargs:
raise ValueError("`mm_processor_kwargs` is only supported for "
"multimodal models.")
if disable_mm_preprocessor_cache:
if self.disable_mm_preprocessor_cache:
raise ValueError("`disable_mm_preprocessor_cache` is only "
"supported for multimodal models.")
@ -620,31 +627,32 @@ class ModelConfig:
return get_sentence_transformer_tokenizer_config(
self.model, self.revision)
def _init_pooler_config(
self,
override_pooler_config: Optional["PoolerConfig"],
) -> Optional["PoolerConfig"]:
def _init_pooler_config(self) -> Optional["PoolerConfig"]:
if self.runner_type == "pooling":
user_config = override_pooler_config or PoolerConfig()
if isinstance(self.override_pooler_config, dict):
self.override_pooler_config = PoolerConfig(
**self.override_pooler_config)
pooler_config = self.override_pooler_config or PoolerConfig()
base_config = get_pooling_config(self.model, self.revision)
if base_config is not None:
# Only set values that are not overridden by the user
for k, v in base_config.items():
if getattr(user_config, k) is None:
setattr(user_config, k, v)
if getattr(pooler_config, k) is None:
setattr(pooler_config, k, v)
if self.is_matryoshka:
if user_config.normalize is None:
user_config.normalize = True
elif not user_config.normalize:
if pooler_config.normalize is None:
pooler_config.normalize = True
elif not pooler_config.normalize:
raise ValueError(
"`normalize` must be enabled (set to True) "
"for models that are compatible with "
"Matryoshka Representation.")
return user_config
return pooler_config
return None
@ -662,11 +670,11 @@ class ModelConfig:
return self.registry.model_has_inner_state(self.architectures)
def _verify_tokenizer_mode(self) -> None:
tokenizer_mode = self.tokenizer_mode.lower()
if tokenizer_mode not in ["auto", "slow", "mistral", "custom"]:
tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower())
if tokenizer_mode not in get_args(TokenizerMode):
raise ValueError(
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
"either 'auto', 'slow', 'mistral' or 'custom'.")
f"one of {get_args(TokenizerMode)}.")
self.tokenizer_mode = tokenizer_mode
def _get_preferred_task(
@ -781,7 +789,8 @@ class ModelConfig:
"quark", "nvfp4", "bitblas", "gptq_bitblas"
]
if self.quantization is not None:
self.quantization = self.quantization.lower()
self.quantization = cast(QuantizationMethods,
self.quantization.lower())
# Parse quantization method from the HF model config, if available.
quant_cfg = self._parse_quant_hf_config()
@ -857,8 +866,6 @@ class ModelConfig:
"non-quantized models.", self.quantization)
def _verify_cuda_graph(self) -> None:
if self.max_seq_len_to_capture is None:
self.max_seq_len_to_capture = self.max_model_len
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
self.max_model_len)
ROCM_UNSUPPORTED_MODELS = ['mllama']
@ -1294,7 +1301,7 @@ class ModelConfig:
@property
def runner_type(self) -> RunnerType:
return _TASK_RUNNER[self.task]
return _TASK_RUNNER[cast(_ResolvedTask, self.task)]
@property
def is_v1_compatible(self) -> bool:
@ -2201,7 +2208,7 @@ class SpeculativeConfig:
according to the log probability settings in SamplingParams."""
# Draft model configuration
quantization: Optional[str] = None
quantization: Optional[QuantizationMethods] = None
"""Quantization method that was used to quantize the draft model weights.
If `None`, we assume the model weights are not quantized. Note that it only
takes effect when using the draft model-based speculative method."""
@ -2386,7 +2393,6 @@ class SpeculativeConfig:
code_revision=self.code_revision,
tokenizer_revision=self.target_model_config.
tokenizer_revision,
max_model_len=None,
spec_target_max_model_len=self.target_model_config.
max_model_len,
quantization=self.quantization,
@ -2793,30 +2799,31 @@ class PromptAdapterConfig:
class MultiModalConfig:
"""Controls the behavior of multimodal models."""
limit_per_prompt: dict[str, int] = field(default_factory=dict)
limit_per_prompt: dict[str, int] = get_field(ModelConfig,
"limit_mm_per_prompt")
"""
The maximum number of input items allowed per prompt for each modality.
This should be a JSON string that will be parsed into a dictionary.
Defaults to 1 (V0) or 999 (V1) for each modality.
For example, to allow up to 16 images and 2 videos per prompt:
:code:`{"images": 16, "videos": 2}`
`{"images": 16, "videos": 2}`
"""
mm_processor_kwargs: Optional[dict[str, object]] = None
"""
Overrides for the multi-modal processor obtained from
:meth:`transformers.AutoProcessor.from_pretrained`.
`transformers.AutoProcessor.from_pretrained`.
The available overrides depend on the model that is being run.
For example, for Phi-3-Vision:
:code:`{"num_crops": 4}`.
`{"num_crops": 4}`.
"""
disable_mm_preprocessor_cache: bool = False
"""
If :code:`True`, disable caching of the processed multi-modal inputs.
If `True`, disable caching of the processed multi-modal inputs.
"""
def compute_hash(self) -> str:
@ -2907,10 +2914,6 @@ class PoolerConfig:
usedforsecurity=False).hexdigest()
return hash_str
@staticmethod
def from_json(json_str: str) -> "PoolerConfig":
return PoolerConfig(**json.loads(json_str))
_STR_DTYPE_TO_TORCH_DTYPE = {
"half": torch.float16,

View File

@ -20,15 +20,16 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DeviceConfig, DistributedExecutorBackend,
GuidedDecodingBackend, GuidedDecodingBackendV1,
HfOverrides, KVTransferConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, ModelImpl, MultiModalConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
PrefixCachingHashAlgo, PromptAdapterConfig,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerPoolConfig, VllmConfig,
get_attr_docs, get_field)
LoRAConfig, ModelConfig, ModelDType, ModelImpl,
MultiModalConfig, ObservabilityConfig, ParallelConfig,
PoolerConfig, PrefixCachingHashAlgo,
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
SpeculativeConfig, TaskOption, TokenizerMode,
TokenizerPoolConfig, VllmConfig, get_attr_docs,
get_field)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.plugins import load_general_plugins
from vllm.reasoning import ReasoningParserManager
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
@ -183,6 +184,9 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
kwargs[name]["nargs"] = "+"
elif contains_type(type_hints, int):
kwargs[name]["type"] = int
# Special case for large integers
if name in {"max_model_len"}:
kwargs[name]["type"] = human_readable_int
elif contains_type(type_hints, float):
kwargs[name]["type"] = float
elif contains_type(type_hints, dict):
@ -212,22 +216,23 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
@dataclass
class EngineArgs:
"""Arguments for vLLM engine."""
model: str = 'facebook/opt-125m'
served_model_name: Optional[Union[str, List[str]]] = None
tokenizer: Optional[str] = None
hf_config_path: Optional[str] = None
task: TaskOption = "auto"
skip_tokenizer_init: bool = False
tokenizer_mode: str = 'auto'
trust_remote_code: bool = False
allowed_local_media_path: str = ""
model: str = ModelConfig.model
served_model_name: Optional[Union[
str, List[str]]] = ModelConfig.served_model_name
tokenizer: Optional[str] = ModelConfig.tokenizer
hf_config_path: Optional[str] = ModelConfig.hf_config_path
task: TaskOption = ModelConfig.task
skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
trust_remote_code: bool = ModelConfig.trust_remote_code
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
download_dir: Optional[str] = LoadConfig.download_dir
load_format: str = LoadConfig.load_format
config_format: ConfigFormat = ConfigFormat.AUTO
dtype: str = 'auto'
config_format: str = ModelConfig.config_format
dtype: ModelDType = ModelConfig.dtype
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
seed: Optional[int] = None
max_model_len: Optional[int] = None
seed: Optional[int] = ModelConfig.seed
max_model_len: Optional[int] = ModelConfig.max_model_len
# Note: Specifying a custom executor backend by passing a class
# is intended for expert use only. The API may change without
# notice.
@ -245,8 +250,8 @@ class EngineArgs:
enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
prefix_caching_hash_algo: PrefixCachingHashAlgo = \
CacheConfig.prefix_caching_hash_algo
disable_sliding_window: bool = False
disable_cascade_attn: bool = False
disable_sliding_window: bool = ModelConfig.disable_sliding_window
disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
use_v2_block_manager: bool = True
swap_space: float = CacheConfig.swap_space
cpu_offload_gb: float = CacheConfig.cpu_offload_gb
@ -258,18 +263,19 @@ class EngineArgs:
long_prefill_token_threshold: int = \
SchedulerConfig.long_prefill_token_threshold
max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
max_logprobs: int = ModelConfig.max_logprobs
disable_log_stats: bool = False
revision: Optional[str] = None
code_revision: Optional[str] = None
rope_scaling: Optional[Dict[str, Any]] = None
rope_theta: Optional[float] = None
hf_token: Optional[Union[bool, str]] = None
hf_overrides: Optional[HfOverrides] = None
tokenizer_revision: Optional[str] = None
quantization: Optional[str] = None
enforce_eager: Optional[bool] = None
max_seq_len_to_capture: int = 8192
revision: Optional[str] = ModelConfig.revision
code_revision: Optional[str] = ModelConfig.code_revision
rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
rope_theta: Optional[float] = ModelConfig.rope_theta
hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token
hf_overrides: Optional[HfOverrides] = \
get_field(ModelConfig, "hf_overrides")
tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
quantization: Optional[QuantizationMethods] = ModelConfig.quantization
enforce_eager: bool = ModelConfig.enforce_eager
max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
# The following three fields are deprecated and will be removed in a future
# release. Setting them will have no effect. Please remove them from your
@ -280,8 +286,10 @@ class EngineArgs:
get_field(TokenizerPoolConfig, "extra_config")
limit_mm_per_prompt: dict[str, int] = \
get_field(MultiModalConfig, "limit_per_prompt")
mm_processor_kwargs: Optional[Dict[str, Any]] = None
disable_mm_preprocessor_cache: bool = False
mm_processor_kwargs: Optional[Dict[str, Any]] = \
MultiModalConfig.mm_processor_kwargs
disable_mm_preprocessor_cache: bool = \
MultiModalConfig.disable_mm_preprocessor_cache
# LoRA fields
enable_lora: bool = False
enable_lora_bias: bool = LoRAConfig.bias_enabled
@ -323,7 +331,8 @@ class EngineArgs:
DecodingConfig.disable_any_whitespace
guided_decoding_disable_additional_properties: bool = \
DecodingConfig.disable_additional_properties
logits_processor_pattern: Optional[str] = None
logits_processor_pattern: Optional[
str] = ModelConfig.logits_processor_pattern
speculative_config: Optional[Dict[str, Any]] = None
@ -331,22 +340,25 @@ class EngineArgs:
show_hidden_metrics_for_version: Optional[str] = None
otlp_traces_endpoint: Optional[str] = None
collect_detailed_traces: Optional[str] = None
disable_async_output_proc: bool = False
disable_async_output_proc: bool = not ModelConfig.use_async_output_proc
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
override_neuron_config: Optional[Dict[str, Any]] = None
override_pooler_config: Optional[PoolerConfig] = None
override_neuron_config: dict[str, Any] = \
get_field(ModelConfig, "override_neuron_config")
override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
ModelConfig.override_pooler_config
compilation_config: Optional[CompilationConfig] = None
worker_cls: str = ParallelConfig.worker_cls
worker_extension_cls: str = ParallelConfig.worker_extension_cls
kv_transfer_config: Optional[KVTransferConfig] = None
generation_config: Optional[str] = "auto"
override_generation_config: Optional[Dict[str, Any]] = None
enable_sleep_mode: bool = False
model_impl: str = "auto"
generation_config: str = ModelConfig.generation_config
enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
override_generation_config: dict[str, Any] = \
get_field(ModelConfig, "override_generation_config")
model_impl: str = ModelConfig.model_impl
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
@ -356,9 +368,6 @@ class EngineArgs:
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
def __post_init__(self):
if not self.tokenizer:
self.tokenizer = self.model
# support `EngineArgs(compilation_config={...})`
# without having to manually construct a
# CompilationConfig object
@ -375,80 +384,87 @@ class EngineArgs:
"""Shared CLI arguments for vLLM engine."""
# Model arguments
parser.add_argument(
'--model',
type=str,
default=EngineArgs.model,
help='Name or path of the huggingface model to use.')
parser.add_argument(
'--task',
default=EngineArgs.task,
choices=get_args(TaskOption),
help='The task to use the model for. Each vLLM instance only '
'supports one task, even if the same model can be used for '
'multiple tasks. When the model only supports one task, ``"auto"`` '
'can be used to select it; otherwise, you must specify explicitly '
'which task to use.')
parser.add_argument(
'--tokenizer',
type=optional_type(str),
default=EngineArgs.tokenizer,
help='Name or path of the huggingface tokenizer to use. '
'If unspecified, model name or path will be used.')
parser.add_argument(
"--hf-config-path",
type=optional_type(str),
default=EngineArgs.hf_config_path,
help='Name or path of the huggingface config to use. '
'If unspecified, model name or path will be used.')
parser.add_argument(
'--skip-tokenizer-init',
action='store_true',
help='Skip initialization of tokenizer and detokenizer. '
'Expects valid prompt_token_ids and None for prompt from '
'the input. The generated output will contain token ids.')
parser.add_argument(
'--revision',
type=optional_type(str),
default=None,
help='The specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument(
'--code-revision',
type=optional_type(str),
default=None,
help='The specific revision to use for the model code on '
'Hugging Face Hub. It can be a branch name, a tag name, or a '
'commit id. If unspecified, will use the default version.')
parser.add_argument(
'--tokenizer-revision',
type=optional_type(str),
default=None,
help='Revision of the huggingface tokenizer to use. '
'It can be a branch name, a tag name, or a commit id. '
'If unspecified, will use the default version.')
parser.add_argument(
'--tokenizer-mode',
type=str,
default=EngineArgs.tokenizer_mode,
choices=['auto', 'slow', 'mistral', 'custom'],
help='The tokenizer mode.\n\n* "auto" will use the '
'fast tokenizer if available.\n* "slow" will '
'always use the slow tokenizer. \n* '
'"mistral" will always use the `mistral_common` tokenizer. \n* '
'"custom" will use --tokenizer to select the '
'preregistered tokenizer.')
parser.add_argument('--trust-remote-code',
action='store_true',
help='Trust remote code from huggingface.')
parser.add_argument(
'--allowed-local-media-path',
type=str,
help="Allowing API requests to read local images or videos "
"from directories specified by the server file system. "
"This is a security risk. "
"Should only be enabled in trusted environments.")
model_kwargs = get_kwargs(ModelConfig)
model_group = parser.add_argument_group(
title="ModelConfig",
description=ModelConfig.__doc__,
)
model_group.add_argument("--model", **model_kwargs["model"])
model_group.add_argument("--task", **model_kwargs["task"])
model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
model_group.add_argument("--tokenizer-mode",
**model_kwargs["tokenizer_mode"])
model_group.add_argument("--trust-remote-code",
**model_kwargs["trust_remote_code"])
model_group.add_argument("--dtype", **model_kwargs["dtype"])
model_group.add_argument("--seed", **model_kwargs["seed"])
model_group.add_argument("--hf-config-path",
**model_kwargs["hf_config_path"])
model_group.add_argument("--allowed-local-media-path",
**model_kwargs["allowed_local_media_path"])
model_group.add_argument("--revision", **model_kwargs["revision"])
model_group.add_argument("--code-revision",
**model_kwargs["code_revision"])
model_group.add_argument("--rope-scaling",
**model_kwargs["rope_scaling"])
model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"])
model_group.add_argument("--tokenizer-revision",
**model_kwargs["tokenizer_revision"])
model_group.add_argument("--max-model-len",
**model_kwargs["max_model_len"])
model_group.add_argument("--quantization", "-q",
**model_kwargs["quantization"])
model_group.add_argument("--enforce-eager",
**model_kwargs["enforce_eager"])
model_group.add_argument("--max-seq-len-to-capture",
**model_kwargs["max_seq_len_to_capture"])
model_group.add_argument("--max-logprobs",
**model_kwargs["max_logprobs"])
model_group.add_argument("--disable-sliding-window",
**model_kwargs["disable_sliding_window"])
model_group.add_argument("--disable-cascade-attn",
**model_kwargs["disable_cascade_attn"])
model_group.add_argument("--skip-tokenizer-init",
**model_kwargs["skip_tokenizer_init"])
model_group.add_argument("--served-model-name",
**model_kwargs["served_model_name"])
# This one is a special case because it is the
# opposite of ModelConfig.use_async_output_proc
model_group.add_argument(
"--disable-async-output-proc",
action="store_true",
default=EngineArgs.disable_async_output_proc,
help="Disable async output processing. This may result in "
"lower performance.")
model_group.add_argument("--config-format",
choices=[f.value for f in ConfigFormat],
**model_kwargs["config_format"])
# This one is a special case because it can bool
# or str. TODO: Handle this in get_kwargs
model_group.add_argument("--hf-token",
type=str,
nargs="?",
const=True,
default=model_kwargs["hf_token"]["default"],
help=model_kwargs["hf_token"]["help"])
model_group.add_argument("--hf-overrides",
**model_kwargs["hf_overrides"])
model_group.add_argument("--override-neuron-config",
**model_kwargs["override_neuron_config"])
model_group.add_argument("--override-pooler-config",
**model_kwargs["override_pooler_config"])
model_group.add_argument("--logits-processor-pattern",
**model_kwargs["logits_processor_pattern"])
model_group.add_argument("--generation-config",
**model_kwargs["generation_config"])
model_group.add_argument("--override-generation-config",
**model_kwargs["override_generation_config"])
model_group.add_argument("--enable-sleep-mode",
**model_kwargs["enable_sleep_mode"])
model_group.add_argument("--model-impl",
choices=[f.value for f in ModelImpl],
**model_kwargs["model_impl"])
# Model loading arguments
load_kwargs = get_kwargs(LoadConfig)
load_group = parser.add_argument_group(
@ -465,38 +481,6 @@ class EngineArgs:
load_group.add_argument('--use-tqdm-on-load',
**load_kwargs["use_tqdm_on_load"])
parser.add_argument(
'--config-format',
default=EngineArgs.config_format,
choices=[f.value for f in ConfigFormat],
help='The format of the model config to load.\n\n'
'* "auto" will try to load the config in hf format '
'if available else it will try to load in mistral format ')
parser.add_argument(
'--dtype',
type=str,
default=EngineArgs.dtype,
choices=[
'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
],
help='Data type for model weights and activations.\n\n'
'* "auto" will use FP16 precision for FP32 and FP16 models, and '
'BF16 precision for BF16 models.\n'
'* "half" for FP16. Recommended for AWQ quantization.\n'
'* "float16" is the same as "half".\n'
'* "bfloat16" for a balance between precision and range.\n'
'* "float" is shorthand for FP32 precision.\n'
'* "float32" for FP32 precision.')
parser.add_argument('--max-model-len',
type=human_readable_int,
default=EngineArgs.max_model_len,
help='Model context length. If unspecified, will '
'be automatically derived from the model config. '
'Supports k/m/g/K/M/G in human-readable format.\n'
'Examples:\n'
'- 1k → 1000\n'
'- 1K → 1024\n')
# Guided decoding arguments
guided_decoding_kwargs = get_kwargs(DecodingConfig)
guided_decoding_group = parser.add_argument_group(
@ -520,26 +504,6 @@ class EngineArgs:
choices=list(ReasoningParserManager.reasoning_parsers),
**guided_decoding_kwargs["reasoning_backend"])
parser.add_argument(
'--logits-processor-pattern',
type=optional_type(str),
default=None,
help='Optional regex pattern specifying valid logits processor '
'qualified names that can be passed with the `logits_processors` '
'extra completion argument. Defaults to None, which allows no '
'processors.')
parser.add_argument(
'--model-impl',
type=str,
default=EngineArgs.model_impl,
choices=[f.value for f in ModelImpl],
help='Which implementation of the model to use.\n\n'
'* "auto" will try to use the vLLM implementation if it exists '
'and fall back to the Transformers implementation if no vLLM '
'implementation is available.\n'
'* "vllm" will use the vLLM model implementation.\n'
'* "transformers" will use the Transformers model '
'implementation.\n')
# Parallel arguments
parallel_kwargs = get_kwargs(ParallelConfig)
parallel_group = parser.add_argument_group(
@ -592,10 +556,6 @@ class EngineArgs:
cache_group.add_argument('--calculate-kv-scales',
**cache_kwargs["calculate_kv_scales"])
parser.add_argument('--disable-sliding-window',
action='store_true',
help='Disables sliding window, '
'capping to sliding window size.')
parser.add_argument('--use-v2-block-manager',
action='store_true',
default=True,
@ -605,73 +565,9 @@ class EngineArgs:
'Setting this flag to True or False'
' has no effect on vLLM behavior.')
parser.add_argument('--seed',
type=int,
default=EngineArgs.seed,
help='Random seed for operations.')
parser.add_argument(
'--max-logprobs',
type=int,
default=EngineArgs.max_logprobs,
help=('Max number of log probs to return logprobs is specified in'
' SamplingParams.'))
parser.add_argument('--disable-log-stats',
action='store_true',
help='Disable logging statistics.')
# Quantization settings.
parser.add_argument('--quantization',
'-q',
type=optional_type(str),
choices=[*QUANTIZATION_METHODS, None],
default=EngineArgs.quantization,
help='Method used to quantize the weights. If '
'None, we first check the `quantization_config` '
'attribute in the model config file. If that is '
'None, we assume the model weights are not '
'quantized and use `dtype` to determine the data '
'type of the weights.')
parser.add_argument(
'--rope-scaling',
default=None,
type=json.loads,
help='RoPE scaling configuration in JSON format. '
'For example, ``{"rope_type":"dynamic","factor":2.0}``')
parser.add_argument('--rope-theta',
default=None,
type=float,
help='RoPE theta. Use with `rope_scaling`. In '
'some cases, changing the RoPE theta improves the '
'performance of the scaled model.')
parser.add_argument(
'--hf-token',
type=str,
nargs='?',
const=True,
default=None,
help='The token to use as HTTP bearer authorization'
' for remote files. If `True`, will use the token '
'generated when running `huggingface-cli login` '
'(stored in `~/.huggingface`).')
parser.add_argument('--hf-overrides',
type=json.loads,
default=EngineArgs.hf_overrides,
help='Extra arguments for the HuggingFace config. '
'This should be a JSON string that will be '
'parsed into a dictionary.')
parser.add_argument('--enforce-eager',
action='store_true',
help='Always use eager-mode PyTorch. If False, '
'will use eager mode and CUDA graph in hybrid '
'for maximal performance and flexibility.')
parser.add_argument('--max-seq-len-to-capture',
type=int,
default=EngineArgs.max_seq_len_to_capture,
help='Maximum sequence length covered by CUDA '
'graphs. When a sequence has context length '
'larger than this, we fall back to eager mode. '
'Additionally for encoder-decoder models, if the '
'sequence length of the encoder input is larger '
'than this, we fall back to the eager mode.')
# Tokenizer arguments
tokenizer_kwargs = get_kwargs(TokenizerPoolConfig)
@ -775,20 +671,6 @@ class EngineArgs:
"Default to `original/**/*` to avoid repeated loading of llama's "
"checkpoints.")
parser.add_argument(
"--served-model-name",
nargs="+",
type=str,
default=None,
help="The model name(s) used in the API. If multiple "
"names are provided, the server will respond to any "
"of the provided names. The model name in the model "
"field of a response will be the first name in this "
"list. If not specified, the model name will be the "
"same as the ``--model`` argument. Noted that this name(s) "
"will also be used in `model_name` tag content of "
"prometheus metrics, if multiple names provided, metrics "
"tag will take the first one.")
parser.add_argument('--qlora-adapter-name-or-path',
type=str,
default=None,
@ -822,13 +704,6 @@ class EngineArgs:
"modules. This involves use of possibly costly and or blocking "
"operations and hence might have a performance impact.")
parser.add_argument(
'--disable-async-output-proc',
action='store_true',
default=EngineArgs.disable_async_output_proc,
help="Disable async output processing. This may result in "
"lower performance.")
# Scheduler arguments
scheduler_kwargs = get_kwargs(SchedulerConfig)
scheduler_group = parser.add_argument_group(
@ -871,19 +746,6 @@ class EngineArgs:
parser.add_argument('--scheduler-cls',
**scheduler_kwargs["scheduler_cls"])
parser.add_argument(
'--override-neuron-config',
type=json.loads,
default=None,
help="Override or set neuron device configuration. "
"e.g. ``{\"cast_logits_dtype\": \"bloat16\"}``.")
parser.add_argument(
'--override-pooler-config',
type=PoolerConfig.from_json,
default=None,
help="Override or set the pooling method for pooling models. "
"e.g. ``{\"pooling_type\": \"mean\", \"normalize\": false}``.")
parser.add_argument('--compilation-config',
'-O',
type=CompilationConfig.from_cli,
@ -920,34 +782,6 @@ class EngineArgs:
help='The worker extension class on top of the worker cls, '
'it is useful if you just want to add new functions to the worker '
'class without changing the existing functions.')
parser.add_argument(
"--generation-config",
type=optional_type(str),
default="auto",
help="The folder path to the generation config. "
"Defaults to 'auto', the generation config will be loaded from "
"model path. If set to 'vllm', no generation config is loaded, "
"vLLM defaults will be used. If set to a folder path, the "
"generation config will be loaded from the specified folder path. "
"If `max_new_tokens` is specified in generation config, then "
"it sets a server-wide limit on the number of output tokens "
"for all requests.")
parser.add_argument(
"--override-generation-config",
type=json.loads,
default=None,
help="Overrides or sets generation config in JSON format. "
"e.g. ``{\"temperature\": 0.5}``. If used with "
"--generation-config=auto, the override parameters will be merged "
"with the default config from the model. If generation-config is "
"None, only the override parameters are used.")
parser.add_argument("--enable-sleep-mode",
action="store_true",
default=False,
help="Enable sleep mode for the engine. "
"(only cuda platform is supported)")
parser.add_argument(
"--additional-config",
@ -966,16 +800,6 @@ class EngineArgs:
"If enabled, the model will be able to generate reasoning content."
)
parser.add_argument(
"--disable-cascade-attn",
action="store_true",
default=False,
help="Disable cascade attention for V1. While cascade attention "
"does not change the mathematical correctness, disabling it "
"could be useful for preventing potential numerical issues. "
"Note that even if this is set to False, cascade attention will be "
"only used when the heuristic tells that it's beneficial.")
return parser
@classmethod
@ -1002,8 +826,7 @@ class EngineArgs:
model=self.model,
hf_config_path=self.hf_config_path,
task=self.task,
# We know this is not None because we set it in __post_init__
tokenizer=cast(str, self.tokenizer),
tokenizer=self.tokenizer,
tokenizer_mode=self.tokenizer_mode,
trust_remote_code=self.trust_remote_code,
allowed_local_media_path=self.allowed_local_media_path,

View File

@ -13,7 +13,7 @@ from typing_extensions import TypeVar, deprecated
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
BeamSearchSequence, get_beam_search_score)
from vllm.config import CompilationConfig
from vllm.config import CompilationConfig, ModelDType, TokenizerMode
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
TaskOption)
from vllm.engine.llm_engine import LLMEngine
@ -32,6 +32,7 @@ from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding.guided_fields import (
GuidedDecodingRequest, LLMGuidedOptions)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
PoolingRequestOutput, RequestOutput,
ScoringRequestOutput)
@ -163,20 +164,20 @@ class LLM:
self,
model: str,
tokenizer: Optional[str] = None,
tokenizer_mode: str = "auto",
tokenizer_mode: TokenizerMode = "auto",
skip_tokenizer_init: bool = False,
trust_remote_code: bool = False,
allowed_local_media_path: str = "",
tensor_parallel_size: int = 1,
dtype: str = "auto",
quantization: Optional[str] = None,
dtype: ModelDType = "auto",
quantization: Optional[QuantizationMethods] = None,
revision: Optional[str] = None,
tokenizer_revision: Optional[str] = None,
seed: Optional[int] = None,
gpu_memory_utilization: float = 0.9,
swap_space: float = 4,
cpu_offload_gb: float = 0,
enforce_eager: Optional[bool] = None,
enforce_eager: bool = False,
max_seq_len_to_capture: int = 8192,
disable_custom_all_reduce: bool = False,
disable_async_output_proc: bool = False,
@ -189,12 +190,7 @@ class LLM:
compilation_config: Optional[Union[int, dict[str, Any]]] = None,
**kwargs,
) -> None:
'''
LLM constructor.
Note: if enforce_eager is unset (enforce_eager is None)
it defaults to False.
'''
"""LLM constructor."""
if "disable_log_stats" not in kwargs:
kwargs["disable_log_stats"] = True

View File

@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
@ -186,7 +187,7 @@ class AQLMConfig(QuantizationConfig):
f"out_group_size={self.out_group_size})")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "aqlm"
@classmethod

View File

@ -7,6 +7,7 @@ import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
@ -44,7 +45,7 @@ class AWQConfig(QuantizationConfig):
f"zero_point={self.zero_point}, "
f"modules_to_not_convert={self.modules_to_not_convert})")
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
return "awq"
def get_supported_act_dtypes(self) -> List[torch.dtype]:

View File

@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
@ -73,7 +74,7 @@ class AWQMarlinConfig(QuantizationConfig):
f"modules_to_not_convert={self.modules_to_not_convert})")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "awq_marlin"
@classmethod
@ -101,8 +102,8 @@ class AWQMarlinConfig(QuantizationConfig):
modules_to_not_convert, config)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "marlin"
or user_quant == "awq_marlin")

View File

@ -2,11 +2,16 @@
import inspect
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
import torch
from torch import nn
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods
else:
QuantizationMethods = str
class QuantizeMethodBase(ABC):
"""Base class for different quantized methods."""
@ -66,7 +71,7 @@ class QuantizationConfig(ABC):
self.packed_modules_mapping: Dict[str, List[str]] = dict()
@abstractmethod
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
"""Name of the quantization method."""
raise NotImplementedError
@ -99,8 +104,8 @@ class QuantizationConfig(ABC):
raise NotImplementedError
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
"""
Detects if this quantization method can support a given checkpoint
format by overriding the user specified quantization method --

View File

@ -5,6 +5,7 @@ import torch
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
@ -100,7 +101,7 @@ class BitBLASConfig(QuantizationConfig):
f"quant_method={self.quant_method})")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "bitblas"
@classmethod
@ -139,8 +140,8 @@ class BitBLASConfig(QuantizationConfig):
lm_head_quantized)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_bitblas_format: bool
is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas"

View File

@ -7,6 +7,7 @@ import torch
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.utils import direct_register_custom_op
@ -56,7 +57,7 @@ class BitsAndBytesConfig(QuantizationConfig):
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
@classmethod
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
return "bitsandbytes"
@classmethod

View File

@ -16,6 +16,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
@ -71,7 +72,7 @@ class CompressedTensorsConfig(QuantizationConfig):
def get_min_capability(cls) -> int:
return 70
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
return "compressed-tensors"
def get_quant_method(

View File

@ -7,6 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
@ -41,8 +42,8 @@ class DeepSpeedFPConfig(QuantizationConfig):
f"group_size={self.group_size}")
@classmethod
def get_name(cls) -> str:
return "DeepSpeedFP"
def get_name(cls) -> QuantizationMethods:
return "deepspeedfp"
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig":

View File

@ -8,6 +8,7 @@ from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
@ -20,7 +21,7 @@ class ExpertsInt8Config(QuantizationConfig):
super().__init__()
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "experts_int8"
@classmethod

View File

@ -9,6 +9,7 @@ from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
@ -38,7 +39,7 @@ class FBGEMMFp8Config(QuantizationConfig):
self.fp8_linear = Fp8LinearOp()
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "fbgemm_fp8"
@classmethod

View File

@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
@ -83,7 +84,7 @@ class Fp8Config(QuantizationConfig):
self.weight_block_size = weight_block_size
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "fp8"
@classmethod

View File

@ -13,6 +13,7 @@ from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.vocab_parallel_embedding import (
@ -31,7 +32,7 @@ class GGUFConfig(QuantizationConfig):
def __repr__(self) -> str:
return ("GGUFConfig()")
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
return "gguf"
def get_supported_act_dtypes(self) -> List[torch.dtype]:

View File

@ -10,6 +10,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
@ -79,7 +80,7 @@ class GPTQConfig(QuantizationConfig):
f"dynamic={self.dynamic}")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "gptq"
@classmethod

View File

@ -7,6 +7,7 @@ from torch.nn.parameter import Parameter
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
@ -123,7 +124,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
f"quant_method={self.quant_method})")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "gptq_bitblas"
@classmethod
@ -151,8 +152,8 @@ class GPTQBitBLASConfig(QuantizationConfig):
lm_head_quantized)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "bitblas"

View File

@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
@ -100,7 +101,7 @@ class GPTQMarlinConfig(QuantizationConfig):
f"dynamic={self.dynamic}")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "gptq_marlin"
@classmethod
@ -130,8 +131,8 @@ class GPTQMarlinConfig(QuantizationConfig):
lm_head_quantized, dynamic, config)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "marlin"

View File

@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import (BasevLLMParameter,
@ -85,7 +86,7 @@ class GPTQMarlin24Config(QuantizationConfig):
self.quant_type, self.group_size)
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "gptq_marlin_24"
@classmethod
@ -108,8 +109,8 @@ class GPTQMarlin24Config(QuantizationConfig):
return cls(weight_bits, group_size)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
is_marlin_24_format = (
hf_quant_cfg.get("checkpoint_format") == "marlin_24")

View File

@ -8,6 +8,7 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
@ -50,7 +51,7 @@ class HQQMarlinConfig(QuantizationConfig):
f"group_size={self.group_size})")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "hqq"
@classmethod

View File

@ -6,6 +6,7 @@ import torch
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod,
is_layer_skipped_awq)
from vllm.model_executor.layers.quantization.base_config import (
@ -58,7 +59,7 @@ class IPEXConfig(QuantizationConfig):
f"group_size={self.group_size})")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "ipex"
@classmethod
@ -97,8 +98,8 @@ class IPEXConfig(QuantizationConfig):
lm_head_quantized)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
if not current_platform.is_cpu() and not current_platform.is_xpu():
return None

View File

@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
@ -63,7 +64,7 @@ class MarlinConfig(QuantizationConfig):
f"lm_head_quantized={self.lm_head_quantized})")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "marlin"
@classmethod
@ -87,8 +88,8 @@ class MarlinConfig(QuantizationConfig):
return cls(group_size, lm_head_quantized)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_marlin_format: bool
is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin"

View File

@ -11,6 +11,7 @@ from vllm._custom_ops import (cutlass_scaled_fp4_mm,
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
@ -42,7 +43,7 @@ class ModelOptFp8Config(QuantizationConfig):
" the format is experimental and could change.")
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "modelopt"
@classmethod
@ -184,8 +185,8 @@ class ModelOptNvFp4Config(QuantizationConfig):
self.exclude_modules = exclude_modules
@classmethod
def get_name(cls) -> str:
return "modelopt_nvfp4"
def get_name(cls) -> QuantizationMethods:
return "nvfp4"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:

View File

@ -9,6 +9,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
@ -64,7 +65,7 @@ class MoeWNA16Config(QuantizationConfig):
self.modules_to_not_convert = modules_to_not_convert
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "moe_wna16"
@classmethod
@ -100,8 +101,8 @@ class MoeWNA16Config(QuantizationConfig):
lm_head_quantized, modules_to_not_convert, config)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
if can_convert and user_quant == "moe_wna16":
return cls.get_name()

View File

@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional
from torch.nn import Module
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
@ -30,7 +31,7 @@ class NeuronQuantConfig(QuantizationConfig):
self.dequant_dtype = dequant_dtype
self.quantize_method = quantize_method
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
return "neuron_quant"
def get_supported_act_dtypes(self) -> List[str]:

View File

@ -9,6 +9,7 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizeMethodBase)
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
@ -50,7 +51,7 @@ class PTPCFp8Config(Fp8Config):
ignored_layers=ignored_layers)
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "ptpc_fp8"
@classmethod

View File

@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import (BasevLLMParameter,
@ -84,7 +85,7 @@ class QQQConfig(QuantizationConfig):
self.weight_bits, self.group_size)
@classmethod
def get_name(cls) -> str:
def get_name(cls) -> QuantizationMethods:
return "qqq"
@classmethod

View File

@ -8,6 +8,7 @@ import torch
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
@ -47,7 +48,7 @@ class QuarkConfig(QuantizationConfig):
def get_min_capability(cls) -> int:
return 70
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
return "quark"
def get_quant_method(self, layer: torch.nn.Module,

View File

@ -6,6 +6,7 @@ import torch.nn.functional as F
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
@ -20,7 +21,7 @@ class TorchAOConfig(QuantizationConfig):
def __repr__(self) -> str:
return f"TorchAOConfig({self.torchao_config})"
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
return "torchao"
def get_supported_act_dtypes(self) -> List[torch.dtype]:

View File

@ -7,6 +7,7 @@ from torch.nn import Module
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.parameter import ModelWeightParameter
@ -27,7 +28,7 @@ class Int8TpuConfig(QuantizationConfig):
f"Unsupported activation scheme {activation_scheme}")
self.activation_scheme = activation_scheme
def get_name(self) -> str:
def get_name(self) -> QuantizationMethods:
return "tpu_int8"
def get_supported_act_dtypes(self) -> List[torch.dtype]:

View File

@ -1496,7 +1496,7 @@ def get_rope(
if key in _ROPE_DICT:
return _ROPE_DICT[key]
if rope_scaling is None:
if not rope_scaling:
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
is_neox_style, dtype)
else:

View File

@ -180,7 +180,6 @@ def _get_neuron_config_after_override(default_neuron_config,
NeuronConfig, QuantizationConfig,
SparseAttnConfig)
overridden_neuron_config = overridden_neuron_config or {}
sparse_attn = overridden_neuron_config.pop("sparse_attn", {})
if sparse_attn:
overridden_neuron_config["sparse_attn"] = SparseAttnConfig(