mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 20:04:27 +08:00
Improve configs - ModelConfig (#17130)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
2c4f59afc3
commit
13698db634
@ -738,7 +738,7 @@ class VllmRunner:
|
||||
- `block_size`: Set to `16` instead of `None` to reduce memory usage.
|
||||
- `enable_chunked_prefill`: Set to `False` instead of `None` for
|
||||
test reproducibility.
|
||||
- `enforce_eager`: Set to `False` instead of `None` to test CUDA graph.
|
||||
- `enforce_eager`: Set to `False` to test CUDA graph.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
||||
@ -8,7 +8,7 @@ from typing import Literal, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import PoolerConfig, config
|
||||
from vllm.config import config
|
||||
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
|
||||
get_type, is_not_builtin, is_type,
|
||||
literal_to_kwargs, nullable_kvs,
|
||||
@ -222,17 +222,6 @@ def test_prefix_cache_default():
|
||||
assert not engine_args.enable_prefix_caching
|
||||
|
||||
|
||||
def test_valid_pooling_config():
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
args = parser.parse_args([
|
||||
'--override-pooler-config',
|
||||
'{"pooling_type": "MEAN"}',
|
||||
])
|
||||
engine_args = EngineArgs.from_cli_args(args=args)
|
||||
assert engine_args.override_pooler_config == PoolerConfig(
|
||||
pooling_type="MEAN", )
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("arg"),
|
||||
[
|
||||
|
||||
@ -14,7 +14,7 @@ import torch.nn.functional as F
|
||||
from vllm.model_executor.layers.linear import LinearBase # noqa: E501
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization import (
|
||||
get_quantization_config, register_quantization_config)
|
||||
QuantizationMethods, get_quantization_config, register_quantization_config)
|
||||
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||
QuantizationConfig)
|
||||
|
||||
@ -54,7 +54,7 @@ class CustomQuantConfig(QuantizationConfig):
|
||||
"""Initialize the quantization config."""
|
||||
self.num_bits = num_bits
|
||||
|
||||
def get_name(self) -> str:
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
"""Name of the quantization method."""
|
||||
return "custom_quant"
|
||||
|
||||
|
||||
@ -185,7 +185,7 @@ def test_get_pooling_config():
|
||||
revision=None,
|
||||
)
|
||||
|
||||
pooling_config = model_config._init_pooler_config(None)
|
||||
pooling_config = model_config._init_pooler_config()
|
||||
assert pooling_config is not None
|
||||
|
||||
assert pooling_config.normalize
|
||||
@ -205,11 +205,12 @@ def test_get_pooling_config_from_args():
|
||||
dtype="float16",
|
||||
revision=None)
|
||||
|
||||
override_config = PoolerConfig(pooling_type='CLS', normalize=True)
|
||||
override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True)
|
||||
model_config.override_pooler_config = override_pooler_config
|
||||
|
||||
pooling_config = model_config._init_pooler_config(override_config)
|
||||
pooling_config = model_config._init_pooler_config()
|
||||
assert pooling_config is not None
|
||||
assert asdict(pooling_config) == asdict(override_config)
|
||||
assert asdict(pooling_config) == asdict(override_pooler_config)
|
||||
|
||||
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
|
||||
509
vllm/config.py
509
vllm/config.py
@ -16,9 +16,8 @@ from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
|
||||
replace)
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
|
||||
Optional, Protocol, TypeVar, Union, cast, get_args,
|
||||
get_origin)
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
|
||||
Protocol, TypeVar, Union, cast, get_args, get_origin)
|
||||
|
||||
import torch
|
||||
from pydantic import BaseModel, Field, PrivateAttr
|
||||
@ -211,103 +210,190 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
||||
f"{cls.__name__}.{name} must have a default value or default factory.")
|
||||
|
||||
|
||||
class ModelConfig:
|
||||
"""Configuration for the model.
|
||||
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
|
||||
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
||||
|
||||
Args:
|
||||
model: Name or path of the huggingface model to use.
|
||||
It is also used as the content for `model_name` tag in metrics
|
||||
output when `served_model_name` is not specified.
|
||||
task: The task to use the model for. Each vLLM instance only supports
|
||||
one task, even if the same model can be used for multiple tasks.
|
||||
When the model only supports one task, "auto" can be used to select
|
||||
it; otherwise, you must specify explicitly which task to use.
|
||||
tokenizer: Name or path of the huggingface tokenizer to use.
|
||||
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
|
||||
available, "slow" will always use the slow tokenizer,
|
||||
"mistral" will always use the tokenizer from `mistral_common`, and
|
||||
"custom" will use --tokenizer to select the preregistered tokenizer.
|
||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
||||
downloading the model and tokenizer.
|
||||
allowed_local_media_path: Allowing API requests to read local images or
|
||||
videos from directories specified by the server file system.
|
||||
This is a security risk. Should only be enabled in trusted
|
||||
environments.
|
||||
dtype: Data type for model weights and activations. The "auto" option
|
||||
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
||||
for BF16 models.
|
||||
seed: Random seed for reproducibility.
|
||||
revision: The specific model version to use. It can be a branch name,
|
||||
a tag name, or a commit id. If unspecified, will use the default
|
||||
version.
|
||||
code_revision: The specific revision to use for the model code on
|
||||
Hugging Face Hub. It can be a branch name, a tag name, or a
|
||||
commit id. If unspecified, will use the default version.
|
||||
tokenizer_revision: The specific tokenizer version to use. It can be a
|
||||
branch name, a tag name, or a commit id. If unspecified, will use
|
||||
the default version.
|
||||
max_model_len: Maximum length of a sequence (including prompt and
|
||||
output). If None, will be derived from the model.
|
||||
spec_target_max_model_len: Specify the the maximum length for spec
|
||||
decoding draft models.
|
||||
quantization: Quantization method that was used to quantize the model
|
||||
weights. If None, we assume the model weights are not quantized.
|
||||
enforce_eager: Whether to enforce eager execution. If True, we will
|
||||
disable CUDA graph and always execute the model in eager mode.
|
||||
If False, we will use CUDA graph and eager execution in hybrid.
|
||||
If None, the user did not specify, so default to False.
|
||||
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
|
||||
When a sequence has context length larger than this, we fall back
|
||||
to eager mode. Additionally for encoder-decoder models, if the
|
||||
sequence length of the encoder input is larger than this, we fall
|
||||
back to the eager mode.
|
||||
max_logprobs: Maximum number of log probabilities. Defaults to 20.
|
||||
disable_sliding_window: Whether to disable sliding window. If True,
|
||||
we will disable the sliding window functionality of the model.
|
||||
If the model does not support sliding window, this argument is
|
||||
ignored.
|
||||
skip_tokenizer_init: If true, skip initialization of tokenizer and
|
||||
detokenizer.
|
||||
served_model_name: The model name used in metrics tag `model_name`,
|
||||
matches the model name exposed via the APIs. If multiple model
|
||||
names provided, the first name will be used. If not specified,
|
||||
the model name will be the same as `model`.
|
||||
limit_mm_per_prompt: Maximum number of data items per modality
|
||||
per prompt. Only applicable for multimodal models.
|
||||
mm_processor_kwargs: Overrides for the multi-modal processor obtained
|
||||
from `AutoProcessor.from_pretrained`.
|
||||
disable_mm_preprocessor_cache: If True, disable caching of the
|
||||
processed multi-modal inputs.
|
||||
use_async_output_proc: Whether to use async output processor.
|
||||
Defaults to True.
|
||||
config_format: The config format which shall be loaded.
|
||||
Defaults to 'auto' which defaults to 'hf'.
|
||||
hf_token: The token to use as HTTP bearer authorization for remote files
|
||||
. If `True`, will use the token generated when running
|
||||
`huggingface-cli login` (stored in `~/.huggingface`).
|
||||
hf_overrides: If a dictionary, contains arguments to be forwarded to the
|
||||
HuggingFace config. If a callable, it is called to update the
|
||||
HuggingFace config.
|
||||
override_neuron_config: Initialize non default neuron config or
|
||||
override default neuron config that are specific to Neuron devices,
|
||||
this argument will be used to configure the neuron config that
|
||||
can not be gathered from the vllm arguments.
|
||||
override_pooler_config: Initialize non default pooling config or
|
||||
override default pooling config for the pooling model.
|
||||
logits_processor_pattern: Optional regex pattern specifying valid
|
||||
logits processor qualified names that can be passed with the
|
||||
`logits_processors` extra completion argument. Defaults to None,
|
||||
which allows no processors.
|
||||
generation_config: Configuration parameter file for generation.
|
||||
model_impl: Which implementation of the model to use:
|
||||
"auto" will try to use the vLLM implementation if it exists and
|
||||
fall back to the Transformers implementation if no vLLM
|
||||
implementation is available.
|
||||
"vllm" will use the vLLM model implementation.
|
||||
"transformers" will use the Transformers model implementation.
|
||||
override_generation_config: Override the generation config with the
|
||||
given config.
|
||||
"""
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""Configuration for the model."""
|
||||
|
||||
model: str = "facebook/opt-125m"
|
||||
"""Name or path of the Hugging Face model to use. It is also used as the
|
||||
content for `model_name` tag in metrics output when `served_model_name` is
|
||||
not specified."""
|
||||
task: Literal[TaskOption, Literal["draft"]] = "auto"
|
||||
"""The task to use the model for. Each vLLM instance only supports one
|
||||
task, even if the same model can be used for multiple tasks. When the model
|
||||
only supports one task, "auto" can be used to select it; otherwise, you
|
||||
must specify explicitly which task to use."""
|
||||
tokenizer: str = None # type: ignore
|
||||
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
|
||||
name or path will be used."""
|
||||
tokenizer_mode: TokenizerMode = "auto"
|
||||
"""Tokenizer mode:\n
|
||||
- "auto" will use the fast tokenizer if available.\n
|
||||
- "slow" will always use the slow tokenizer.\n
|
||||
- "mistral" will always use the tokenizer from `mistral_common`.\n
|
||||
- "custom" will use --tokenizer to select the preregistered tokenizer."""
|
||||
trust_remote_code: bool = False
|
||||
"""Trust remote code (e.g., from HuggingFace) when downloading the model
|
||||
and tokenizer."""
|
||||
dtype: Union[ModelDType, torch.dtype] = "auto"
|
||||
"""Data type for model weights and activations:\n
|
||||
- "auto" will use FP16 precision for FP32 and FP16 models, and BF16
|
||||
precision for BF16 models.\n
|
||||
- "half" for FP16. Recommended for AWQ quantization.\n
|
||||
- "float16" is the same as "half".\n
|
||||
- "bfloat16" for a balance between precision and range.\n
|
||||
- "float" is shorthand for FP32 precision.\n
|
||||
- "float32" for FP32 precision."""
|
||||
seed: Optional[int] = None
|
||||
"""Random seed for reproducibility."""
|
||||
hf_config_path: Optional[str] = None
|
||||
"""Name or path of the Hugging Face config to use. If unspecified, model
|
||||
name or path will be used."""
|
||||
allowed_local_media_path: str = ""
|
||||
"""Allowing API requests to read local images or videos from directories
|
||||
specified by the server file system. This is a security risk. Should only
|
||||
be enabled in trusted environments."""
|
||||
revision: Optional[str] = None
|
||||
"""The specific model version to use. It can be a branch name, a tag name,
|
||||
or a commit id. If unspecified, will use the default version."""
|
||||
code_revision: Optional[str] = None
|
||||
"""The specific revision to use for the model code on the Hugging Face Hub.
|
||||
It can be a branch name, a tag name, or a commit id. If unspecified, will
|
||||
use the default version."""
|
||||
rope_scaling: dict[str, Any] = field(default_factory=dict)
|
||||
"""RoPE scaling configuration in JSON format. For example,
|
||||
`{"rope_type":"dynamic","factor":2.0}`."""
|
||||
rope_theta: Optional[float] = None
|
||||
"""RoPE theta. Use with `rope_scaling`. In some cases, changing the RoPE
|
||||
theta improves the performance of the scaled model."""
|
||||
tokenizer_revision: Optional[str] = None
|
||||
"""The specific revision to use for the tokenizer on the Hugging Face Hub.
|
||||
It can be a branch name, a tag name, or a commit id. If unspecified, will
|
||||
use the default version."""
|
||||
max_model_len: int = None # type: ignore
|
||||
"""Model context length (prompt and output). If unspecified, will be
|
||||
automatically derived from the model config.
|
||||
|
||||
When passing via `--max-model-len`, supports k/m/g/K/M/G in human-readable
|
||||
format. Examples:\n
|
||||
- 1k -> 1000\n
|
||||
- 1K -> 1024\n
|
||||
- 25.6k -> 25,600"""
|
||||
spec_target_max_model_len: Optional[int] = None
|
||||
"""Specify the the maximum length for spec decoding draft models."""
|
||||
quantization: Optional[QuantizationMethods] = None
|
||||
"""Method used to quantize the weights. If `None`, we first check the
|
||||
`quantization_config` attribute in the model config file. If that is
|
||||
`None`, we assume the model weights are not quantized and use `dtype` to
|
||||
determine the data type of the weights."""
|
||||
enforce_eager: bool = False
|
||||
"""Whether to always use eager-mode PyTorch. If True, we will disable CUDA
|
||||
graph and always execute the model in eager mode. If False, we will use
|
||||
CUDA graph and eager execution in hybrid for maximal performance and
|
||||
flexibility."""
|
||||
max_seq_len_to_capture: int = 8192
|
||||
"""Maximum sequence len covered by CUDA graphs. When a sequence has context
|
||||
length larger than this, we fall back to eager mode. Additionally for
|
||||
encoder-decoder models, if the sequence length of the encoder input is
|
||||
larger than this, we fall back to the eager mode."""
|
||||
max_logprobs: int = 20
|
||||
"""Maximum number of log probabilities to return when `logprobs` is
|
||||
specified in `SamplingParams`. The default value comes the default for the
|
||||
OpenAI Chat Completions API."""
|
||||
disable_sliding_window: bool = False
|
||||
"""Whether to disable sliding window. If True, we will disable the sliding
|
||||
window functionality of the model, capping to sliding window size. If the
|
||||
model does not support sliding window, this argument is ignored."""
|
||||
disable_cascade_attn: bool = False
|
||||
"""Disable cascade attention for V1. While cascade attention does not
|
||||
change the mathematical correctness, disabling it could be useful for
|
||||
preventing potential numerical issues. Note that even if this is set to
|
||||
False, cascade attention will be only used when the heuristic tells that
|
||||
it's beneficial."""
|
||||
skip_tokenizer_init: bool = False
|
||||
"""Skip initialization of tokenizer and detokenizer. Expects valid
|
||||
`prompt_token_ids` and `None` for prompt from the input. The generated
|
||||
output will contain token ids."""
|
||||
served_model_name: Optional[Union[str, list[str]]] = None
|
||||
"""The model name(s) used in the API. If multiple names are provided, the
|
||||
server will respond to any of the provided names. The model name in the
|
||||
model field of a response will be the first name in this list. If not
|
||||
specified, the model name will be the same as the `--model` argument. Noted
|
||||
that this name(s) will also be used in `model_name` tag content of
|
||||
prometheus metrics, if multiple names provided, metrics tag will take the
|
||||
first one."""
|
||||
limit_mm_per_prompt: dict[str, int] = field(default_factory=dict)
|
||||
"""Maximum number of data items per modality per prompt. Only applicable
|
||||
for multimodal models."""
|
||||
use_async_output_proc: bool = True
|
||||
"""Whether to use async output processor."""
|
||||
config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value
|
||||
"""The format of the model config to load:\n
|
||||
- "auto" will try to load the config in hf format if available else it
|
||||
will try to load in mistral format.\n
|
||||
- "hf" will load the config in hf format.\n
|
||||
- "mistral" will load the config in mistral format."""
|
||||
hf_token: Optional[Union[bool, str]] = None
|
||||
"""The token to use as HTTP bearer authorization for remote files . If
|
||||
`True`, will use the token generated when running `huggingface-cli login`
|
||||
(stored in `~/.huggingface`)."""
|
||||
hf_overrides: HfOverrides = field(default_factory=dict)
|
||||
"""If a dictionary, contains arguments to be forwarded to the Hugging Face
|
||||
config. If a callable, it is called to update the HuggingFace config. When
|
||||
specified via CLI, the argument must be a valid JSON string."""
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = None
|
||||
"""Arguments to be forwarded to the model's processor for multi-modal data,
|
||||
e.g., image processor. Overrides for the multi-modal processor obtained
|
||||
from `AutoProcessor.from_pretrained`. The available overrides depend on the
|
||||
model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`.
|
||||
When specified via CLI, the argument must be a valid JSON string."""
|
||||
disable_mm_preprocessor_cache: bool = False
|
||||
"""If `True`, disable caching of the multi-modal preprocessor/mapper (not
|
||||
recommended)."""
|
||||
override_neuron_config: dict[str, Any] = field(default_factory=dict)
|
||||
"""Initialize non-default neuron config or override default neuron config
|
||||
that are specific to Neuron devices, this argument will be used to
|
||||
configure the neuron config that can not be gathered from the vllm
|
||||
arguments. e.g. `{"cast_logits_dtype": "bloat16"}`. When specified via CLI,
|
||||
the argument must be a valid JSON string."""
|
||||
pooler_config: Optional["PoolerConfig"] = field(init=False)
|
||||
"""Pooler config which controls the behaviour of output pooling in pooling
|
||||
models."""
|
||||
override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None
|
||||
"""Initialize non-default pooling config or override default pooling config
|
||||
for the pooling model. e.g. `{"pooling_type": "mean", "normalize": false}`.
|
||||
When specified via CLI, the argument must be a valid JSON string."""
|
||||
logits_processor_pattern: Optional[str] = None
|
||||
"""Optional regex pattern specifying valid logits processor qualified names
|
||||
that can be passed with the `logits_processors` extra completion argument.
|
||||
Defaults to `None`, which allows no processors."""
|
||||
generation_config: str = "auto"
|
||||
"""The folder path to the generation config. Defaults to `"auto"`, the
|
||||
generation config will be loaded from model path. If set to `"vllm"`, no
|
||||
generation config is loaded, vLLM defaults will be used. If set to a folder
|
||||
path, the generation config will be loaded from the specified folder path.
|
||||
If `max_new_tokens` is specified in generation config, then it sets a
|
||||
server-wide limit on the number of output tokens for all requests."""
|
||||
override_generation_config: dict[str, Any] = field(default_factory=dict)
|
||||
"""Overrides or sets generation config. e.g. `{"temperature": 0.5}`. If
|
||||
used with `--generation-config auto`, the override parameters will be
|
||||
merged with the default config from the model. If used with
|
||||
`--generation-config vllm`, only the override parameters are used.
|
||||
When specified via CLI, the argument must be a valid JSON string."""
|
||||
enable_sleep_mode: bool = False
|
||||
"""Enable sleep mode for the engine (only cuda platform is supported)."""
|
||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO.value
|
||||
"""Which implementation of the model to use:\n
|
||||
- "auto" will try to use the vLLM implementation, if it exists, and fall
|
||||
back to the Transformers implementation if no vLLM implementation is
|
||||
available.\n
|
||||
- "vllm" will use the vLLM model implementation.\n
|
||||
- "transformers" will use the Transformers model implementation."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
@ -342,92 +428,43 @@ class ModelConfig:
|
||||
assert_hashable(str_factors)
|
||||
return hashlib.sha256(str(factors).encode()).hexdigest()
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
task: Literal[TaskOption, Literal["draft"]],
|
||||
tokenizer: str,
|
||||
tokenizer_mode: str,
|
||||
trust_remote_code: bool,
|
||||
dtype: Union[str, torch.dtype],
|
||||
seed: int,
|
||||
hf_config_path: Optional[str] = None,
|
||||
allowed_local_media_path: str = "",
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
rope_scaling: Optional[dict[str, Any]] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
spec_target_max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
max_seq_len_to_capture: Optional[int] = None,
|
||||
max_logprobs: int = 20,
|
||||
disable_sliding_window: bool = False,
|
||||
disable_cascade_attn: bool = False,
|
||||
skip_tokenizer_init: bool = False,
|
||||
served_model_name: Optional[Union[str, list[str]]] = None,
|
||||
limit_mm_per_prompt: Optional[dict[str, int]] = None,
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = None,
|
||||
disable_mm_preprocessor_cache: bool = False,
|
||||
use_async_output_proc: bool = True,
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||
hf_token: Optional[Union[bool, str]] = None,
|
||||
hf_overrides: Optional[HfOverrides] = None,
|
||||
override_neuron_config: Optional[dict[str, Any]] = None,
|
||||
override_pooler_config: Optional["PoolerConfig"] = None,
|
||||
logits_processor_pattern: Optional[str] = None,
|
||||
generation_config: str = "auto",
|
||||
enable_sleep_mode: bool = False,
|
||||
override_generation_config: Optional[dict[str, Any]] = None,
|
||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
||||
) -> None:
|
||||
self.model = maybe_model_redirect(model)
|
||||
self.tokenizer = maybe_model_redirect(tokenizer)
|
||||
def __post_init__(self) -> None:
|
||||
self.model = maybe_model_redirect(self.model)
|
||||
# The tokenizer is consistent with the model by default.
|
||||
if self.tokenizer is None:
|
||||
self.tokenizer = self.model
|
||||
if self.tokenizer_revision is None:
|
||||
self.tokenizer_revision = self.revision
|
||||
self.tokenizer = maybe_model_redirect(self.tokenizer)
|
||||
|
||||
self.hf_config_path = hf_config_path
|
||||
if isinstance(hf_config_path, str):
|
||||
self.hf_config_path = maybe_model_redirect(hf_config_path)
|
||||
if isinstance(self.hf_config_path, str):
|
||||
self.hf_config_path = maybe_model_redirect(self.hf_config_path)
|
||||
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
self.trust_remote_code = trust_remote_code
|
||||
self.allowed_local_media_path = allowed_local_media_path
|
||||
self.seed = seed
|
||||
self.revision = revision
|
||||
self.code_revision = code_revision
|
||||
self.rope_scaling = rope_scaling
|
||||
self.rope_theta = rope_theta
|
||||
self.model_impl = model_impl
|
||||
|
||||
if hf_overrides is None:
|
||||
hf_overrides = {}
|
||||
|
||||
if callable(hf_overrides):
|
||||
if callable(self.hf_overrides):
|
||||
hf_overrides_kw = {}
|
||||
hf_overrides_fn = hf_overrides
|
||||
hf_overrides_fn = self.hf_overrides
|
||||
else:
|
||||
hf_overrides_kw = hf_overrides
|
||||
hf_overrides_kw = self.hf_overrides
|
||||
hf_overrides_fn = None
|
||||
|
||||
if rope_scaling is not None:
|
||||
hf_override: dict[str, Any] = {"rope_scaling": rope_scaling}
|
||||
if self.rope_scaling:
|
||||
hf_override: dict[str, Any] = {"rope_scaling": self.rope_scaling}
|
||||
hf_overrides_kw.update(hf_override)
|
||||
hf_overrides_str = json.dumps(hf_overrides)
|
||||
hf_overrides_str = json.dumps(hf_overrides_kw)
|
||||
msg = (
|
||||
"`--rope-scaling` will be removed in a future release. "
|
||||
f"'Please instead use `--hf-overrides '{hf_overrides_str}'`")
|
||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||
if rope_theta is not None:
|
||||
hf_override = {"rope_theta": rope_theta}
|
||||
if self.rope_theta is not None:
|
||||
hf_override = {"rope_theta": self.rope_theta}
|
||||
hf_overrides_kw.update(hf_override)
|
||||
hf_overrides_str = json.dumps(hf_overrides)
|
||||
hf_overrides_str = json.dumps(hf_overrides_kw)
|
||||
msg = (
|
||||
"`--rope-theta` will be removed in a future release. "
|
||||
f"'Please instead use `--hf-overrides '{hf_overrides_str}'`")
|
||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||
|
||||
self.maybe_pull_model_tokenizer_for_s3(model, tokenizer)
|
||||
self.maybe_pull_model_tokenizer_for_s3(self.model, self.tokenizer)
|
||||
|
||||
if (backend := envs.VLLM_ATTENTION_BACKEND
|
||||
) and backend == "FLASHINFER" and find_spec("flashinfer") is None:
|
||||
@ -437,20 +474,6 @@ class ModelConfig:
|
||||
"https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501
|
||||
"for instructions on how to install it.")
|
||||
|
||||
# The tokenizer version is consistent with the model version by default.
|
||||
if tokenizer_revision is None:
|
||||
self.tokenizer_revision = revision
|
||||
else:
|
||||
self.tokenizer_revision = tokenizer_revision
|
||||
self.quantization = quantization
|
||||
self.enforce_eager = enforce_eager
|
||||
self.max_seq_len_to_capture = max_seq_len_to_capture
|
||||
self.max_logprobs = max_logprobs
|
||||
self.disable_sliding_window = disable_sliding_window
|
||||
self.disable_cascade_attn = disable_cascade_attn
|
||||
self.skip_tokenizer_init = skip_tokenizer_init
|
||||
self.enable_sleep_mode = enable_sleep_mode
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if (self.enable_sleep_mode
|
||||
@ -458,9 +481,12 @@ class ModelConfig:
|
||||
raise ValueError(
|
||||
"Sleep mode is not supported on current platform.")
|
||||
|
||||
if isinstance(self.config_format, str):
|
||||
self.config_format = ConfigFormat(self.config_format)
|
||||
|
||||
hf_config = get_config(self.hf_config_path or self.model,
|
||||
trust_remote_code, revision, code_revision,
|
||||
config_format)
|
||||
self.trust_remote_code, self.revision,
|
||||
self.code_revision, self.config_format)
|
||||
|
||||
if hf_overrides_kw:
|
||||
logger.info("Overriding HF config with %s", hf_overrides_kw)
|
||||
@ -476,13 +502,8 @@ class ModelConfig:
|
||||
"attention_chunk_size", None)
|
||||
self.encoder_config = self._get_encoder_config()
|
||||
self.hf_image_processor_config = get_hf_image_processor_config(
|
||||
self.model, hf_token=hf_token, revision=revision)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
||||
self.use_async_output_proc = use_async_output_proc
|
||||
|
||||
# Set enforce_eager to False if the value is unset.
|
||||
if self.enforce_eager is None:
|
||||
self.enforce_eager = False
|
||||
self.model, hf_token=self.hf_token, revision=self.revision)
|
||||
self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype)
|
||||
|
||||
interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"]
|
||||
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
|
||||
@ -515,18 +536,14 @@ class ModelConfig:
|
||||
|
||||
self.max_model_len = _get_and_verify_max_len(
|
||||
hf_config=self.hf_text_config,
|
||||
max_model_len=max_model_len,
|
||||
max_model_len=self.max_model_len,
|
||||
disable_sliding_window=self.disable_sliding_window,
|
||||
sliding_window_len=self.get_hf_config_sliding_window(),
|
||||
spec_target_max_model_len=spec_target_max_model_len,
|
||||
spec_target_max_model_len=self.spec_target_max_model_len,
|
||||
encoder_config=self.encoder_config)
|
||||
self.served_model_name = get_served_model_name(model,
|
||||
served_model_name)
|
||||
self.multimodal_config = self._init_multimodal_config(
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
|
||||
)
|
||||
self.served_model_name = get_served_model_name(self.model,
|
||||
self.served_model_name)
|
||||
self.multimodal_config = self._init_multimodal_config()
|
||||
if not self.skip_tokenizer_init:
|
||||
self._verify_tokenizer_mode()
|
||||
|
||||
@ -535,24 +552,19 @@ class ModelConfig:
|
||||
self.has_noops = self._init_has_noops()
|
||||
self.has_inner_state = self._init_has_inner_state()
|
||||
|
||||
if current_platform.is_neuron():
|
||||
self.override_neuron_config = override_neuron_config
|
||||
else:
|
||||
self.override_neuron_config = None
|
||||
if (not current_platform.is_neuron() and self.override_neuron_config):
|
||||
raise ValueError(
|
||||
"`override_neuron_config` is only supported on Neuron.")
|
||||
|
||||
supported_tasks, task = self._resolve_task(task)
|
||||
supported_tasks, task = self._resolve_task(self.task)
|
||||
self.supported_tasks = supported_tasks
|
||||
self.task: Final = task
|
||||
self.task = task
|
||||
if self.task in ("draft", "generate"):
|
||||
self.truncation_side = "left"
|
||||
else:
|
||||
self.truncation_side = "right"
|
||||
|
||||
self.pooler_config = self._init_pooler_config(override_pooler_config)
|
||||
self.logits_processor_pattern = logits_processor_pattern
|
||||
|
||||
self.generation_config = generation_config
|
||||
self.override_generation_config = override_generation_config or {}
|
||||
self.pooler_config = self._init_pooler_config()
|
||||
|
||||
self._verify_quantization()
|
||||
self._verify_cuda_graph()
|
||||
@ -591,26 +603,21 @@ class ModelConfig:
|
||||
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
|
||||
self.tokenizer = s3_tokenizer.dir
|
||||
|
||||
def _init_multimodal_config(
|
||||
self,
|
||||
limit_mm_per_prompt: Optional[dict[str, int]],
|
||||
mm_processor_kwargs: Optional[dict[str, Any]],
|
||||
disable_mm_preprocessor_cache: bool,
|
||||
) -> Optional["MultiModalConfig"]:
|
||||
def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:
|
||||
if self.registry.is_multimodal_model(self.architectures):
|
||||
return MultiModalConfig(
|
||||
limit_per_prompt=limit_mm_per_prompt or {},
|
||||
mm_processor_kwargs=mm_processor_kwargs or {},
|
||||
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
|
||||
)
|
||||
limit_per_prompt=self.limit_mm_per_prompt,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
disable_mm_preprocessor_cache=self.
|
||||
disable_mm_preprocessor_cache)
|
||||
|
||||
if limit_mm_per_prompt:
|
||||
if self.limit_mm_per_prompt:
|
||||
raise ValueError("`limit_mm_per_prompt` is only supported for "
|
||||
"multimodal models.")
|
||||
if mm_processor_kwargs:
|
||||
if self.mm_processor_kwargs:
|
||||
raise ValueError("`mm_processor_kwargs` is only supported for "
|
||||
"multimodal models.")
|
||||
if disable_mm_preprocessor_cache:
|
||||
if self.disable_mm_preprocessor_cache:
|
||||
raise ValueError("`disable_mm_preprocessor_cache` is only "
|
||||
"supported for multimodal models.")
|
||||
|
||||
@ -620,31 +627,32 @@ class ModelConfig:
|
||||
return get_sentence_transformer_tokenizer_config(
|
||||
self.model, self.revision)
|
||||
|
||||
def _init_pooler_config(
|
||||
self,
|
||||
override_pooler_config: Optional["PoolerConfig"],
|
||||
) -> Optional["PoolerConfig"]:
|
||||
def _init_pooler_config(self) -> Optional["PoolerConfig"]:
|
||||
|
||||
if self.runner_type == "pooling":
|
||||
user_config = override_pooler_config or PoolerConfig()
|
||||
if isinstance(self.override_pooler_config, dict):
|
||||
self.override_pooler_config = PoolerConfig(
|
||||
**self.override_pooler_config)
|
||||
|
||||
pooler_config = self.override_pooler_config or PoolerConfig()
|
||||
|
||||
base_config = get_pooling_config(self.model, self.revision)
|
||||
if base_config is not None:
|
||||
# Only set values that are not overridden by the user
|
||||
for k, v in base_config.items():
|
||||
if getattr(user_config, k) is None:
|
||||
setattr(user_config, k, v)
|
||||
if getattr(pooler_config, k) is None:
|
||||
setattr(pooler_config, k, v)
|
||||
|
||||
if self.is_matryoshka:
|
||||
if user_config.normalize is None:
|
||||
user_config.normalize = True
|
||||
elif not user_config.normalize:
|
||||
if pooler_config.normalize is None:
|
||||
pooler_config.normalize = True
|
||||
elif not pooler_config.normalize:
|
||||
raise ValueError(
|
||||
"`normalize` must be enabled (set to True) "
|
||||
"for models that are compatible with "
|
||||
"Matryoshka Representation.")
|
||||
|
||||
return user_config
|
||||
return pooler_config
|
||||
|
||||
return None
|
||||
|
||||
@ -662,11 +670,11 @@ class ModelConfig:
|
||||
return self.registry.model_has_inner_state(self.architectures)
|
||||
|
||||
def _verify_tokenizer_mode(self) -> None:
|
||||
tokenizer_mode = self.tokenizer_mode.lower()
|
||||
if tokenizer_mode not in ["auto", "slow", "mistral", "custom"]:
|
||||
tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower())
|
||||
if tokenizer_mode not in get_args(TokenizerMode):
|
||||
raise ValueError(
|
||||
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
|
||||
"either 'auto', 'slow', 'mistral' or 'custom'.")
|
||||
f"one of {get_args(TokenizerMode)}.")
|
||||
self.tokenizer_mode = tokenizer_mode
|
||||
|
||||
def _get_preferred_task(
|
||||
@ -781,7 +789,8 @@ class ModelConfig:
|
||||
"quark", "nvfp4", "bitblas", "gptq_bitblas"
|
||||
]
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
self.quantization = cast(QuantizationMethods,
|
||||
self.quantization.lower())
|
||||
|
||||
# Parse quantization method from the HF model config, if available.
|
||||
quant_cfg = self._parse_quant_hf_config()
|
||||
@ -857,8 +866,6 @@ class ModelConfig:
|
||||
"non-quantized models.", self.quantization)
|
||||
|
||||
def _verify_cuda_graph(self) -> None:
|
||||
if self.max_seq_len_to_capture is None:
|
||||
self.max_seq_len_to_capture = self.max_model_len
|
||||
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
|
||||
self.max_model_len)
|
||||
ROCM_UNSUPPORTED_MODELS = ['mllama']
|
||||
@ -1294,7 +1301,7 @@ class ModelConfig:
|
||||
|
||||
@property
|
||||
def runner_type(self) -> RunnerType:
|
||||
return _TASK_RUNNER[self.task]
|
||||
return _TASK_RUNNER[cast(_ResolvedTask, self.task)]
|
||||
|
||||
@property
|
||||
def is_v1_compatible(self) -> bool:
|
||||
@ -2201,7 +2208,7 @@ class SpeculativeConfig:
|
||||
according to the log probability settings in SamplingParams."""
|
||||
|
||||
# Draft model configuration
|
||||
quantization: Optional[str] = None
|
||||
quantization: Optional[QuantizationMethods] = None
|
||||
"""Quantization method that was used to quantize the draft model weights.
|
||||
If `None`, we assume the model weights are not quantized. Note that it only
|
||||
takes effect when using the draft model-based speculative method."""
|
||||
@ -2386,7 +2393,6 @@ class SpeculativeConfig:
|
||||
code_revision=self.code_revision,
|
||||
tokenizer_revision=self.target_model_config.
|
||||
tokenizer_revision,
|
||||
max_model_len=None,
|
||||
spec_target_max_model_len=self.target_model_config.
|
||||
max_model_len,
|
||||
quantization=self.quantization,
|
||||
@ -2793,30 +2799,31 @@ class PromptAdapterConfig:
|
||||
class MultiModalConfig:
|
||||
"""Controls the behavior of multimodal models."""
|
||||
|
||||
limit_per_prompt: dict[str, int] = field(default_factory=dict)
|
||||
limit_per_prompt: dict[str, int] = get_field(ModelConfig,
|
||||
"limit_mm_per_prompt")
|
||||
"""
|
||||
The maximum number of input items allowed per prompt for each modality.
|
||||
This should be a JSON string that will be parsed into a dictionary.
|
||||
Defaults to 1 (V0) or 999 (V1) for each modality.
|
||||
|
||||
For example, to allow up to 16 images and 2 videos per prompt:
|
||||
:code:`{"images": 16, "videos": 2}`
|
||||
`{"images": 16, "videos": 2}`
|
||||
"""
|
||||
|
||||
mm_processor_kwargs: Optional[dict[str, object]] = None
|
||||
"""
|
||||
Overrides for the multi-modal processor obtained from
|
||||
:meth:`transformers.AutoProcessor.from_pretrained`.
|
||||
`transformers.AutoProcessor.from_pretrained`.
|
||||
|
||||
The available overrides depend on the model that is being run.
|
||||
|
||||
For example, for Phi-3-Vision:
|
||||
:code:`{"num_crops": 4}`.
|
||||
`{"num_crops": 4}`.
|
||||
"""
|
||||
|
||||
disable_mm_preprocessor_cache: bool = False
|
||||
"""
|
||||
If :code:`True`, disable caching of the processed multi-modal inputs.
|
||||
If `True`, disable caching of the processed multi-modal inputs.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
@ -2907,10 +2914,6 @@ class PoolerConfig:
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@staticmethod
|
||||
def from_json(json_str: str) -> "PoolerConfig":
|
||||
return PoolerConfig(**json.loads(json_str))
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"half": torch.float16,
|
||||
|
||||
@ -20,15 +20,16 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||
DeviceConfig, DistributedExecutorBackend,
|
||||
GuidedDecodingBackend, GuidedDecodingBackendV1,
|
||||
HfOverrides, KVTransferConfig, LoadConfig, LoadFormat,
|
||||
LoRAConfig, ModelConfig, ModelImpl, MultiModalConfig,
|
||||
ObservabilityConfig, ParallelConfig, PoolerConfig,
|
||||
PrefixCachingHashAlgo, PromptAdapterConfig,
|
||||
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
|
||||
TaskOption, TokenizerPoolConfig, VllmConfig,
|
||||
get_attr_docs, get_field)
|
||||
LoRAConfig, ModelConfig, ModelDType, ModelImpl,
|
||||
MultiModalConfig, ObservabilityConfig, ParallelConfig,
|
||||
PoolerConfig, PrefixCachingHashAlgo,
|
||||
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
|
||||
SpeculativeConfig, TaskOption, TokenizerMode,
|
||||
TokenizerPoolConfig, VllmConfig, get_attr_docs,
|
||||
get_field)
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.plugins import load_general_plugins
|
||||
from vllm.reasoning import ReasoningParserManager
|
||||
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
||||
@ -183,6 +184,9 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
kwargs[name]["nargs"] = "+"
|
||||
elif contains_type(type_hints, int):
|
||||
kwargs[name]["type"] = int
|
||||
# Special case for large integers
|
||||
if name in {"max_model_len"}:
|
||||
kwargs[name]["type"] = human_readable_int
|
||||
elif contains_type(type_hints, float):
|
||||
kwargs[name]["type"] = float
|
||||
elif contains_type(type_hints, dict):
|
||||
@ -212,22 +216,23 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
@dataclass
|
||||
class EngineArgs:
|
||||
"""Arguments for vLLM engine."""
|
||||
model: str = 'facebook/opt-125m'
|
||||
served_model_name: Optional[Union[str, List[str]]] = None
|
||||
tokenizer: Optional[str] = None
|
||||
hf_config_path: Optional[str] = None
|
||||
task: TaskOption = "auto"
|
||||
skip_tokenizer_init: bool = False
|
||||
tokenizer_mode: str = 'auto'
|
||||
trust_remote_code: bool = False
|
||||
allowed_local_media_path: str = ""
|
||||
model: str = ModelConfig.model
|
||||
served_model_name: Optional[Union[
|
||||
str, List[str]]] = ModelConfig.served_model_name
|
||||
tokenizer: Optional[str] = ModelConfig.tokenizer
|
||||
hf_config_path: Optional[str] = ModelConfig.hf_config_path
|
||||
task: TaskOption = ModelConfig.task
|
||||
skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
|
||||
tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
|
||||
trust_remote_code: bool = ModelConfig.trust_remote_code
|
||||
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
|
||||
download_dir: Optional[str] = LoadConfig.download_dir
|
||||
load_format: str = LoadConfig.load_format
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO
|
||||
dtype: str = 'auto'
|
||||
config_format: str = ModelConfig.config_format
|
||||
dtype: ModelDType = ModelConfig.dtype
|
||||
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
|
||||
seed: Optional[int] = None
|
||||
max_model_len: Optional[int] = None
|
||||
seed: Optional[int] = ModelConfig.seed
|
||||
max_model_len: Optional[int] = ModelConfig.max_model_len
|
||||
# Note: Specifying a custom executor backend by passing a class
|
||||
# is intended for expert use only. The API may change without
|
||||
# notice.
|
||||
@ -245,8 +250,8 @@ class EngineArgs:
|
||||
enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
|
||||
prefix_caching_hash_algo: PrefixCachingHashAlgo = \
|
||||
CacheConfig.prefix_caching_hash_algo
|
||||
disable_sliding_window: bool = False
|
||||
disable_cascade_attn: bool = False
|
||||
disable_sliding_window: bool = ModelConfig.disable_sliding_window
|
||||
disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
|
||||
use_v2_block_manager: bool = True
|
||||
swap_space: float = CacheConfig.swap_space
|
||||
cpu_offload_gb: float = CacheConfig.cpu_offload_gb
|
||||
@ -258,18 +263,19 @@ class EngineArgs:
|
||||
long_prefill_token_threshold: int = \
|
||||
SchedulerConfig.long_prefill_token_threshold
|
||||
max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
|
||||
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
|
||||
max_logprobs: int = ModelConfig.max_logprobs
|
||||
disable_log_stats: bool = False
|
||||
revision: Optional[str] = None
|
||||
code_revision: Optional[str] = None
|
||||
rope_scaling: Optional[Dict[str, Any]] = None
|
||||
rope_theta: Optional[float] = None
|
||||
hf_token: Optional[Union[bool, str]] = None
|
||||
hf_overrides: Optional[HfOverrides] = None
|
||||
tokenizer_revision: Optional[str] = None
|
||||
quantization: Optional[str] = None
|
||||
enforce_eager: Optional[bool] = None
|
||||
max_seq_len_to_capture: int = 8192
|
||||
revision: Optional[str] = ModelConfig.revision
|
||||
code_revision: Optional[str] = ModelConfig.code_revision
|
||||
rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
|
||||
rope_theta: Optional[float] = ModelConfig.rope_theta
|
||||
hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token
|
||||
hf_overrides: Optional[HfOverrides] = \
|
||||
get_field(ModelConfig, "hf_overrides")
|
||||
tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
|
||||
quantization: Optional[QuantizationMethods] = ModelConfig.quantization
|
||||
enforce_eager: bool = ModelConfig.enforce_eager
|
||||
max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture
|
||||
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
|
||||
# The following three fields are deprecated and will be removed in a future
|
||||
# release. Setting them will have no effect. Please remove them from your
|
||||
@ -280,8 +286,10 @@ class EngineArgs:
|
||||
get_field(TokenizerPoolConfig, "extra_config")
|
||||
limit_mm_per_prompt: dict[str, int] = \
|
||||
get_field(MultiModalConfig, "limit_per_prompt")
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
||||
disable_mm_preprocessor_cache: bool = False
|
||||
mm_processor_kwargs: Optional[Dict[str, Any]] = \
|
||||
MultiModalConfig.mm_processor_kwargs
|
||||
disable_mm_preprocessor_cache: bool = \
|
||||
MultiModalConfig.disable_mm_preprocessor_cache
|
||||
# LoRA fields
|
||||
enable_lora: bool = False
|
||||
enable_lora_bias: bool = LoRAConfig.bias_enabled
|
||||
@ -323,7 +331,8 @@ class EngineArgs:
|
||||
DecodingConfig.disable_any_whitespace
|
||||
guided_decoding_disable_additional_properties: bool = \
|
||||
DecodingConfig.disable_additional_properties
|
||||
logits_processor_pattern: Optional[str] = None
|
||||
logits_processor_pattern: Optional[
|
||||
str] = ModelConfig.logits_processor_pattern
|
||||
|
||||
speculative_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
@ -331,22 +340,25 @@ class EngineArgs:
|
||||
show_hidden_metrics_for_version: Optional[str] = None
|
||||
otlp_traces_endpoint: Optional[str] = None
|
||||
collect_detailed_traces: Optional[str] = None
|
||||
disable_async_output_proc: bool = False
|
||||
disable_async_output_proc: bool = not ModelConfig.use_async_output_proc
|
||||
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
|
||||
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
|
||||
|
||||
override_neuron_config: Optional[Dict[str, Any]] = None
|
||||
override_pooler_config: Optional[PoolerConfig] = None
|
||||
override_neuron_config: dict[str, Any] = \
|
||||
get_field(ModelConfig, "override_neuron_config")
|
||||
override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
|
||||
ModelConfig.override_pooler_config
|
||||
compilation_config: Optional[CompilationConfig] = None
|
||||
worker_cls: str = ParallelConfig.worker_cls
|
||||
worker_extension_cls: str = ParallelConfig.worker_extension_cls
|
||||
|
||||
kv_transfer_config: Optional[KVTransferConfig] = None
|
||||
|
||||
generation_config: Optional[str] = "auto"
|
||||
override_generation_config: Optional[Dict[str, Any]] = None
|
||||
enable_sleep_mode: bool = False
|
||||
model_impl: str = "auto"
|
||||
generation_config: str = ModelConfig.generation_config
|
||||
enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
|
||||
override_generation_config: dict[str, Any] = \
|
||||
get_field(ModelConfig, "override_generation_config")
|
||||
model_impl: str = ModelConfig.model_impl
|
||||
|
||||
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
||||
|
||||
@ -356,9 +368,6 @@ class EngineArgs:
|
||||
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
|
||||
|
||||
def __post_init__(self):
|
||||
if not self.tokenizer:
|
||||
self.tokenizer = self.model
|
||||
|
||||
# support `EngineArgs(compilation_config={...})`
|
||||
# without having to manually construct a
|
||||
# CompilationConfig object
|
||||
@ -375,80 +384,87 @@ class EngineArgs:
|
||||
"""Shared CLI arguments for vLLM engine."""
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument(
|
||||
'--model',
|
||||
type=str,
|
||||
default=EngineArgs.model,
|
||||
help='Name or path of the huggingface model to use.')
|
||||
parser.add_argument(
|
||||
'--task',
|
||||
default=EngineArgs.task,
|
||||
choices=get_args(TaskOption),
|
||||
help='The task to use the model for. Each vLLM instance only '
|
||||
'supports one task, even if the same model can be used for '
|
||||
'multiple tasks. When the model only supports one task, ``"auto"`` '
|
||||
'can be used to select it; otherwise, you must specify explicitly '
|
||||
'which task to use.')
|
||||
parser.add_argument(
|
||||
'--tokenizer',
|
||||
type=optional_type(str),
|
||||
default=EngineArgs.tokenizer,
|
||||
help='Name or path of the huggingface tokenizer to use. '
|
||||
'If unspecified, model name or path will be used.')
|
||||
parser.add_argument(
|
||||
"--hf-config-path",
|
||||
type=optional_type(str),
|
||||
default=EngineArgs.hf_config_path,
|
||||
help='Name or path of the huggingface config to use. '
|
||||
'If unspecified, model name or path will be used.')
|
||||
parser.add_argument(
|
||||
'--skip-tokenizer-init',
|
||||
action='store_true',
|
||||
help='Skip initialization of tokenizer and detokenizer. '
|
||||
'Expects valid prompt_token_ids and None for prompt from '
|
||||
'the input. The generated output will contain token ids.')
|
||||
parser.add_argument(
|
||||
'--revision',
|
||||
type=optional_type(str),
|
||||
default=None,
|
||||
help='The specific model version to use. It can be a branch '
|
||||
'name, a tag name, or a commit id. If unspecified, will use '
|
||||
'the default version.')
|
||||
parser.add_argument(
|
||||
'--code-revision',
|
||||
type=optional_type(str),
|
||||
default=None,
|
||||
help='The specific revision to use for the model code on '
|
||||
'Hugging Face Hub. It can be a branch name, a tag name, or a '
|
||||
'commit id. If unspecified, will use the default version.')
|
||||
parser.add_argument(
|
||||
'--tokenizer-revision',
|
||||
type=optional_type(str),
|
||||
default=None,
|
||||
help='Revision of the huggingface tokenizer to use. '
|
||||
'It can be a branch name, a tag name, or a commit id. '
|
||||
'If unspecified, will use the default version.')
|
||||
parser.add_argument(
|
||||
'--tokenizer-mode',
|
||||
type=str,
|
||||
default=EngineArgs.tokenizer_mode,
|
||||
choices=['auto', 'slow', 'mistral', 'custom'],
|
||||
help='The tokenizer mode.\n\n* "auto" will use the '
|
||||
'fast tokenizer if available.\n* "slow" will '
|
||||
'always use the slow tokenizer. \n* '
|
||||
'"mistral" will always use the `mistral_common` tokenizer. \n* '
|
||||
'"custom" will use --tokenizer to select the '
|
||||
'preregistered tokenizer.')
|
||||
parser.add_argument('--trust-remote-code',
|
||||
action='store_true',
|
||||
help='Trust remote code from huggingface.')
|
||||
parser.add_argument(
|
||||
'--allowed-local-media-path',
|
||||
type=str,
|
||||
help="Allowing API requests to read local images or videos "
|
||||
"from directories specified by the server file system. "
|
||||
"This is a security risk. "
|
||||
"Should only be enabled in trusted environments.")
|
||||
model_kwargs = get_kwargs(ModelConfig)
|
||||
model_group = parser.add_argument_group(
|
||||
title="ModelConfig",
|
||||
description=ModelConfig.__doc__,
|
||||
)
|
||||
model_group.add_argument("--model", **model_kwargs["model"])
|
||||
model_group.add_argument("--task", **model_kwargs["task"])
|
||||
model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
|
||||
model_group.add_argument("--tokenizer-mode",
|
||||
**model_kwargs["tokenizer_mode"])
|
||||
model_group.add_argument("--trust-remote-code",
|
||||
**model_kwargs["trust_remote_code"])
|
||||
model_group.add_argument("--dtype", **model_kwargs["dtype"])
|
||||
model_group.add_argument("--seed", **model_kwargs["seed"])
|
||||
model_group.add_argument("--hf-config-path",
|
||||
**model_kwargs["hf_config_path"])
|
||||
model_group.add_argument("--allowed-local-media-path",
|
||||
**model_kwargs["allowed_local_media_path"])
|
||||
model_group.add_argument("--revision", **model_kwargs["revision"])
|
||||
model_group.add_argument("--code-revision",
|
||||
**model_kwargs["code_revision"])
|
||||
model_group.add_argument("--rope-scaling",
|
||||
**model_kwargs["rope_scaling"])
|
||||
model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"])
|
||||
model_group.add_argument("--tokenizer-revision",
|
||||
**model_kwargs["tokenizer_revision"])
|
||||
model_group.add_argument("--max-model-len",
|
||||
**model_kwargs["max_model_len"])
|
||||
model_group.add_argument("--quantization", "-q",
|
||||
**model_kwargs["quantization"])
|
||||
model_group.add_argument("--enforce-eager",
|
||||
**model_kwargs["enforce_eager"])
|
||||
model_group.add_argument("--max-seq-len-to-capture",
|
||||
**model_kwargs["max_seq_len_to_capture"])
|
||||
model_group.add_argument("--max-logprobs",
|
||||
**model_kwargs["max_logprobs"])
|
||||
model_group.add_argument("--disable-sliding-window",
|
||||
**model_kwargs["disable_sliding_window"])
|
||||
model_group.add_argument("--disable-cascade-attn",
|
||||
**model_kwargs["disable_cascade_attn"])
|
||||
model_group.add_argument("--skip-tokenizer-init",
|
||||
**model_kwargs["skip_tokenizer_init"])
|
||||
model_group.add_argument("--served-model-name",
|
||||
**model_kwargs["served_model_name"])
|
||||
# This one is a special case because it is the
|
||||
# opposite of ModelConfig.use_async_output_proc
|
||||
model_group.add_argument(
|
||||
"--disable-async-output-proc",
|
||||
action="store_true",
|
||||
default=EngineArgs.disable_async_output_proc,
|
||||
help="Disable async output processing. This may result in "
|
||||
"lower performance.")
|
||||
model_group.add_argument("--config-format",
|
||||
choices=[f.value for f in ConfigFormat],
|
||||
**model_kwargs["config_format"])
|
||||
# This one is a special case because it can bool
|
||||
# or str. TODO: Handle this in get_kwargs
|
||||
model_group.add_argument("--hf-token",
|
||||
type=str,
|
||||
nargs="?",
|
||||
const=True,
|
||||
default=model_kwargs["hf_token"]["default"],
|
||||
help=model_kwargs["hf_token"]["help"])
|
||||
model_group.add_argument("--hf-overrides",
|
||||
**model_kwargs["hf_overrides"])
|
||||
model_group.add_argument("--override-neuron-config",
|
||||
**model_kwargs["override_neuron_config"])
|
||||
model_group.add_argument("--override-pooler-config",
|
||||
**model_kwargs["override_pooler_config"])
|
||||
model_group.add_argument("--logits-processor-pattern",
|
||||
**model_kwargs["logits_processor_pattern"])
|
||||
model_group.add_argument("--generation-config",
|
||||
**model_kwargs["generation_config"])
|
||||
model_group.add_argument("--override-generation-config",
|
||||
**model_kwargs["override_generation_config"])
|
||||
model_group.add_argument("--enable-sleep-mode",
|
||||
**model_kwargs["enable_sleep_mode"])
|
||||
model_group.add_argument("--model-impl",
|
||||
choices=[f.value for f in ModelImpl],
|
||||
**model_kwargs["model_impl"])
|
||||
|
||||
# Model loading arguments
|
||||
load_kwargs = get_kwargs(LoadConfig)
|
||||
load_group = parser.add_argument_group(
|
||||
@ -465,38 +481,6 @@ class EngineArgs:
|
||||
load_group.add_argument('--use-tqdm-on-load',
|
||||
**load_kwargs["use_tqdm_on_load"])
|
||||
|
||||
parser.add_argument(
|
||||
'--config-format',
|
||||
default=EngineArgs.config_format,
|
||||
choices=[f.value for f in ConfigFormat],
|
||||
help='The format of the model config to load.\n\n'
|
||||
'* "auto" will try to load the config in hf format '
|
||||
'if available else it will try to load in mistral format ')
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
default=EngineArgs.dtype,
|
||||
choices=[
|
||||
'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
|
||||
],
|
||||
help='Data type for model weights and activations.\n\n'
|
||||
'* "auto" will use FP16 precision for FP32 and FP16 models, and '
|
||||
'BF16 precision for BF16 models.\n'
|
||||
'* "half" for FP16. Recommended for AWQ quantization.\n'
|
||||
'* "float16" is the same as "half".\n'
|
||||
'* "bfloat16" for a balance between precision and range.\n'
|
||||
'* "float" is shorthand for FP32 precision.\n'
|
||||
'* "float32" for FP32 precision.')
|
||||
parser.add_argument('--max-model-len',
|
||||
type=human_readable_int,
|
||||
default=EngineArgs.max_model_len,
|
||||
help='Model context length. If unspecified, will '
|
||||
'be automatically derived from the model config. '
|
||||
'Supports k/m/g/K/M/G in human-readable format.\n'
|
||||
'Examples:\n'
|
||||
'- 1k → 1000\n'
|
||||
'- 1K → 1024\n')
|
||||
|
||||
# Guided decoding arguments
|
||||
guided_decoding_kwargs = get_kwargs(DecodingConfig)
|
||||
guided_decoding_group = parser.add_argument_group(
|
||||
@ -520,26 +504,6 @@ class EngineArgs:
|
||||
choices=list(ReasoningParserManager.reasoning_parsers),
|
||||
**guided_decoding_kwargs["reasoning_backend"])
|
||||
|
||||
parser.add_argument(
|
||||
'--logits-processor-pattern',
|
||||
type=optional_type(str),
|
||||
default=None,
|
||||
help='Optional regex pattern specifying valid logits processor '
|
||||
'qualified names that can be passed with the `logits_processors` '
|
||||
'extra completion argument. Defaults to None, which allows no '
|
||||
'processors.')
|
||||
parser.add_argument(
|
||||
'--model-impl',
|
||||
type=str,
|
||||
default=EngineArgs.model_impl,
|
||||
choices=[f.value for f in ModelImpl],
|
||||
help='Which implementation of the model to use.\n\n'
|
||||
'* "auto" will try to use the vLLM implementation if it exists '
|
||||
'and fall back to the Transformers implementation if no vLLM '
|
||||
'implementation is available.\n'
|
||||
'* "vllm" will use the vLLM model implementation.\n'
|
||||
'* "transformers" will use the Transformers model '
|
||||
'implementation.\n')
|
||||
# Parallel arguments
|
||||
parallel_kwargs = get_kwargs(ParallelConfig)
|
||||
parallel_group = parser.add_argument_group(
|
||||
@ -592,10 +556,6 @@ class EngineArgs:
|
||||
cache_group.add_argument('--calculate-kv-scales',
|
||||
**cache_kwargs["calculate_kv_scales"])
|
||||
|
||||
parser.add_argument('--disable-sliding-window',
|
||||
action='store_true',
|
||||
help='Disables sliding window, '
|
||||
'capping to sliding window size.')
|
||||
parser.add_argument('--use-v2-block-manager',
|
||||
action='store_true',
|
||||
default=True,
|
||||
@ -605,73 +565,9 @@ class EngineArgs:
|
||||
'Setting this flag to True or False'
|
||||
' has no effect on vLLM behavior.')
|
||||
|
||||
parser.add_argument('--seed',
|
||||
type=int,
|
||||
default=EngineArgs.seed,
|
||||
help='Random seed for operations.')
|
||||
parser.add_argument(
|
||||
'--max-logprobs',
|
||||
type=int,
|
||||
default=EngineArgs.max_logprobs,
|
||||
help=('Max number of log probs to return logprobs is specified in'
|
||||
' SamplingParams.'))
|
||||
parser.add_argument('--disable-log-stats',
|
||||
action='store_true',
|
||||
help='Disable logging statistics.')
|
||||
# Quantization settings.
|
||||
parser.add_argument('--quantization',
|
||||
'-q',
|
||||
type=optional_type(str),
|
||||
choices=[*QUANTIZATION_METHODS, None],
|
||||
default=EngineArgs.quantization,
|
||||
help='Method used to quantize the weights. If '
|
||||
'None, we first check the `quantization_config` '
|
||||
'attribute in the model config file. If that is '
|
||||
'None, we assume the model weights are not '
|
||||
'quantized and use `dtype` to determine the data '
|
||||
'type of the weights.')
|
||||
parser.add_argument(
|
||||
'--rope-scaling',
|
||||
default=None,
|
||||
type=json.loads,
|
||||
help='RoPE scaling configuration in JSON format. '
|
||||
'For example, ``{"rope_type":"dynamic","factor":2.0}``')
|
||||
parser.add_argument('--rope-theta',
|
||||
default=None,
|
||||
type=float,
|
||||
help='RoPE theta. Use with `rope_scaling`. In '
|
||||
'some cases, changing the RoPE theta improves the '
|
||||
'performance of the scaled model.')
|
||||
parser.add_argument(
|
||||
'--hf-token',
|
||||
type=str,
|
||||
nargs='?',
|
||||
const=True,
|
||||
default=None,
|
||||
help='The token to use as HTTP bearer authorization'
|
||||
' for remote files. If `True`, will use the token '
|
||||
'generated when running `huggingface-cli login` '
|
||||
'(stored in `~/.huggingface`).')
|
||||
parser.add_argument('--hf-overrides',
|
||||
type=json.loads,
|
||||
default=EngineArgs.hf_overrides,
|
||||
help='Extra arguments for the HuggingFace config. '
|
||||
'This should be a JSON string that will be '
|
||||
'parsed into a dictionary.')
|
||||
parser.add_argument('--enforce-eager',
|
||||
action='store_true',
|
||||
help='Always use eager-mode PyTorch. If False, '
|
||||
'will use eager mode and CUDA graph in hybrid '
|
||||
'for maximal performance and flexibility.')
|
||||
parser.add_argument('--max-seq-len-to-capture',
|
||||
type=int,
|
||||
default=EngineArgs.max_seq_len_to_capture,
|
||||
help='Maximum sequence length covered by CUDA '
|
||||
'graphs. When a sequence has context length '
|
||||
'larger than this, we fall back to eager mode. '
|
||||
'Additionally for encoder-decoder models, if the '
|
||||
'sequence length of the encoder input is larger '
|
||||
'than this, we fall back to the eager mode.')
|
||||
|
||||
# Tokenizer arguments
|
||||
tokenizer_kwargs = get_kwargs(TokenizerPoolConfig)
|
||||
@ -775,20 +671,6 @@ class EngineArgs:
|
||||
"Default to `original/**/*` to avoid repeated loading of llama's "
|
||||
"checkpoints.")
|
||||
|
||||
parser.add_argument(
|
||||
"--served-model-name",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The model name(s) used in the API. If multiple "
|
||||
"names are provided, the server will respond to any "
|
||||
"of the provided names. The model name in the model "
|
||||
"field of a response will be the first name in this "
|
||||
"list. If not specified, the model name will be the "
|
||||
"same as the ``--model`` argument. Noted that this name(s) "
|
||||
"will also be used in `model_name` tag content of "
|
||||
"prometheus metrics, if multiple names provided, metrics "
|
||||
"tag will take the first one.")
|
||||
parser.add_argument('--qlora-adapter-name-or-path',
|
||||
type=str,
|
||||
default=None,
|
||||
@ -822,13 +704,6 @@ class EngineArgs:
|
||||
"modules. This involves use of possibly costly and or blocking "
|
||||
"operations and hence might have a performance impact.")
|
||||
|
||||
parser.add_argument(
|
||||
'--disable-async-output-proc',
|
||||
action='store_true',
|
||||
default=EngineArgs.disable_async_output_proc,
|
||||
help="Disable async output processing. This may result in "
|
||||
"lower performance.")
|
||||
|
||||
# Scheduler arguments
|
||||
scheduler_kwargs = get_kwargs(SchedulerConfig)
|
||||
scheduler_group = parser.add_argument_group(
|
||||
@ -871,19 +746,6 @@ class EngineArgs:
|
||||
parser.add_argument('--scheduler-cls',
|
||||
**scheduler_kwargs["scheduler_cls"])
|
||||
|
||||
parser.add_argument(
|
||||
'--override-neuron-config',
|
||||
type=json.loads,
|
||||
default=None,
|
||||
help="Override or set neuron device configuration. "
|
||||
"e.g. ``{\"cast_logits_dtype\": \"bloat16\"}``.")
|
||||
parser.add_argument(
|
||||
'--override-pooler-config',
|
||||
type=PoolerConfig.from_json,
|
||||
default=None,
|
||||
help="Override or set the pooling method for pooling models. "
|
||||
"e.g. ``{\"pooling_type\": \"mean\", \"normalize\": false}``.")
|
||||
|
||||
parser.add_argument('--compilation-config',
|
||||
'-O',
|
||||
type=CompilationConfig.from_cli,
|
||||
@ -920,34 +782,6 @@ class EngineArgs:
|
||||
help='The worker extension class on top of the worker cls, '
|
||||
'it is useful if you just want to add new functions to the worker '
|
||||
'class without changing the existing functions.')
|
||||
parser.add_argument(
|
||||
"--generation-config",
|
||||
type=optional_type(str),
|
||||
default="auto",
|
||||
help="The folder path to the generation config. "
|
||||
"Defaults to 'auto', the generation config will be loaded from "
|
||||
"model path. If set to 'vllm', no generation config is loaded, "
|
||||
"vLLM defaults will be used. If set to a folder path, the "
|
||||
"generation config will be loaded from the specified folder path. "
|
||||
"If `max_new_tokens` is specified in generation config, then "
|
||||
"it sets a server-wide limit on the number of output tokens "
|
||||
"for all requests.")
|
||||
|
||||
parser.add_argument(
|
||||
"--override-generation-config",
|
||||
type=json.loads,
|
||||
default=None,
|
||||
help="Overrides or sets generation config in JSON format. "
|
||||
"e.g. ``{\"temperature\": 0.5}``. If used with "
|
||||
"--generation-config=auto, the override parameters will be merged "
|
||||
"with the default config from the model. If generation-config is "
|
||||
"None, only the override parameters are used.")
|
||||
|
||||
parser.add_argument("--enable-sleep-mode",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Enable sleep mode for the engine. "
|
||||
"(only cuda platform is supported)")
|
||||
|
||||
parser.add_argument(
|
||||
"--additional-config",
|
||||
@ -966,16 +800,6 @@ class EngineArgs:
|
||||
"If enabled, the model will be able to generate reasoning content."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable-cascade-attn",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Disable cascade attention for V1. While cascade attention "
|
||||
"does not change the mathematical correctness, disabling it "
|
||||
"could be useful for preventing potential numerical issues. "
|
||||
"Note that even if this is set to False, cascade attention will be "
|
||||
"only used when the heuristic tells that it's beneficial.")
|
||||
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
@ -1002,8 +826,7 @@ class EngineArgs:
|
||||
model=self.model,
|
||||
hf_config_path=self.hf_config_path,
|
||||
task=self.task,
|
||||
# We know this is not None because we set it in __post_init__
|
||||
tokenizer=cast(str, self.tokenizer),
|
||||
tokenizer=self.tokenizer,
|
||||
tokenizer_mode=self.tokenizer_mode,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
allowed_local_media_path=self.allowed_local_media_path,
|
||||
|
||||
@ -13,7 +13,7 @@ from typing_extensions import TypeVar, deprecated
|
||||
|
||||
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
||||
BeamSearchSequence, get_beam_search_score)
|
||||
from vllm.config import CompilationConfig
|
||||
from vllm.config import CompilationConfig, ModelDType, TokenizerMode
|
||||
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
|
||||
TaskOption)
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
@ -32,6 +32,7 @@ from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||
GuidedDecodingRequest, LLMGuidedOptions)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
|
||||
PoolingRequestOutput, RequestOutput,
|
||||
ScoringRequestOutput)
|
||||
@ -163,20 +164,20 @@ class LLM:
|
||||
self,
|
||||
model: str,
|
||||
tokenizer: Optional[str] = None,
|
||||
tokenizer_mode: str = "auto",
|
||||
tokenizer_mode: TokenizerMode = "auto",
|
||||
skip_tokenizer_init: bool = False,
|
||||
trust_remote_code: bool = False,
|
||||
allowed_local_media_path: str = "",
|
||||
tensor_parallel_size: int = 1,
|
||||
dtype: str = "auto",
|
||||
quantization: Optional[str] = None,
|
||||
dtype: ModelDType = "auto",
|
||||
quantization: Optional[QuantizationMethods] = None,
|
||||
revision: Optional[str] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
seed: Optional[int] = None,
|
||||
gpu_memory_utilization: float = 0.9,
|
||||
swap_space: float = 4,
|
||||
cpu_offload_gb: float = 0,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
enforce_eager: bool = False,
|
||||
max_seq_len_to_capture: int = 8192,
|
||||
disable_custom_all_reduce: bool = False,
|
||||
disable_async_output_proc: bool = False,
|
||||
@ -189,12 +190,7 @@ class LLM:
|
||||
compilation_config: Optional[Union[int, dict[str, Any]]] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
'''
|
||||
LLM constructor.
|
||||
|
||||
Note: if enforce_eager is unset (enforce_eager is None)
|
||||
it defaults to False.
|
||||
'''
|
||||
"""LLM constructor."""
|
||||
|
||||
if "disable_log_stats" not in kwargs:
|
||||
kwargs["disable_log_stats"] = True
|
||||
|
||||
@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@ -186,7 +187,7 @@ class AQLMConfig(QuantizationConfig):
|
||||
f"out_group_size={self.out_group_size})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "aqlm"
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -7,6 +7,7 @@ import torch
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||
@ -44,7 +45,7 @@ class AWQConfig(QuantizationConfig):
|
||||
f"zero_point={self.zero_point}, "
|
||||
f"modules_to_not_convert={self.modules_to_not_convert})")
|
||||
|
||||
def get_name(self) -> str:
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "awq"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
|
||||
@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
|
||||
is_layer_skipped_awq)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -73,7 +74,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
f"modules_to_not_convert={self.modules_to_not_convert})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "awq_marlin"
|
||||
|
||||
@classmethod
|
||||
@ -101,8 +102,8 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
modules_to_not_convert, config)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "marlin"
|
||||
or user_quant == "awq_marlin")
|
||||
|
||||
@ -2,11 +2,16 @@
|
||||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
else:
|
||||
QuantizationMethods = str
|
||||
|
||||
|
||||
class QuantizeMethodBase(ABC):
|
||||
"""Base class for different quantized methods."""
|
||||
@ -66,7 +71,7 @@ class QuantizationConfig(ABC):
|
||||
self.packed_modules_mapping: Dict[str, List[str]] = dict()
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str:
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
"""Name of the quantization method."""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -99,8 +104,8 @@ class QuantizationConfig(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
"""
|
||||
Detects if this quantization method can support a given checkpoint
|
||||
format by overriding the user specified quantization method --
|
||||
|
||||
@ -5,6 +5,7 @@ import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||
@ -100,7 +101,7 @@ class BitBLASConfig(QuantizationConfig):
|
||||
f"quant_method={self.quant_method})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "bitblas"
|
||||
|
||||
@classmethod
|
||||
@ -139,8 +140,8 @@ class BitBLASConfig(QuantizationConfig):
|
||||
lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||
# compat: autogptq <=0.7.1 is_bitblas_format: bool
|
||||
is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas"
|
||||
|
||||
@ -7,6 +7,7 @@ import torch
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.utils import direct_register_custom_op
|
||||
@ -56,7 +57,7 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
|
||||
|
||||
@classmethod
|
||||
def get_name(self) -> str:
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "bitsandbytes"
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -16,6 +16,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
|
||||
@ -71,7 +72,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def get_name(self) -> str:
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "compressed-tensors"
|
||||
|
||||
def get_quant_method(
|
||||
|
||||
@ -7,6 +7,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@ -41,8 +42,8 @@ class DeepSpeedFPConfig(QuantizationConfig):
|
||||
f"group_size={self.group_size}")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "DeepSpeedFP"
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "deepspeedfp"
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig":
|
||||
|
||||
@ -8,6 +8,7 @@ from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@ -20,7 +21,7 @@ class ExpertsInt8Config(QuantizationConfig):
|
||||
super().__init__()
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "experts_int8"
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -9,6 +9,7 @@ from torch.nn.parameter import Parameter
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
@ -38,7 +39,7 @@ class FBGEMMFp8Config(QuantizationConfig):
|
||||
self.fp8_linear = Fp8LinearOp()
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "fbgemm_fp8"
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
@ -83,7 +84,7 @@ class Fp8Config(QuantizationConfig):
|
||||
self.weight_block_size = weight_block_size
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "fp8"
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -13,6 +13,7 @@ from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
@ -31,7 +32,7 @@ class GGUFConfig(QuantizationConfig):
|
||||
def __repr__(self) -> str:
|
||||
return ("GGUFConfig()")
|
||||
|
||||
def get_name(self) -> str:
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "gguf"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
|
||||
@ -10,6 +10,7 @@ from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||
@ -79,7 +80,7 @@ class GPTQConfig(QuantizationConfig):
|
||||
f"dynamic={self.dynamic}")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq"
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -7,6 +7,7 @@ from torch.nn.parameter import Parameter
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
@ -123,7 +124,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
|
||||
f"quant_method={self.quant_method})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq_bitblas"
|
||||
|
||||
@classmethod
|
||||
@ -151,8 +152,8 @@ class GPTQBitBLASConfig(QuantizationConfig):
|
||||
lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "bitblas"
|
||||
|
||||
@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||
@ -100,7 +101,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
f"dynamic={self.dynamic}")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq_marlin"
|
||||
|
||||
@classmethod
|
||||
@ -130,8 +131,8 @@ class GPTQMarlinConfig(QuantizationConfig):
|
||||
lm_head_quantized, dynamic, config)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
|
||||
|
||||
is_valid_user_quant = (user_quant is None or user_quant == "marlin"
|
||||
|
||||
@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
@ -85,7 +86,7 @@ class GPTQMarlin24Config(QuantizationConfig):
|
||||
self.quant_type, self.group_size)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "gptq_marlin_24"
|
||||
|
||||
@classmethod
|
||||
@ -108,8 +109,8 @@ class GPTQMarlin24Config(QuantizationConfig):
|
||||
return cls(weight_bits, group_size)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
is_marlin_24_format = (
|
||||
hf_quant_cfg.get("checkpoint_format") == "marlin_24")
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
@ -50,7 +51,7 @@ class HQQMarlinConfig(QuantizationConfig):
|
||||
f"group_size={self.group_size})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "hqq"
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -6,6 +6,7 @@ import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod,
|
||||
is_layer_skipped_awq)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
@ -58,7 +59,7 @@ class IPEXConfig(QuantizationConfig):
|
||||
f"group_size={self.group_size})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "ipex"
|
||||
|
||||
@classmethod
|
||||
@ -97,8 +98,8 @@ class IPEXConfig(QuantizationConfig):
|
||||
lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
if not current_platform.is_cpu() and not current_platform.is_xpu():
|
||||
return None
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
@ -63,7 +64,7 @@ class MarlinConfig(QuantizationConfig):
|
||||
f"lm_head_quantized={self.lm_head_quantized})")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "marlin"
|
||||
|
||||
@classmethod
|
||||
@ -87,8 +88,8 @@ class MarlinConfig(QuantizationConfig):
|
||||
return cls(group_size, lm_head_quantized)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
||||
is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin"
|
||||
|
||||
@ -11,6 +11,7 @@ from vllm._custom_ops import (cutlass_scaled_fp4_mm,
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
@ -42,7 +43,7 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
" the format is experimental and could change.")
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "modelopt"
|
||||
|
||||
@classmethod
|
||||
@ -184,8 +185,8 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
self.exclude_modules = exclude_modules
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
return "modelopt_nvfp4"
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "nvfp4"
|
||||
|
||||
@classmethod
|
||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||
|
||||
@ -9,6 +9,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
@ -64,7 +65,7 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "moe_wna16"
|
||||
|
||||
@classmethod
|
||||
@ -100,8 +101,8 @@ class MoeWNA16Config(QuantizationConfig):
|
||||
lm_head_quantized, modules_to_not_convert, config)
|
||||
|
||||
@classmethod
|
||||
def override_quantization_method(cls, hf_quant_cfg,
|
||||
user_quant) -> Optional[str]:
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
|
||||
if can_convert and user_quant == "moe_wna16":
|
||||
return cls.get_name()
|
||||
|
||||
@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from torch.nn import Module
|
||||
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
|
||||
@ -30,7 +31,7 @@ class NeuronQuantConfig(QuantizationConfig):
|
||||
self.dequant_dtype = dequant_dtype
|
||||
self.quantize_method = quantize_method
|
||||
|
||||
def get_name(self) -> str:
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "neuron_quant"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[str]:
|
||||
|
||||
@ -9,6 +9,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
|
||||
@ -50,7 +51,7 @@ class PTPCFp8Config(Fp8Config):
|
||||
ignored_layers=ignored_layers)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "ptpc_fp8"
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||
@ -84,7 +85,7 @@ class QQQConfig(QuantizationConfig):
|
||||
self.weight_bits, self.group_size)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> str:
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
return "qqq"
|
||||
|
||||
@classmethod
|
||||
|
||||
@ -8,6 +8,7 @@ import torch
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
@ -47,7 +48,7 @@ class QuarkConfig(QuantizationConfig):
|
||||
def get_min_capability(cls) -> int:
|
||||
return 70
|
||||
|
||||
def get_name(self) -> str:
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "quark"
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
|
||||
@ -6,6 +6,7 @@ import torch.nn.functional as F
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
@ -20,7 +21,7 @@ class TorchAOConfig(QuantizationConfig):
|
||||
def __repr__(self) -> str:
|
||||
return f"TorchAOConfig({self.torchao_config})"
|
||||
|
||||
def get_name(self) -> str:
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "torchao"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
|
||||
@ -7,6 +7,7 @@ from torch.nn import Module
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.parameter import ModelWeightParameter
|
||||
@ -27,7 +28,7 @@ class Int8TpuConfig(QuantizationConfig):
|
||||
f"Unsupported activation scheme {activation_scheme}")
|
||||
self.activation_scheme = activation_scheme
|
||||
|
||||
def get_name(self) -> str:
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
return "tpu_int8"
|
||||
|
||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||
|
||||
@ -1496,7 +1496,7 @@ def get_rope(
|
||||
if key in _ROPE_DICT:
|
||||
return _ROPE_DICT[key]
|
||||
|
||||
if rope_scaling is None:
|
||||
if not rope_scaling:
|
||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style, dtype)
|
||||
else:
|
||||
|
||||
@ -180,7 +180,6 @@ def _get_neuron_config_after_override(default_neuron_config,
|
||||
NeuronConfig, QuantizationConfig,
|
||||
SparseAttnConfig)
|
||||
|
||||
overridden_neuron_config = overridden_neuron_config or {}
|
||||
sparse_attn = overridden_neuron_config.pop("sparse_attn", {})
|
||||
if sparse_attn:
|
||||
overridden_neuron_config["sparse_attn"] = SparseAttnConfig(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user