mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-18 06:55:01 +08:00
Improve configs - ModelConfig (#17130)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
2c4f59afc3
commit
13698db634
@ -738,7 +738,7 @@ class VllmRunner:
|
|||||||
- `block_size`: Set to `16` instead of `None` to reduce memory usage.
|
- `block_size`: Set to `16` instead of `None` to reduce memory usage.
|
||||||
- `enable_chunked_prefill`: Set to `False` instead of `None` for
|
- `enable_chunked_prefill`: Set to `False` instead of `None` for
|
||||||
test reproducibility.
|
test reproducibility.
|
||||||
- `enforce_eager`: Set to `False` instead of `None` to test CUDA graph.
|
- `enforce_eager`: Set to `False` to test CUDA graph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|||||||
@ -8,7 +8,7 @@ from typing import Literal, Optional
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from vllm.config import PoolerConfig, config
|
from vllm.config import config
|
||||||
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
|
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
|
||||||
get_type, is_not_builtin, is_type,
|
get_type, is_not_builtin, is_type,
|
||||||
literal_to_kwargs, nullable_kvs,
|
literal_to_kwargs, nullable_kvs,
|
||||||
@ -222,17 +222,6 @@ def test_prefix_cache_default():
|
|||||||
assert not engine_args.enable_prefix_caching
|
assert not engine_args.enable_prefix_caching
|
||||||
|
|
||||||
|
|
||||||
def test_valid_pooling_config():
|
|
||||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
|
||||||
args = parser.parse_args([
|
|
||||||
'--override-pooler-config',
|
|
||||||
'{"pooling_type": "MEAN"}',
|
|
||||||
])
|
|
||||||
engine_args = EngineArgs.from_cli_args(args=args)
|
|
||||||
assert engine_args.override_pooler_config == PoolerConfig(
|
|
||||||
pooling_type="MEAN", )
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
("arg"),
|
("arg"),
|
||||||
[
|
[
|
||||||
|
|||||||
@ -14,7 +14,7 @@ import torch.nn.functional as F
|
|||||||
from vllm.model_executor.layers.linear import LinearBase # noqa: E501
|
from vllm.model_executor.layers.linear import LinearBase # noqa: E501
|
||||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||||
from vllm.model_executor.layers.quantization import (
|
from vllm.model_executor.layers.quantization import (
|
||||||
get_quantization_config, register_quantization_config)
|
QuantizationMethods, get_quantization_config, register_quantization_config)
|
||||||
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ class CustomQuantConfig(QuantizationConfig):
|
|||||||
"""Initialize the quantization config."""
|
"""Initialize the quantization config."""
|
||||||
self.num_bits = num_bits
|
self.num_bits = num_bits
|
||||||
|
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> QuantizationMethods:
|
||||||
"""Name of the quantization method."""
|
"""Name of the quantization method."""
|
||||||
return "custom_quant"
|
return "custom_quant"
|
||||||
|
|
||||||
|
|||||||
@ -185,7 +185,7 @@ def test_get_pooling_config():
|
|||||||
revision=None,
|
revision=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
pooling_config = model_config._init_pooler_config(None)
|
pooling_config = model_config._init_pooler_config()
|
||||||
assert pooling_config is not None
|
assert pooling_config is not None
|
||||||
|
|
||||||
assert pooling_config.normalize
|
assert pooling_config.normalize
|
||||||
@ -205,11 +205,12 @@ def test_get_pooling_config_from_args():
|
|||||||
dtype="float16",
|
dtype="float16",
|
||||||
revision=None)
|
revision=None)
|
||||||
|
|
||||||
override_config = PoolerConfig(pooling_type='CLS', normalize=True)
|
override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True)
|
||||||
|
model_config.override_pooler_config = override_pooler_config
|
||||||
|
|
||||||
pooling_config = model_config._init_pooler_config(override_config)
|
pooling_config = model_config._init_pooler_config()
|
||||||
assert pooling_config is not None
|
assert pooling_config is not None
|
||||||
assert asdict(pooling_config) == asdict(override_config)
|
assert asdict(pooling_config) == asdict(override_pooler_config)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||||
|
|||||||
509
vllm/config.py
509
vllm/config.py
@ -16,9 +16,8 @@ from dataclasses import (MISSING, dataclass, field, fields, is_dataclass,
|
|||||||
replace)
|
replace)
|
||||||
from importlib.util import find_spec
|
from importlib.util import find_spec
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Final, Literal,
|
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
|
||||||
Optional, Protocol, TypeVar, Union, cast, get_args,
|
Protocol, TypeVar, Union, cast, get_args, get_origin)
|
||||||
get_origin)
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pydantic import BaseModel, Field, PrivateAttr
|
from pydantic import BaseModel, Field, PrivateAttr
|
||||||
@ -211,103 +210,190 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
|||||||
f"{cls.__name__}.{name} must have a default value or default factory.")
|
f"{cls.__name__}.{name} must have a default value or default factory.")
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig:
|
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
|
||||||
"""Configuration for the model.
|
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
||||||
|
|
||||||
Args:
|
|
||||||
model: Name or path of the huggingface model to use.
|
@config
|
||||||
It is also used as the content for `model_name` tag in metrics
|
@dataclass
|
||||||
output when `served_model_name` is not specified.
|
class ModelConfig:
|
||||||
task: The task to use the model for. Each vLLM instance only supports
|
"""Configuration for the model."""
|
||||||
one task, even if the same model can be used for multiple tasks.
|
|
||||||
When the model only supports one task, "auto" can be used to select
|
model: str = "facebook/opt-125m"
|
||||||
it; otherwise, you must specify explicitly which task to use.
|
"""Name or path of the Hugging Face model to use. It is also used as the
|
||||||
tokenizer: Name or path of the huggingface tokenizer to use.
|
content for `model_name` tag in metrics output when `served_model_name` is
|
||||||
tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if
|
not specified."""
|
||||||
available, "slow" will always use the slow tokenizer,
|
task: Literal[TaskOption, Literal["draft"]] = "auto"
|
||||||
"mistral" will always use the tokenizer from `mistral_common`, and
|
"""The task to use the model for. Each vLLM instance only supports one
|
||||||
"custom" will use --tokenizer to select the preregistered tokenizer.
|
task, even if the same model can be used for multiple tasks. When the model
|
||||||
trust_remote_code: Trust remote code (e.g., from HuggingFace) when
|
only supports one task, "auto" can be used to select it; otherwise, you
|
||||||
downloading the model and tokenizer.
|
must specify explicitly which task to use."""
|
||||||
allowed_local_media_path: Allowing API requests to read local images or
|
tokenizer: str = None # type: ignore
|
||||||
videos from directories specified by the server file system.
|
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
|
||||||
This is a security risk. Should only be enabled in trusted
|
name or path will be used."""
|
||||||
environments.
|
tokenizer_mode: TokenizerMode = "auto"
|
||||||
dtype: Data type for model weights and activations. The "auto" option
|
"""Tokenizer mode:\n
|
||||||
will use FP16 precision for FP32 and FP16 models, and BF16 precision
|
- "auto" will use the fast tokenizer if available.\n
|
||||||
for BF16 models.
|
- "slow" will always use the slow tokenizer.\n
|
||||||
seed: Random seed for reproducibility.
|
- "mistral" will always use the tokenizer from `mistral_common`.\n
|
||||||
revision: The specific model version to use. It can be a branch name,
|
- "custom" will use --tokenizer to select the preregistered tokenizer."""
|
||||||
a tag name, or a commit id. If unspecified, will use the default
|
trust_remote_code: bool = False
|
||||||
version.
|
"""Trust remote code (e.g., from HuggingFace) when downloading the model
|
||||||
code_revision: The specific revision to use for the model code on
|
and tokenizer."""
|
||||||
Hugging Face Hub. It can be a branch name, a tag name, or a
|
dtype: Union[ModelDType, torch.dtype] = "auto"
|
||||||
commit id. If unspecified, will use the default version.
|
"""Data type for model weights and activations:\n
|
||||||
tokenizer_revision: The specific tokenizer version to use. It can be a
|
- "auto" will use FP16 precision for FP32 and FP16 models, and BF16
|
||||||
branch name, a tag name, or a commit id. If unspecified, will use
|
precision for BF16 models.\n
|
||||||
the default version.
|
- "half" for FP16. Recommended for AWQ quantization.\n
|
||||||
max_model_len: Maximum length of a sequence (including prompt and
|
- "float16" is the same as "half".\n
|
||||||
output). If None, will be derived from the model.
|
- "bfloat16" for a balance between precision and range.\n
|
||||||
spec_target_max_model_len: Specify the the maximum length for spec
|
- "float" is shorthand for FP32 precision.\n
|
||||||
decoding draft models.
|
- "float32" for FP32 precision."""
|
||||||
quantization: Quantization method that was used to quantize the model
|
seed: Optional[int] = None
|
||||||
weights. If None, we assume the model weights are not quantized.
|
"""Random seed for reproducibility."""
|
||||||
enforce_eager: Whether to enforce eager execution. If True, we will
|
hf_config_path: Optional[str] = None
|
||||||
disable CUDA graph and always execute the model in eager mode.
|
"""Name or path of the Hugging Face config to use. If unspecified, model
|
||||||
If False, we will use CUDA graph and eager execution in hybrid.
|
name or path will be used."""
|
||||||
If None, the user did not specify, so default to False.
|
allowed_local_media_path: str = ""
|
||||||
max_seq_len_to_capture: Maximum sequence len covered by CUDA graphs.
|
"""Allowing API requests to read local images or videos from directories
|
||||||
When a sequence has context length larger than this, we fall back
|
specified by the server file system. This is a security risk. Should only
|
||||||
to eager mode. Additionally for encoder-decoder models, if the
|
be enabled in trusted environments."""
|
||||||
sequence length of the encoder input is larger than this, we fall
|
revision: Optional[str] = None
|
||||||
back to the eager mode.
|
"""The specific model version to use. It can be a branch name, a tag name,
|
||||||
max_logprobs: Maximum number of log probabilities. Defaults to 20.
|
or a commit id. If unspecified, will use the default version."""
|
||||||
disable_sliding_window: Whether to disable sliding window. If True,
|
code_revision: Optional[str] = None
|
||||||
we will disable the sliding window functionality of the model.
|
"""The specific revision to use for the model code on the Hugging Face Hub.
|
||||||
If the model does not support sliding window, this argument is
|
It can be a branch name, a tag name, or a commit id. If unspecified, will
|
||||||
ignored.
|
use the default version."""
|
||||||
skip_tokenizer_init: If true, skip initialization of tokenizer and
|
rope_scaling: dict[str, Any] = field(default_factory=dict)
|
||||||
detokenizer.
|
"""RoPE scaling configuration in JSON format. For example,
|
||||||
served_model_name: The model name used in metrics tag `model_name`,
|
`{"rope_type":"dynamic","factor":2.0}`."""
|
||||||
matches the model name exposed via the APIs. If multiple model
|
rope_theta: Optional[float] = None
|
||||||
names provided, the first name will be used. If not specified,
|
"""RoPE theta. Use with `rope_scaling`. In some cases, changing the RoPE
|
||||||
the model name will be the same as `model`.
|
theta improves the performance of the scaled model."""
|
||||||
limit_mm_per_prompt: Maximum number of data items per modality
|
tokenizer_revision: Optional[str] = None
|
||||||
per prompt. Only applicable for multimodal models.
|
"""The specific revision to use for the tokenizer on the Hugging Face Hub.
|
||||||
mm_processor_kwargs: Overrides for the multi-modal processor obtained
|
It can be a branch name, a tag name, or a commit id. If unspecified, will
|
||||||
from `AutoProcessor.from_pretrained`.
|
use the default version."""
|
||||||
disable_mm_preprocessor_cache: If True, disable caching of the
|
max_model_len: int = None # type: ignore
|
||||||
processed multi-modal inputs.
|
"""Model context length (prompt and output). If unspecified, will be
|
||||||
use_async_output_proc: Whether to use async output processor.
|
automatically derived from the model config.
|
||||||
Defaults to True.
|
|
||||||
config_format: The config format which shall be loaded.
|
When passing via `--max-model-len`, supports k/m/g/K/M/G in human-readable
|
||||||
Defaults to 'auto' which defaults to 'hf'.
|
format. Examples:\n
|
||||||
hf_token: The token to use as HTTP bearer authorization for remote files
|
- 1k -> 1000\n
|
||||||
. If `True`, will use the token generated when running
|
- 1K -> 1024\n
|
||||||
`huggingface-cli login` (stored in `~/.huggingface`).
|
- 25.6k -> 25,600"""
|
||||||
hf_overrides: If a dictionary, contains arguments to be forwarded to the
|
spec_target_max_model_len: Optional[int] = None
|
||||||
HuggingFace config. If a callable, it is called to update the
|
"""Specify the the maximum length for spec decoding draft models."""
|
||||||
HuggingFace config.
|
quantization: Optional[QuantizationMethods] = None
|
||||||
override_neuron_config: Initialize non default neuron config or
|
"""Method used to quantize the weights. If `None`, we first check the
|
||||||
override default neuron config that are specific to Neuron devices,
|
`quantization_config` attribute in the model config file. If that is
|
||||||
this argument will be used to configure the neuron config that
|
`None`, we assume the model weights are not quantized and use `dtype` to
|
||||||
can not be gathered from the vllm arguments.
|
determine the data type of the weights."""
|
||||||
override_pooler_config: Initialize non default pooling config or
|
enforce_eager: bool = False
|
||||||
override default pooling config for the pooling model.
|
"""Whether to always use eager-mode PyTorch. If True, we will disable CUDA
|
||||||
logits_processor_pattern: Optional regex pattern specifying valid
|
graph and always execute the model in eager mode. If False, we will use
|
||||||
logits processor qualified names that can be passed with the
|
CUDA graph and eager execution in hybrid for maximal performance and
|
||||||
`logits_processors` extra completion argument. Defaults to None,
|
flexibility."""
|
||||||
which allows no processors.
|
max_seq_len_to_capture: int = 8192
|
||||||
generation_config: Configuration parameter file for generation.
|
"""Maximum sequence len covered by CUDA graphs. When a sequence has context
|
||||||
model_impl: Which implementation of the model to use:
|
length larger than this, we fall back to eager mode. Additionally for
|
||||||
"auto" will try to use the vLLM implementation if it exists and
|
encoder-decoder models, if the sequence length of the encoder input is
|
||||||
fall back to the Transformers implementation if no vLLM
|
larger than this, we fall back to the eager mode."""
|
||||||
implementation is available.
|
max_logprobs: int = 20
|
||||||
"vllm" will use the vLLM model implementation.
|
"""Maximum number of log probabilities to return when `logprobs` is
|
||||||
"transformers" will use the Transformers model implementation.
|
specified in `SamplingParams`. The default value comes the default for the
|
||||||
override_generation_config: Override the generation config with the
|
OpenAI Chat Completions API."""
|
||||||
given config.
|
disable_sliding_window: bool = False
|
||||||
"""
|
"""Whether to disable sliding window. If True, we will disable the sliding
|
||||||
|
window functionality of the model, capping to sliding window size. If the
|
||||||
|
model does not support sliding window, this argument is ignored."""
|
||||||
|
disable_cascade_attn: bool = False
|
||||||
|
"""Disable cascade attention for V1. While cascade attention does not
|
||||||
|
change the mathematical correctness, disabling it could be useful for
|
||||||
|
preventing potential numerical issues. Note that even if this is set to
|
||||||
|
False, cascade attention will be only used when the heuristic tells that
|
||||||
|
it's beneficial."""
|
||||||
|
skip_tokenizer_init: bool = False
|
||||||
|
"""Skip initialization of tokenizer and detokenizer. Expects valid
|
||||||
|
`prompt_token_ids` and `None` for prompt from the input. The generated
|
||||||
|
output will contain token ids."""
|
||||||
|
served_model_name: Optional[Union[str, list[str]]] = None
|
||||||
|
"""The model name(s) used in the API. If multiple names are provided, the
|
||||||
|
server will respond to any of the provided names. The model name in the
|
||||||
|
model field of a response will be the first name in this list. If not
|
||||||
|
specified, the model name will be the same as the `--model` argument. Noted
|
||||||
|
that this name(s) will also be used in `model_name` tag content of
|
||||||
|
prometheus metrics, if multiple names provided, metrics tag will take the
|
||||||
|
first one."""
|
||||||
|
limit_mm_per_prompt: dict[str, int] = field(default_factory=dict)
|
||||||
|
"""Maximum number of data items per modality per prompt. Only applicable
|
||||||
|
for multimodal models."""
|
||||||
|
use_async_output_proc: bool = True
|
||||||
|
"""Whether to use async output processor."""
|
||||||
|
config_format: Union[str, ConfigFormat] = ConfigFormat.AUTO.value
|
||||||
|
"""The format of the model config to load:\n
|
||||||
|
- "auto" will try to load the config in hf format if available else it
|
||||||
|
will try to load in mistral format.\n
|
||||||
|
- "hf" will load the config in hf format.\n
|
||||||
|
- "mistral" will load the config in mistral format."""
|
||||||
|
hf_token: Optional[Union[bool, str]] = None
|
||||||
|
"""The token to use as HTTP bearer authorization for remote files . If
|
||||||
|
`True`, will use the token generated when running `huggingface-cli login`
|
||||||
|
(stored in `~/.huggingface`)."""
|
||||||
|
hf_overrides: HfOverrides = field(default_factory=dict)
|
||||||
|
"""If a dictionary, contains arguments to be forwarded to the Hugging Face
|
||||||
|
config. If a callable, it is called to update the HuggingFace config. When
|
||||||
|
specified via CLI, the argument must be a valid JSON string."""
|
||||||
|
mm_processor_kwargs: Optional[dict[str, Any]] = None
|
||||||
|
"""Arguments to be forwarded to the model's processor for multi-modal data,
|
||||||
|
e.g., image processor. Overrides for the multi-modal processor obtained
|
||||||
|
from `AutoProcessor.from_pretrained`. The available overrides depend on the
|
||||||
|
model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`.
|
||||||
|
When specified via CLI, the argument must be a valid JSON string."""
|
||||||
|
disable_mm_preprocessor_cache: bool = False
|
||||||
|
"""If `True`, disable caching of the multi-modal preprocessor/mapper (not
|
||||||
|
recommended)."""
|
||||||
|
override_neuron_config: dict[str, Any] = field(default_factory=dict)
|
||||||
|
"""Initialize non-default neuron config or override default neuron config
|
||||||
|
that are specific to Neuron devices, this argument will be used to
|
||||||
|
configure the neuron config that can not be gathered from the vllm
|
||||||
|
arguments. e.g. `{"cast_logits_dtype": "bloat16"}`. When specified via CLI,
|
||||||
|
the argument must be a valid JSON string."""
|
||||||
|
pooler_config: Optional["PoolerConfig"] = field(init=False)
|
||||||
|
"""Pooler config which controls the behaviour of output pooling in pooling
|
||||||
|
models."""
|
||||||
|
override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None
|
||||||
|
"""Initialize non-default pooling config or override default pooling config
|
||||||
|
for the pooling model. e.g. `{"pooling_type": "mean", "normalize": false}`.
|
||||||
|
When specified via CLI, the argument must be a valid JSON string."""
|
||||||
|
logits_processor_pattern: Optional[str] = None
|
||||||
|
"""Optional regex pattern specifying valid logits processor qualified names
|
||||||
|
that can be passed with the `logits_processors` extra completion argument.
|
||||||
|
Defaults to `None`, which allows no processors."""
|
||||||
|
generation_config: str = "auto"
|
||||||
|
"""The folder path to the generation config. Defaults to `"auto"`, the
|
||||||
|
generation config will be loaded from model path. If set to `"vllm"`, no
|
||||||
|
generation config is loaded, vLLM defaults will be used. If set to a folder
|
||||||
|
path, the generation config will be loaded from the specified folder path.
|
||||||
|
If `max_new_tokens` is specified in generation config, then it sets a
|
||||||
|
server-wide limit on the number of output tokens for all requests."""
|
||||||
|
override_generation_config: dict[str, Any] = field(default_factory=dict)
|
||||||
|
"""Overrides or sets generation config. e.g. `{"temperature": 0.5}`. If
|
||||||
|
used with `--generation-config auto`, the override parameters will be
|
||||||
|
merged with the default config from the model. If used with
|
||||||
|
`--generation-config vllm`, only the override parameters are used.
|
||||||
|
When specified via CLI, the argument must be a valid JSON string."""
|
||||||
|
enable_sleep_mode: bool = False
|
||||||
|
"""Enable sleep mode for the engine (only cuda platform is supported)."""
|
||||||
|
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO.value
|
||||||
|
"""Which implementation of the model to use:\n
|
||||||
|
- "auto" will try to use the vLLM implementation, if it exists, and fall
|
||||||
|
back to the Transformers implementation if no vLLM implementation is
|
||||||
|
available.\n
|
||||||
|
- "vllm" will use the vLLM model implementation.\n
|
||||||
|
- "transformers" will use the Transformers model implementation."""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
"""
|
"""
|
||||||
@ -342,92 +428,43 @@ class ModelConfig:
|
|||||||
assert_hashable(str_factors)
|
assert_hashable(str_factors)
|
||||||
return hashlib.sha256(str(factors).encode()).hexdigest()
|
return hashlib.sha256(str(factors).encode()).hexdigest()
|
||||||
|
|
||||||
def __init__(
|
def __post_init__(self) -> None:
|
||||||
self,
|
self.model = maybe_model_redirect(self.model)
|
||||||
model: str,
|
# The tokenizer is consistent with the model by default.
|
||||||
task: Literal[TaskOption, Literal["draft"]],
|
if self.tokenizer is None:
|
||||||
tokenizer: str,
|
self.tokenizer = self.model
|
||||||
tokenizer_mode: str,
|
if self.tokenizer_revision is None:
|
||||||
trust_remote_code: bool,
|
self.tokenizer_revision = self.revision
|
||||||
dtype: Union[str, torch.dtype],
|
self.tokenizer = maybe_model_redirect(self.tokenizer)
|
||||||
seed: int,
|
|
||||||
hf_config_path: Optional[str] = None,
|
|
||||||
allowed_local_media_path: str = "",
|
|
||||||
revision: Optional[str] = None,
|
|
||||||
code_revision: Optional[str] = None,
|
|
||||||
rope_scaling: Optional[dict[str, Any]] = None,
|
|
||||||
rope_theta: Optional[float] = None,
|
|
||||||
tokenizer_revision: Optional[str] = None,
|
|
||||||
max_model_len: Optional[int] = None,
|
|
||||||
spec_target_max_model_len: Optional[int] = None,
|
|
||||||
quantization: Optional[str] = None,
|
|
||||||
enforce_eager: Optional[bool] = None,
|
|
||||||
max_seq_len_to_capture: Optional[int] = None,
|
|
||||||
max_logprobs: int = 20,
|
|
||||||
disable_sliding_window: bool = False,
|
|
||||||
disable_cascade_attn: bool = False,
|
|
||||||
skip_tokenizer_init: bool = False,
|
|
||||||
served_model_name: Optional[Union[str, list[str]]] = None,
|
|
||||||
limit_mm_per_prompt: Optional[dict[str, int]] = None,
|
|
||||||
mm_processor_kwargs: Optional[dict[str, Any]] = None,
|
|
||||||
disable_mm_preprocessor_cache: bool = False,
|
|
||||||
use_async_output_proc: bool = True,
|
|
||||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
|
||||||
hf_token: Optional[Union[bool, str]] = None,
|
|
||||||
hf_overrides: Optional[HfOverrides] = None,
|
|
||||||
override_neuron_config: Optional[dict[str, Any]] = None,
|
|
||||||
override_pooler_config: Optional["PoolerConfig"] = None,
|
|
||||||
logits_processor_pattern: Optional[str] = None,
|
|
||||||
generation_config: str = "auto",
|
|
||||||
enable_sleep_mode: bool = False,
|
|
||||||
override_generation_config: Optional[dict[str, Any]] = None,
|
|
||||||
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
|
|
||||||
) -> None:
|
|
||||||
self.model = maybe_model_redirect(model)
|
|
||||||
self.tokenizer = maybe_model_redirect(tokenizer)
|
|
||||||
|
|
||||||
self.hf_config_path = hf_config_path
|
if isinstance(self.hf_config_path, str):
|
||||||
if isinstance(hf_config_path, str):
|
self.hf_config_path = maybe_model_redirect(self.hf_config_path)
|
||||||
self.hf_config_path = maybe_model_redirect(hf_config_path)
|
|
||||||
|
|
||||||
self.tokenizer_mode = tokenizer_mode
|
if callable(self.hf_overrides):
|
||||||
self.trust_remote_code = trust_remote_code
|
|
||||||
self.allowed_local_media_path = allowed_local_media_path
|
|
||||||
self.seed = seed
|
|
||||||
self.revision = revision
|
|
||||||
self.code_revision = code_revision
|
|
||||||
self.rope_scaling = rope_scaling
|
|
||||||
self.rope_theta = rope_theta
|
|
||||||
self.model_impl = model_impl
|
|
||||||
|
|
||||||
if hf_overrides is None:
|
|
||||||
hf_overrides = {}
|
|
||||||
|
|
||||||
if callable(hf_overrides):
|
|
||||||
hf_overrides_kw = {}
|
hf_overrides_kw = {}
|
||||||
hf_overrides_fn = hf_overrides
|
hf_overrides_fn = self.hf_overrides
|
||||||
else:
|
else:
|
||||||
hf_overrides_kw = hf_overrides
|
hf_overrides_kw = self.hf_overrides
|
||||||
hf_overrides_fn = None
|
hf_overrides_fn = None
|
||||||
|
|
||||||
if rope_scaling is not None:
|
if self.rope_scaling:
|
||||||
hf_override: dict[str, Any] = {"rope_scaling": rope_scaling}
|
hf_override: dict[str, Any] = {"rope_scaling": self.rope_scaling}
|
||||||
hf_overrides_kw.update(hf_override)
|
hf_overrides_kw.update(hf_override)
|
||||||
hf_overrides_str = json.dumps(hf_overrides)
|
hf_overrides_str = json.dumps(hf_overrides_kw)
|
||||||
msg = (
|
msg = (
|
||||||
"`--rope-scaling` will be removed in a future release. "
|
"`--rope-scaling` will be removed in a future release. "
|
||||||
f"'Please instead use `--hf-overrides '{hf_overrides_str}'`")
|
f"'Please instead use `--hf-overrides '{hf_overrides_str}'`")
|
||||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||||
if rope_theta is not None:
|
if self.rope_theta is not None:
|
||||||
hf_override = {"rope_theta": rope_theta}
|
hf_override = {"rope_theta": self.rope_theta}
|
||||||
hf_overrides_kw.update(hf_override)
|
hf_overrides_kw.update(hf_override)
|
||||||
hf_overrides_str = json.dumps(hf_overrides)
|
hf_overrides_str = json.dumps(hf_overrides_kw)
|
||||||
msg = (
|
msg = (
|
||||||
"`--rope-theta` will be removed in a future release. "
|
"`--rope-theta` will be removed in a future release. "
|
||||||
f"'Please instead use `--hf-overrides '{hf_overrides_str}'`")
|
f"'Please instead use `--hf-overrides '{hf_overrides_str}'`")
|
||||||
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
warnings.warn(DeprecationWarning(msg), stacklevel=2)
|
||||||
|
|
||||||
self.maybe_pull_model_tokenizer_for_s3(model, tokenizer)
|
self.maybe_pull_model_tokenizer_for_s3(self.model, self.tokenizer)
|
||||||
|
|
||||||
if (backend := envs.VLLM_ATTENTION_BACKEND
|
if (backend := envs.VLLM_ATTENTION_BACKEND
|
||||||
) and backend == "FLASHINFER" and find_spec("flashinfer") is None:
|
) and backend == "FLASHINFER" and find_spec("flashinfer") is None:
|
||||||
@ -437,20 +474,6 @@ class ModelConfig:
|
|||||||
"https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501
|
"https://github.com/vllm-project/vllm/blob/main/docker/Dockerfile " # noqa: E501
|
||||||
"for instructions on how to install it.")
|
"for instructions on how to install it.")
|
||||||
|
|
||||||
# The tokenizer version is consistent with the model version by default.
|
|
||||||
if tokenizer_revision is None:
|
|
||||||
self.tokenizer_revision = revision
|
|
||||||
else:
|
|
||||||
self.tokenizer_revision = tokenizer_revision
|
|
||||||
self.quantization = quantization
|
|
||||||
self.enforce_eager = enforce_eager
|
|
||||||
self.max_seq_len_to_capture = max_seq_len_to_capture
|
|
||||||
self.max_logprobs = max_logprobs
|
|
||||||
self.disable_sliding_window = disable_sliding_window
|
|
||||||
self.disable_cascade_attn = disable_cascade_attn
|
|
||||||
self.skip_tokenizer_init = skip_tokenizer_init
|
|
||||||
self.enable_sleep_mode = enable_sleep_mode
|
|
||||||
|
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
if (self.enable_sleep_mode
|
if (self.enable_sleep_mode
|
||||||
@ -458,9 +481,12 @@ class ModelConfig:
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Sleep mode is not supported on current platform.")
|
"Sleep mode is not supported on current platform.")
|
||||||
|
|
||||||
|
if isinstance(self.config_format, str):
|
||||||
|
self.config_format = ConfigFormat(self.config_format)
|
||||||
|
|
||||||
hf_config = get_config(self.hf_config_path or self.model,
|
hf_config = get_config(self.hf_config_path or self.model,
|
||||||
trust_remote_code, revision, code_revision,
|
self.trust_remote_code, self.revision,
|
||||||
config_format)
|
self.code_revision, self.config_format)
|
||||||
|
|
||||||
if hf_overrides_kw:
|
if hf_overrides_kw:
|
||||||
logger.info("Overriding HF config with %s", hf_overrides_kw)
|
logger.info("Overriding HF config with %s", hf_overrides_kw)
|
||||||
@ -476,13 +502,8 @@ class ModelConfig:
|
|||||||
"attention_chunk_size", None)
|
"attention_chunk_size", None)
|
||||||
self.encoder_config = self._get_encoder_config()
|
self.encoder_config = self._get_encoder_config()
|
||||||
self.hf_image_processor_config = get_hf_image_processor_config(
|
self.hf_image_processor_config = get_hf_image_processor_config(
|
||||||
self.model, hf_token=hf_token, revision=revision)
|
self.model, hf_token=self.hf_token, revision=self.revision)
|
||||||
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
|
self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype)
|
||||||
self.use_async_output_proc = use_async_output_proc
|
|
||||||
|
|
||||||
# Set enforce_eager to False if the value is unset.
|
|
||||||
if self.enforce_eager is None:
|
|
||||||
self.enforce_eager = False
|
|
||||||
|
|
||||||
interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"]
|
interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"]
|
||||||
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
|
sliding_window = getattr(self.hf_text_config, "sliding_window", None)
|
||||||
@ -515,18 +536,14 @@ class ModelConfig:
|
|||||||
|
|
||||||
self.max_model_len = _get_and_verify_max_len(
|
self.max_model_len = _get_and_verify_max_len(
|
||||||
hf_config=self.hf_text_config,
|
hf_config=self.hf_text_config,
|
||||||
max_model_len=max_model_len,
|
max_model_len=self.max_model_len,
|
||||||
disable_sliding_window=self.disable_sliding_window,
|
disable_sliding_window=self.disable_sliding_window,
|
||||||
sliding_window_len=self.get_hf_config_sliding_window(),
|
sliding_window_len=self.get_hf_config_sliding_window(),
|
||||||
spec_target_max_model_len=spec_target_max_model_len,
|
spec_target_max_model_len=self.spec_target_max_model_len,
|
||||||
encoder_config=self.encoder_config)
|
encoder_config=self.encoder_config)
|
||||||
self.served_model_name = get_served_model_name(model,
|
self.served_model_name = get_served_model_name(self.model,
|
||||||
served_model_name)
|
self.served_model_name)
|
||||||
self.multimodal_config = self._init_multimodal_config(
|
self.multimodal_config = self._init_multimodal_config()
|
||||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
|
||||||
mm_processor_kwargs=mm_processor_kwargs,
|
|
||||||
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
|
|
||||||
)
|
|
||||||
if not self.skip_tokenizer_init:
|
if not self.skip_tokenizer_init:
|
||||||
self._verify_tokenizer_mode()
|
self._verify_tokenizer_mode()
|
||||||
|
|
||||||
@ -535,24 +552,19 @@ class ModelConfig:
|
|||||||
self.has_noops = self._init_has_noops()
|
self.has_noops = self._init_has_noops()
|
||||||
self.has_inner_state = self._init_has_inner_state()
|
self.has_inner_state = self._init_has_inner_state()
|
||||||
|
|
||||||
if current_platform.is_neuron():
|
if (not current_platform.is_neuron() and self.override_neuron_config):
|
||||||
self.override_neuron_config = override_neuron_config
|
raise ValueError(
|
||||||
else:
|
"`override_neuron_config` is only supported on Neuron.")
|
||||||
self.override_neuron_config = None
|
|
||||||
|
|
||||||
supported_tasks, task = self._resolve_task(task)
|
supported_tasks, task = self._resolve_task(self.task)
|
||||||
self.supported_tasks = supported_tasks
|
self.supported_tasks = supported_tasks
|
||||||
self.task: Final = task
|
self.task = task
|
||||||
if self.task in ("draft", "generate"):
|
if self.task in ("draft", "generate"):
|
||||||
self.truncation_side = "left"
|
self.truncation_side = "left"
|
||||||
else:
|
else:
|
||||||
self.truncation_side = "right"
|
self.truncation_side = "right"
|
||||||
|
|
||||||
self.pooler_config = self._init_pooler_config(override_pooler_config)
|
self.pooler_config = self._init_pooler_config()
|
||||||
self.logits_processor_pattern = logits_processor_pattern
|
|
||||||
|
|
||||||
self.generation_config = generation_config
|
|
||||||
self.override_generation_config = override_generation_config or {}
|
|
||||||
|
|
||||||
self._verify_quantization()
|
self._verify_quantization()
|
||||||
self._verify_cuda_graph()
|
self._verify_cuda_graph()
|
||||||
@ -591,26 +603,21 @@ class ModelConfig:
|
|||||||
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
|
model, ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
|
||||||
self.tokenizer = s3_tokenizer.dir
|
self.tokenizer = s3_tokenizer.dir
|
||||||
|
|
||||||
def _init_multimodal_config(
|
def _init_multimodal_config(self) -> Optional["MultiModalConfig"]:
|
||||||
self,
|
|
||||||
limit_mm_per_prompt: Optional[dict[str, int]],
|
|
||||||
mm_processor_kwargs: Optional[dict[str, Any]],
|
|
||||||
disable_mm_preprocessor_cache: bool,
|
|
||||||
) -> Optional["MultiModalConfig"]:
|
|
||||||
if self.registry.is_multimodal_model(self.architectures):
|
if self.registry.is_multimodal_model(self.architectures):
|
||||||
return MultiModalConfig(
|
return MultiModalConfig(
|
||||||
limit_per_prompt=limit_mm_per_prompt or {},
|
limit_per_prompt=self.limit_mm_per_prompt,
|
||||||
mm_processor_kwargs=mm_processor_kwargs or {},
|
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||||
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
|
disable_mm_preprocessor_cache=self.
|
||||||
)
|
disable_mm_preprocessor_cache)
|
||||||
|
|
||||||
if limit_mm_per_prompt:
|
if self.limit_mm_per_prompt:
|
||||||
raise ValueError("`limit_mm_per_prompt` is only supported for "
|
raise ValueError("`limit_mm_per_prompt` is only supported for "
|
||||||
"multimodal models.")
|
"multimodal models.")
|
||||||
if mm_processor_kwargs:
|
if self.mm_processor_kwargs:
|
||||||
raise ValueError("`mm_processor_kwargs` is only supported for "
|
raise ValueError("`mm_processor_kwargs` is only supported for "
|
||||||
"multimodal models.")
|
"multimodal models.")
|
||||||
if disable_mm_preprocessor_cache:
|
if self.disable_mm_preprocessor_cache:
|
||||||
raise ValueError("`disable_mm_preprocessor_cache` is only "
|
raise ValueError("`disable_mm_preprocessor_cache` is only "
|
||||||
"supported for multimodal models.")
|
"supported for multimodal models.")
|
||||||
|
|
||||||
@ -620,31 +627,32 @@ class ModelConfig:
|
|||||||
return get_sentence_transformer_tokenizer_config(
|
return get_sentence_transformer_tokenizer_config(
|
||||||
self.model, self.revision)
|
self.model, self.revision)
|
||||||
|
|
||||||
def _init_pooler_config(
|
def _init_pooler_config(self) -> Optional["PoolerConfig"]:
|
||||||
self,
|
|
||||||
override_pooler_config: Optional["PoolerConfig"],
|
|
||||||
) -> Optional["PoolerConfig"]:
|
|
||||||
|
|
||||||
if self.runner_type == "pooling":
|
if self.runner_type == "pooling":
|
||||||
user_config = override_pooler_config or PoolerConfig()
|
if isinstance(self.override_pooler_config, dict):
|
||||||
|
self.override_pooler_config = PoolerConfig(
|
||||||
|
**self.override_pooler_config)
|
||||||
|
|
||||||
|
pooler_config = self.override_pooler_config or PoolerConfig()
|
||||||
|
|
||||||
base_config = get_pooling_config(self.model, self.revision)
|
base_config = get_pooling_config(self.model, self.revision)
|
||||||
if base_config is not None:
|
if base_config is not None:
|
||||||
# Only set values that are not overridden by the user
|
# Only set values that are not overridden by the user
|
||||||
for k, v in base_config.items():
|
for k, v in base_config.items():
|
||||||
if getattr(user_config, k) is None:
|
if getattr(pooler_config, k) is None:
|
||||||
setattr(user_config, k, v)
|
setattr(pooler_config, k, v)
|
||||||
|
|
||||||
if self.is_matryoshka:
|
if self.is_matryoshka:
|
||||||
if user_config.normalize is None:
|
if pooler_config.normalize is None:
|
||||||
user_config.normalize = True
|
pooler_config.normalize = True
|
||||||
elif not user_config.normalize:
|
elif not pooler_config.normalize:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"`normalize` must be enabled (set to True) "
|
"`normalize` must be enabled (set to True) "
|
||||||
"for models that are compatible with "
|
"for models that are compatible with "
|
||||||
"Matryoshka Representation.")
|
"Matryoshka Representation.")
|
||||||
|
|
||||||
return user_config
|
return pooler_config
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -662,11 +670,11 @@ class ModelConfig:
|
|||||||
return self.registry.model_has_inner_state(self.architectures)
|
return self.registry.model_has_inner_state(self.architectures)
|
||||||
|
|
||||||
def _verify_tokenizer_mode(self) -> None:
|
def _verify_tokenizer_mode(self) -> None:
|
||||||
tokenizer_mode = self.tokenizer_mode.lower()
|
tokenizer_mode = cast(TokenizerMode, self.tokenizer_mode.lower())
|
||||||
if tokenizer_mode not in ["auto", "slow", "mistral", "custom"]:
|
if tokenizer_mode not in get_args(TokenizerMode):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
|
f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be "
|
||||||
"either 'auto', 'slow', 'mistral' or 'custom'.")
|
f"one of {get_args(TokenizerMode)}.")
|
||||||
self.tokenizer_mode = tokenizer_mode
|
self.tokenizer_mode = tokenizer_mode
|
||||||
|
|
||||||
def _get_preferred_task(
|
def _get_preferred_task(
|
||||||
@ -781,7 +789,8 @@ class ModelConfig:
|
|||||||
"quark", "nvfp4", "bitblas", "gptq_bitblas"
|
"quark", "nvfp4", "bitblas", "gptq_bitblas"
|
||||||
]
|
]
|
||||||
if self.quantization is not None:
|
if self.quantization is not None:
|
||||||
self.quantization = self.quantization.lower()
|
self.quantization = cast(QuantizationMethods,
|
||||||
|
self.quantization.lower())
|
||||||
|
|
||||||
# Parse quantization method from the HF model config, if available.
|
# Parse quantization method from the HF model config, if available.
|
||||||
quant_cfg = self._parse_quant_hf_config()
|
quant_cfg = self._parse_quant_hf_config()
|
||||||
@ -857,8 +866,6 @@ class ModelConfig:
|
|||||||
"non-quantized models.", self.quantization)
|
"non-quantized models.", self.quantization)
|
||||||
|
|
||||||
def _verify_cuda_graph(self) -> None:
|
def _verify_cuda_graph(self) -> None:
|
||||||
if self.max_seq_len_to_capture is None:
|
|
||||||
self.max_seq_len_to_capture = self.max_model_len
|
|
||||||
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
|
self.max_seq_len_to_capture = min(self.max_seq_len_to_capture,
|
||||||
self.max_model_len)
|
self.max_model_len)
|
||||||
ROCM_UNSUPPORTED_MODELS = ['mllama']
|
ROCM_UNSUPPORTED_MODELS = ['mllama']
|
||||||
@ -1294,7 +1301,7 @@ class ModelConfig:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def runner_type(self) -> RunnerType:
|
def runner_type(self) -> RunnerType:
|
||||||
return _TASK_RUNNER[self.task]
|
return _TASK_RUNNER[cast(_ResolvedTask, self.task)]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_v1_compatible(self) -> bool:
|
def is_v1_compatible(self) -> bool:
|
||||||
@ -2201,7 +2208,7 @@ class SpeculativeConfig:
|
|||||||
according to the log probability settings in SamplingParams."""
|
according to the log probability settings in SamplingParams."""
|
||||||
|
|
||||||
# Draft model configuration
|
# Draft model configuration
|
||||||
quantization: Optional[str] = None
|
quantization: Optional[QuantizationMethods] = None
|
||||||
"""Quantization method that was used to quantize the draft model weights.
|
"""Quantization method that was used to quantize the draft model weights.
|
||||||
If `None`, we assume the model weights are not quantized. Note that it only
|
If `None`, we assume the model weights are not quantized. Note that it only
|
||||||
takes effect when using the draft model-based speculative method."""
|
takes effect when using the draft model-based speculative method."""
|
||||||
@ -2386,7 +2393,6 @@ class SpeculativeConfig:
|
|||||||
code_revision=self.code_revision,
|
code_revision=self.code_revision,
|
||||||
tokenizer_revision=self.target_model_config.
|
tokenizer_revision=self.target_model_config.
|
||||||
tokenizer_revision,
|
tokenizer_revision,
|
||||||
max_model_len=None,
|
|
||||||
spec_target_max_model_len=self.target_model_config.
|
spec_target_max_model_len=self.target_model_config.
|
||||||
max_model_len,
|
max_model_len,
|
||||||
quantization=self.quantization,
|
quantization=self.quantization,
|
||||||
@ -2793,30 +2799,31 @@ class PromptAdapterConfig:
|
|||||||
class MultiModalConfig:
|
class MultiModalConfig:
|
||||||
"""Controls the behavior of multimodal models."""
|
"""Controls the behavior of multimodal models."""
|
||||||
|
|
||||||
limit_per_prompt: dict[str, int] = field(default_factory=dict)
|
limit_per_prompt: dict[str, int] = get_field(ModelConfig,
|
||||||
|
"limit_mm_per_prompt")
|
||||||
"""
|
"""
|
||||||
The maximum number of input items allowed per prompt for each modality.
|
The maximum number of input items allowed per prompt for each modality.
|
||||||
This should be a JSON string that will be parsed into a dictionary.
|
This should be a JSON string that will be parsed into a dictionary.
|
||||||
Defaults to 1 (V0) or 999 (V1) for each modality.
|
Defaults to 1 (V0) or 999 (V1) for each modality.
|
||||||
|
|
||||||
For example, to allow up to 16 images and 2 videos per prompt:
|
For example, to allow up to 16 images and 2 videos per prompt:
|
||||||
:code:`{"images": 16, "videos": 2}`
|
`{"images": 16, "videos": 2}`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
mm_processor_kwargs: Optional[dict[str, object]] = None
|
mm_processor_kwargs: Optional[dict[str, object]] = None
|
||||||
"""
|
"""
|
||||||
Overrides for the multi-modal processor obtained from
|
Overrides for the multi-modal processor obtained from
|
||||||
:meth:`transformers.AutoProcessor.from_pretrained`.
|
`transformers.AutoProcessor.from_pretrained`.
|
||||||
|
|
||||||
The available overrides depend on the model that is being run.
|
The available overrides depend on the model that is being run.
|
||||||
|
|
||||||
For example, for Phi-3-Vision:
|
For example, for Phi-3-Vision:
|
||||||
:code:`{"num_crops": 4}`.
|
`{"num_crops": 4}`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
disable_mm_preprocessor_cache: bool = False
|
disable_mm_preprocessor_cache: bool = False
|
||||||
"""
|
"""
|
||||||
If :code:`True`, disable caching of the processed multi-modal inputs.
|
If `True`, disable caching of the processed multi-modal inputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def compute_hash(self) -> str:
|
def compute_hash(self) -> str:
|
||||||
@ -2907,10 +2914,6 @@ class PoolerConfig:
|
|||||||
usedforsecurity=False).hexdigest()
|
usedforsecurity=False).hexdigest()
|
||||||
return hash_str
|
return hash_str
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_json(json_str: str) -> "PoolerConfig":
|
|
||||||
return PoolerConfig(**json.loads(json_str))
|
|
||||||
|
|
||||||
|
|
||||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
"half": torch.float16,
|
"half": torch.float16,
|
||||||
|
|||||||
@ -20,15 +20,16 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
|||||||
DeviceConfig, DistributedExecutorBackend,
|
DeviceConfig, DistributedExecutorBackend,
|
||||||
GuidedDecodingBackend, GuidedDecodingBackendV1,
|
GuidedDecodingBackend, GuidedDecodingBackendV1,
|
||||||
HfOverrides, KVTransferConfig, LoadConfig, LoadFormat,
|
HfOverrides, KVTransferConfig, LoadConfig, LoadFormat,
|
||||||
LoRAConfig, ModelConfig, ModelImpl, MultiModalConfig,
|
LoRAConfig, ModelConfig, ModelDType, ModelImpl,
|
||||||
ObservabilityConfig, ParallelConfig, PoolerConfig,
|
MultiModalConfig, ObservabilityConfig, ParallelConfig,
|
||||||
PrefixCachingHashAlgo, PromptAdapterConfig,
|
PoolerConfig, PrefixCachingHashAlgo,
|
||||||
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
|
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
|
||||||
TaskOption, TokenizerPoolConfig, VllmConfig,
|
SpeculativeConfig, TaskOption, TokenizerMode,
|
||||||
get_attr_docs, get_field)
|
TokenizerPoolConfig, VllmConfig, get_attr_docs,
|
||||||
|
get_field)
|
||||||
from vllm.executor.executor_base import ExecutorBase
|
from vllm.executor.executor_base import ExecutorBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.plugins import load_general_plugins
|
from vllm.plugins import load_general_plugins
|
||||||
from vllm.reasoning import ReasoningParserManager
|
from vllm.reasoning import ReasoningParserManager
|
||||||
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
||||||
@ -183,6 +184,9 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
|||||||
kwargs[name]["nargs"] = "+"
|
kwargs[name]["nargs"] = "+"
|
||||||
elif contains_type(type_hints, int):
|
elif contains_type(type_hints, int):
|
||||||
kwargs[name]["type"] = int
|
kwargs[name]["type"] = int
|
||||||
|
# Special case for large integers
|
||||||
|
if name in {"max_model_len"}:
|
||||||
|
kwargs[name]["type"] = human_readable_int
|
||||||
elif contains_type(type_hints, float):
|
elif contains_type(type_hints, float):
|
||||||
kwargs[name]["type"] = float
|
kwargs[name]["type"] = float
|
||||||
elif contains_type(type_hints, dict):
|
elif contains_type(type_hints, dict):
|
||||||
@ -212,22 +216,23 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class EngineArgs:
|
class EngineArgs:
|
||||||
"""Arguments for vLLM engine."""
|
"""Arguments for vLLM engine."""
|
||||||
model: str = 'facebook/opt-125m'
|
model: str = ModelConfig.model
|
||||||
served_model_name: Optional[Union[str, List[str]]] = None
|
served_model_name: Optional[Union[
|
||||||
tokenizer: Optional[str] = None
|
str, List[str]]] = ModelConfig.served_model_name
|
||||||
hf_config_path: Optional[str] = None
|
tokenizer: Optional[str] = ModelConfig.tokenizer
|
||||||
task: TaskOption = "auto"
|
hf_config_path: Optional[str] = ModelConfig.hf_config_path
|
||||||
skip_tokenizer_init: bool = False
|
task: TaskOption = ModelConfig.task
|
||||||
tokenizer_mode: str = 'auto'
|
skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
|
||||||
trust_remote_code: bool = False
|
tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
|
||||||
allowed_local_media_path: str = ""
|
trust_remote_code: bool = ModelConfig.trust_remote_code
|
||||||
|
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
|
||||||
download_dir: Optional[str] = LoadConfig.download_dir
|
download_dir: Optional[str] = LoadConfig.download_dir
|
||||||
load_format: str = LoadConfig.load_format
|
load_format: str = LoadConfig.load_format
|
||||||
config_format: ConfigFormat = ConfigFormat.AUTO
|
config_format: str = ModelConfig.config_format
|
||||||
dtype: str = 'auto'
|
dtype: ModelDType = ModelConfig.dtype
|
||||||
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
|
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
|
||||||
seed: Optional[int] = None
|
seed: Optional[int] = ModelConfig.seed
|
||||||
max_model_len: Optional[int] = None
|
max_model_len: Optional[int] = ModelConfig.max_model_len
|
||||||
# Note: Specifying a custom executor backend by passing a class
|
# Note: Specifying a custom executor backend by passing a class
|
||||||
# is intended for expert use only. The API may change without
|
# is intended for expert use only. The API may change without
|
||||||
# notice.
|
# notice.
|
||||||
@ -245,8 +250,8 @@ class EngineArgs:
|
|||||||
enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
|
enable_prefix_caching: Optional[bool] = CacheConfig.enable_prefix_caching
|
||||||
prefix_caching_hash_algo: PrefixCachingHashAlgo = \
|
prefix_caching_hash_algo: PrefixCachingHashAlgo = \
|
||||||
CacheConfig.prefix_caching_hash_algo
|
CacheConfig.prefix_caching_hash_algo
|
||||||
disable_sliding_window: bool = False
|
disable_sliding_window: bool = ModelConfig.disable_sliding_window
|
||||||
disable_cascade_attn: bool = False
|
disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
|
||||||
use_v2_block_manager: bool = True
|
use_v2_block_manager: bool = True
|
||||||
swap_space: float = CacheConfig.swap_space
|
swap_space: float = CacheConfig.swap_space
|
||||||
cpu_offload_gb: float = CacheConfig.cpu_offload_gb
|
cpu_offload_gb: float = CacheConfig.cpu_offload_gb
|
||||||
@ -258,18 +263,19 @@ class EngineArgs:
|
|||||||
long_prefill_token_threshold: int = \
|
long_prefill_token_threshold: int = \
|
||||||
SchedulerConfig.long_prefill_token_threshold
|
SchedulerConfig.long_prefill_token_threshold
|
||||||
max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
|
max_num_seqs: Optional[int] = SchedulerConfig.max_num_seqs
|
||||||
max_logprobs: int = 20 # Default value for OpenAI Chat Completions API
|
max_logprobs: int = ModelConfig.max_logprobs
|
||||||
disable_log_stats: bool = False
|
disable_log_stats: bool = False
|
||||||
revision: Optional[str] = None
|
revision: Optional[str] = ModelConfig.revision
|
||||||
code_revision: Optional[str] = None
|
code_revision: Optional[str] = ModelConfig.code_revision
|
||||||
rope_scaling: Optional[Dict[str, Any]] = None
|
rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
|
||||||
rope_theta: Optional[float] = None
|
rope_theta: Optional[float] = ModelConfig.rope_theta
|
||||||
hf_token: Optional[Union[bool, str]] = None
|
hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token
|
||||||
hf_overrides: Optional[HfOverrides] = None
|
hf_overrides: Optional[HfOverrides] = \
|
||||||
tokenizer_revision: Optional[str] = None
|
get_field(ModelConfig, "hf_overrides")
|
||||||
quantization: Optional[str] = None
|
tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
|
||||||
enforce_eager: Optional[bool] = None
|
quantization: Optional[QuantizationMethods] = ModelConfig.quantization
|
||||||
max_seq_len_to_capture: int = 8192
|
enforce_eager: bool = ModelConfig.enforce_eager
|
||||||
|
max_seq_len_to_capture: int = ModelConfig.max_seq_len_to_capture
|
||||||
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
|
disable_custom_all_reduce: bool = ParallelConfig.disable_custom_all_reduce
|
||||||
# The following three fields are deprecated and will be removed in a future
|
# The following three fields are deprecated and will be removed in a future
|
||||||
# release. Setting them will have no effect. Please remove them from your
|
# release. Setting them will have no effect. Please remove them from your
|
||||||
@ -280,8 +286,10 @@ class EngineArgs:
|
|||||||
get_field(TokenizerPoolConfig, "extra_config")
|
get_field(TokenizerPoolConfig, "extra_config")
|
||||||
limit_mm_per_prompt: dict[str, int] = \
|
limit_mm_per_prompt: dict[str, int] = \
|
||||||
get_field(MultiModalConfig, "limit_per_prompt")
|
get_field(MultiModalConfig, "limit_per_prompt")
|
||||||
mm_processor_kwargs: Optional[Dict[str, Any]] = None
|
mm_processor_kwargs: Optional[Dict[str, Any]] = \
|
||||||
disable_mm_preprocessor_cache: bool = False
|
MultiModalConfig.mm_processor_kwargs
|
||||||
|
disable_mm_preprocessor_cache: bool = \
|
||||||
|
MultiModalConfig.disable_mm_preprocessor_cache
|
||||||
# LoRA fields
|
# LoRA fields
|
||||||
enable_lora: bool = False
|
enable_lora: bool = False
|
||||||
enable_lora_bias: bool = LoRAConfig.bias_enabled
|
enable_lora_bias: bool = LoRAConfig.bias_enabled
|
||||||
@ -323,7 +331,8 @@ class EngineArgs:
|
|||||||
DecodingConfig.disable_any_whitespace
|
DecodingConfig.disable_any_whitespace
|
||||||
guided_decoding_disable_additional_properties: bool = \
|
guided_decoding_disable_additional_properties: bool = \
|
||||||
DecodingConfig.disable_additional_properties
|
DecodingConfig.disable_additional_properties
|
||||||
logits_processor_pattern: Optional[str] = None
|
logits_processor_pattern: Optional[
|
||||||
|
str] = ModelConfig.logits_processor_pattern
|
||||||
|
|
||||||
speculative_config: Optional[Dict[str, Any]] = None
|
speculative_config: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
@ -331,22 +340,25 @@ class EngineArgs:
|
|||||||
show_hidden_metrics_for_version: Optional[str] = None
|
show_hidden_metrics_for_version: Optional[str] = None
|
||||||
otlp_traces_endpoint: Optional[str] = None
|
otlp_traces_endpoint: Optional[str] = None
|
||||||
collect_detailed_traces: Optional[str] = None
|
collect_detailed_traces: Optional[str] = None
|
||||||
disable_async_output_proc: bool = False
|
disable_async_output_proc: bool = not ModelConfig.use_async_output_proc
|
||||||
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
|
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
|
||||||
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
|
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
|
||||||
|
|
||||||
override_neuron_config: Optional[Dict[str, Any]] = None
|
override_neuron_config: dict[str, Any] = \
|
||||||
override_pooler_config: Optional[PoolerConfig] = None
|
get_field(ModelConfig, "override_neuron_config")
|
||||||
|
override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
|
||||||
|
ModelConfig.override_pooler_config
|
||||||
compilation_config: Optional[CompilationConfig] = None
|
compilation_config: Optional[CompilationConfig] = None
|
||||||
worker_cls: str = ParallelConfig.worker_cls
|
worker_cls: str = ParallelConfig.worker_cls
|
||||||
worker_extension_cls: str = ParallelConfig.worker_extension_cls
|
worker_extension_cls: str = ParallelConfig.worker_extension_cls
|
||||||
|
|
||||||
kv_transfer_config: Optional[KVTransferConfig] = None
|
kv_transfer_config: Optional[KVTransferConfig] = None
|
||||||
|
|
||||||
generation_config: Optional[str] = "auto"
|
generation_config: str = ModelConfig.generation_config
|
||||||
override_generation_config: Optional[Dict[str, Any]] = None
|
enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
|
||||||
enable_sleep_mode: bool = False
|
override_generation_config: dict[str, Any] = \
|
||||||
model_impl: str = "auto"
|
get_field(ModelConfig, "override_generation_config")
|
||||||
|
model_impl: str = ModelConfig.model_impl
|
||||||
|
|
||||||
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
||||||
|
|
||||||
@ -356,9 +368,6 @@ class EngineArgs:
|
|||||||
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
|
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if not self.tokenizer:
|
|
||||||
self.tokenizer = self.model
|
|
||||||
|
|
||||||
# support `EngineArgs(compilation_config={...})`
|
# support `EngineArgs(compilation_config={...})`
|
||||||
# without having to manually construct a
|
# without having to manually construct a
|
||||||
# CompilationConfig object
|
# CompilationConfig object
|
||||||
@ -375,80 +384,87 @@ class EngineArgs:
|
|||||||
"""Shared CLI arguments for vLLM engine."""
|
"""Shared CLI arguments for vLLM engine."""
|
||||||
|
|
||||||
# Model arguments
|
# Model arguments
|
||||||
parser.add_argument(
|
model_kwargs = get_kwargs(ModelConfig)
|
||||||
'--model',
|
model_group = parser.add_argument_group(
|
||||||
|
title="ModelConfig",
|
||||||
|
description=ModelConfig.__doc__,
|
||||||
|
)
|
||||||
|
model_group.add_argument("--model", **model_kwargs["model"])
|
||||||
|
model_group.add_argument("--task", **model_kwargs["task"])
|
||||||
|
model_group.add_argument("--tokenizer", **model_kwargs["tokenizer"])
|
||||||
|
model_group.add_argument("--tokenizer-mode",
|
||||||
|
**model_kwargs["tokenizer_mode"])
|
||||||
|
model_group.add_argument("--trust-remote-code",
|
||||||
|
**model_kwargs["trust_remote_code"])
|
||||||
|
model_group.add_argument("--dtype", **model_kwargs["dtype"])
|
||||||
|
model_group.add_argument("--seed", **model_kwargs["seed"])
|
||||||
|
model_group.add_argument("--hf-config-path",
|
||||||
|
**model_kwargs["hf_config_path"])
|
||||||
|
model_group.add_argument("--allowed-local-media-path",
|
||||||
|
**model_kwargs["allowed_local_media_path"])
|
||||||
|
model_group.add_argument("--revision", **model_kwargs["revision"])
|
||||||
|
model_group.add_argument("--code-revision",
|
||||||
|
**model_kwargs["code_revision"])
|
||||||
|
model_group.add_argument("--rope-scaling",
|
||||||
|
**model_kwargs["rope_scaling"])
|
||||||
|
model_group.add_argument("--rope-theta", **model_kwargs["rope_theta"])
|
||||||
|
model_group.add_argument("--tokenizer-revision",
|
||||||
|
**model_kwargs["tokenizer_revision"])
|
||||||
|
model_group.add_argument("--max-model-len",
|
||||||
|
**model_kwargs["max_model_len"])
|
||||||
|
model_group.add_argument("--quantization", "-q",
|
||||||
|
**model_kwargs["quantization"])
|
||||||
|
model_group.add_argument("--enforce-eager",
|
||||||
|
**model_kwargs["enforce_eager"])
|
||||||
|
model_group.add_argument("--max-seq-len-to-capture",
|
||||||
|
**model_kwargs["max_seq_len_to_capture"])
|
||||||
|
model_group.add_argument("--max-logprobs",
|
||||||
|
**model_kwargs["max_logprobs"])
|
||||||
|
model_group.add_argument("--disable-sliding-window",
|
||||||
|
**model_kwargs["disable_sliding_window"])
|
||||||
|
model_group.add_argument("--disable-cascade-attn",
|
||||||
|
**model_kwargs["disable_cascade_attn"])
|
||||||
|
model_group.add_argument("--skip-tokenizer-init",
|
||||||
|
**model_kwargs["skip_tokenizer_init"])
|
||||||
|
model_group.add_argument("--served-model-name",
|
||||||
|
**model_kwargs["served_model_name"])
|
||||||
|
# This one is a special case because it is the
|
||||||
|
# opposite of ModelConfig.use_async_output_proc
|
||||||
|
model_group.add_argument(
|
||||||
|
"--disable-async-output-proc",
|
||||||
|
action="store_true",
|
||||||
|
default=EngineArgs.disable_async_output_proc,
|
||||||
|
help="Disable async output processing. This may result in "
|
||||||
|
"lower performance.")
|
||||||
|
model_group.add_argument("--config-format",
|
||||||
|
choices=[f.value for f in ConfigFormat],
|
||||||
|
**model_kwargs["config_format"])
|
||||||
|
# This one is a special case because it can bool
|
||||||
|
# or str. TODO: Handle this in get_kwargs
|
||||||
|
model_group.add_argument("--hf-token",
|
||||||
type=str,
|
type=str,
|
||||||
default=EngineArgs.model,
|
nargs="?",
|
||||||
help='Name or path of the huggingface model to use.')
|
const=True,
|
||||||
parser.add_argument(
|
default=model_kwargs["hf_token"]["default"],
|
||||||
'--task',
|
help=model_kwargs["hf_token"]["help"])
|
||||||
default=EngineArgs.task,
|
model_group.add_argument("--hf-overrides",
|
||||||
choices=get_args(TaskOption),
|
**model_kwargs["hf_overrides"])
|
||||||
help='The task to use the model for. Each vLLM instance only '
|
model_group.add_argument("--override-neuron-config",
|
||||||
'supports one task, even if the same model can be used for '
|
**model_kwargs["override_neuron_config"])
|
||||||
'multiple tasks. When the model only supports one task, ``"auto"`` '
|
model_group.add_argument("--override-pooler-config",
|
||||||
'can be used to select it; otherwise, you must specify explicitly '
|
**model_kwargs["override_pooler_config"])
|
||||||
'which task to use.')
|
model_group.add_argument("--logits-processor-pattern",
|
||||||
parser.add_argument(
|
**model_kwargs["logits_processor_pattern"])
|
||||||
'--tokenizer',
|
model_group.add_argument("--generation-config",
|
||||||
type=optional_type(str),
|
**model_kwargs["generation_config"])
|
||||||
default=EngineArgs.tokenizer,
|
model_group.add_argument("--override-generation-config",
|
||||||
help='Name or path of the huggingface tokenizer to use. '
|
**model_kwargs["override_generation_config"])
|
||||||
'If unspecified, model name or path will be used.')
|
model_group.add_argument("--enable-sleep-mode",
|
||||||
parser.add_argument(
|
**model_kwargs["enable_sleep_mode"])
|
||||||
"--hf-config-path",
|
model_group.add_argument("--model-impl",
|
||||||
type=optional_type(str),
|
choices=[f.value for f in ModelImpl],
|
||||||
default=EngineArgs.hf_config_path,
|
**model_kwargs["model_impl"])
|
||||||
help='Name or path of the huggingface config to use. '
|
|
||||||
'If unspecified, model name or path will be used.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--skip-tokenizer-init',
|
|
||||||
action='store_true',
|
|
||||||
help='Skip initialization of tokenizer and detokenizer. '
|
|
||||||
'Expects valid prompt_token_ids and None for prompt from '
|
|
||||||
'the input. The generated output will contain token ids.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--revision',
|
|
||||||
type=optional_type(str),
|
|
||||||
default=None,
|
|
||||||
help='The specific model version to use. It can be a branch '
|
|
||||||
'name, a tag name, or a commit id. If unspecified, will use '
|
|
||||||
'the default version.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--code-revision',
|
|
||||||
type=optional_type(str),
|
|
||||||
default=None,
|
|
||||||
help='The specific revision to use for the model code on '
|
|
||||||
'Hugging Face Hub. It can be a branch name, a tag name, or a '
|
|
||||||
'commit id. If unspecified, will use the default version.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--tokenizer-revision',
|
|
||||||
type=optional_type(str),
|
|
||||||
default=None,
|
|
||||||
help='Revision of the huggingface tokenizer to use. '
|
|
||||||
'It can be a branch name, a tag name, or a commit id. '
|
|
||||||
'If unspecified, will use the default version.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--tokenizer-mode',
|
|
||||||
type=str,
|
|
||||||
default=EngineArgs.tokenizer_mode,
|
|
||||||
choices=['auto', 'slow', 'mistral', 'custom'],
|
|
||||||
help='The tokenizer mode.\n\n* "auto" will use the '
|
|
||||||
'fast tokenizer if available.\n* "slow" will '
|
|
||||||
'always use the slow tokenizer. \n* '
|
|
||||||
'"mistral" will always use the `mistral_common` tokenizer. \n* '
|
|
||||||
'"custom" will use --tokenizer to select the '
|
|
||||||
'preregistered tokenizer.')
|
|
||||||
parser.add_argument('--trust-remote-code',
|
|
||||||
action='store_true',
|
|
||||||
help='Trust remote code from huggingface.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--allowed-local-media-path',
|
|
||||||
type=str,
|
|
||||||
help="Allowing API requests to read local images or videos "
|
|
||||||
"from directories specified by the server file system. "
|
|
||||||
"This is a security risk. "
|
|
||||||
"Should only be enabled in trusted environments.")
|
|
||||||
# Model loading arguments
|
# Model loading arguments
|
||||||
load_kwargs = get_kwargs(LoadConfig)
|
load_kwargs = get_kwargs(LoadConfig)
|
||||||
load_group = parser.add_argument_group(
|
load_group = parser.add_argument_group(
|
||||||
@ -465,38 +481,6 @@ class EngineArgs:
|
|||||||
load_group.add_argument('--use-tqdm-on-load',
|
load_group.add_argument('--use-tqdm-on-load',
|
||||||
**load_kwargs["use_tqdm_on_load"])
|
**load_kwargs["use_tqdm_on_load"])
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--config-format',
|
|
||||||
default=EngineArgs.config_format,
|
|
||||||
choices=[f.value for f in ConfigFormat],
|
|
||||||
help='The format of the model config to load.\n\n'
|
|
||||||
'* "auto" will try to load the config in hf format '
|
|
||||||
'if available else it will try to load in mistral format ')
|
|
||||||
parser.add_argument(
|
|
||||||
'--dtype',
|
|
||||||
type=str,
|
|
||||||
default=EngineArgs.dtype,
|
|
||||||
choices=[
|
|
||||||
'auto', 'half', 'float16', 'bfloat16', 'float', 'float32'
|
|
||||||
],
|
|
||||||
help='Data type for model weights and activations.\n\n'
|
|
||||||
'* "auto" will use FP16 precision for FP32 and FP16 models, and '
|
|
||||||
'BF16 precision for BF16 models.\n'
|
|
||||||
'* "half" for FP16. Recommended for AWQ quantization.\n'
|
|
||||||
'* "float16" is the same as "half".\n'
|
|
||||||
'* "bfloat16" for a balance between precision and range.\n'
|
|
||||||
'* "float" is shorthand for FP32 precision.\n'
|
|
||||||
'* "float32" for FP32 precision.')
|
|
||||||
parser.add_argument('--max-model-len',
|
|
||||||
type=human_readable_int,
|
|
||||||
default=EngineArgs.max_model_len,
|
|
||||||
help='Model context length. If unspecified, will '
|
|
||||||
'be automatically derived from the model config. '
|
|
||||||
'Supports k/m/g/K/M/G in human-readable format.\n'
|
|
||||||
'Examples:\n'
|
|
||||||
'- 1k → 1000\n'
|
|
||||||
'- 1K → 1024\n')
|
|
||||||
|
|
||||||
# Guided decoding arguments
|
# Guided decoding arguments
|
||||||
guided_decoding_kwargs = get_kwargs(DecodingConfig)
|
guided_decoding_kwargs = get_kwargs(DecodingConfig)
|
||||||
guided_decoding_group = parser.add_argument_group(
|
guided_decoding_group = parser.add_argument_group(
|
||||||
@ -520,26 +504,6 @@ class EngineArgs:
|
|||||||
choices=list(ReasoningParserManager.reasoning_parsers),
|
choices=list(ReasoningParserManager.reasoning_parsers),
|
||||||
**guided_decoding_kwargs["reasoning_backend"])
|
**guided_decoding_kwargs["reasoning_backend"])
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--logits-processor-pattern',
|
|
||||||
type=optional_type(str),
|
|
||||||
default=None,
|
|
||||||
help='Optional regex pattern specifying valid logits processor '
|
|
||||||
'qualified names that can be passed with the `logits_processors` '
|
|
||||||
'extra completion argument. Defaults to None, which allows no '
|
|
||||||
'processors.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--model-impl',
|
|
||||||
type=str,
|
|
||||||
default=EngineArgs.model_impl,
|
|
||||||
choices=[f.value for f in ModelImpl],
|
|
||||||
help='Which implementation of the model to use.\n\n'
|
|
||||||
'* "auto" will try to use the vLLM implementation if it exists '
|
|
||||||
'and fall back to the Transformers implementation if no vLLM '
|
|
||||||
'implementation is available.\n'
|
|
||||||
'* "vllm" will use the vLLM model implementation.\n'
|
|
||||||
'* "transformers" will use the Transformers model '
|
|
||||||
'implementation.\n')
|
|
||||||
# Parallel arguments
|
# Parallel arguments
|
||||||
parallel_kwargs = get_kwargs(ParallelConfig)
|
parallel_kwargs = get_kwargs(ParallelConfig)
|
||||||
parallel_group = parser.add_argument_group(
|
parallel_group = parser.add_argument_group(
|
||||||
@ -592,10 +556,6 @@ class EngineArgs:
|
|||||||
cache_group.add_argument('--calculate-kv-scales',
|
cache_group.add_argument('--calculate-kv-scales',
|
||||||
**cache_kwargs["calculate_kv_scales"])
|
**cache_kwargs["calculate_kv_scales"])
|
||||||
|
|
||||||
parser.add_argument('--disable-sliding-window',
|
|
||||||
action='store_true',
|
|
||||||
help='Disables sliding window, '
|
|
||||||
'capping to sliding window size.')
|
|
||||||
parser.add_argument('--use-v2-block-manager',
|
parser.add_argument('--use-v2-block-manager',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
default=True,
|
default=True,
|
||||||
@ -605,73 +565,9 @@ class EngineArgs:
|
|||||||
'Setting this flag to True or False'
|
'Setting this flag to True or False'
|
||||||
' has no effect on vLLM behavior.')
|
' has no effect on vLLM behavior.')
|
||||||
|
|
||||||
parser.add_argument('--seed',
|
|
||||||
type=int,
|
|
||||||
default=EngineArgs.seed,
|
|
||||||
help='Random seed for operations.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--max-logprobs',
|
|
||||||
type=int,
|
|
||||||
default=EngineArgs.max_logprobs,
|
|
||||||
help=('Max number of log probs to return logprobs is specified in'
|
|
||||||
' SamplingParams.'))
|
|
||||||
parser.add_argument('--disable-log-stats',
|
parser.add_argument('--disable-log-stats',
|
||||||
action='store_true',
|
action='store_true',
|
||||||
help='Disable logging statistics.')
|
help='Disable logging statistics.')
|
||||||
# Quantization settings.
|
|
||||||
parser.add_argument('--quantization',
|
|
||||||
'-q',
|
|
||||||
type=optional_type(str),
|
|
||||||
choices=[*QUANTIZATION_METHODS, None],
|
|
||||||
default=EngineArgs.quantization,
|
|
||||||
help='Method used to quantize the weights. If '
|
|
||||||
'None, we first check the `quantization_config` '
|
|
||||||
'attribute in the model config file. If that is '
|
|
||||||
'None, we assume the model weights are not '
|
|
||||||
'quantized and use `dtype` to determine the data '
|
|
||||||
'type of the weights.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--rope-scaling',
|
|
||||||
default=None,
|
|
||||||
type=json.loads,
|
|
||||||
help='RoPE scaling configuration in JSON format. '
|
|
||||||
'For example, ``{"rope_type":"dynamic","factor":2.0}``')
|
|
||||||
parser.add_argument('--rope-theta',
|
|
||||||
default=None,
|
|
||||||
type=float,
|
|
||||||
help='RoPE theta. Use with `rope_scaling`. In '
|
|
||||||
'some cases, changing the RoPE theta improves the '
|
|
||||||
'performance of the scaled model.')
|
|
||||||
parser.add_argument(
|
|
||||||
'--hf-token',
|
|
||||||
type=str,
|
|
||||||
nargs='?',
|
|
||||||
const=True,
|
|
||||||
default=None,
|
|
||||||
help='The token to use as HTTP bearer authorization'
|
|
||||||
' for remote files. If `True`, will use the token '
|
|
||||||
'generated when running `huggingface-cli login` '
|
|
||||||
'(stored in `~/.huggingface`).')
|
|
||||||
parser.add_argument('--hf-overrides',
|
|
||||||
type=json.loads,
|
|
||||||
default=EngineArgs.hf_overrides,
|
|
||||||
help='Extra arguments for the HuggingFace config. '
|
|
||||||
'This should be a JSON string that will be '
|
|
||||||
'parsed into a dictionary.')
|
|
||||||
parser.add_argument('--enforce-eager',
|
|
||||||
action='store_true',
|
|
||||||
help='Always use eager-mode PyTorch. If False, '
|
|
||||||
'will use eager mode and CUDA graph in hybrid '
|
|
||||||
'for maximal performance and flexibility.')
|
|
||||||
parser.add_argument('--max-seq-len-to-capture',
|
|
||||||
type=int,
|
|
||||||
default=EngineArgs.max_seq_len_to_capture,
|
|
||||||
help='Maximum sequence length covered by CUDA '
|
|
||||||
'graphs. When a sequence has context length '
|
|
||||||
'larger than this, we fall back to eager mode. '
|
|
||||||
'Additionally for encoder-decoder models, if the '
|
|
||||||
'sequence length of the encoder input is larger '
|
|
||||||
'than this, we fall back to the eager mode.')
|
|
||||||
|
|
||||||
# Tokenizer arguments
|
# Tokenizer arguments
|
||||||
tokenizer_kwargs = get_kwargs(TokenizerPoolConfig)
|
tokenizer_kwargs = get_kwargs(TokenizerPoolConfig)
|
||||||
@ -775,20 +671,6 @@ class EngineArgs:
|
|||||||
"Default to `original/**/*` to avoid repeated loading of llama's "
|
"Default to `original/**/*` to avoid repeated loading of llama's "
|
||||||
"checkpoints.")
|
"checkpoints.")
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--served-model-name",
|
|
||||||
nargs="+",
|
|
||||||
type=str,
|
|
||||||
default=None,
|
|
||||||
help="The model name(s) used in the API. If multiple "
|
|
||||||
"names are provided, the server will respond to any "
|
|
||||||
"of the provided names. The model name in the model "
|
|
||||||
"field of a response will be the first name in this "
|
|
||||||
"list. If not specified, the model name will be the "
|
|
||||||
"same as the ``--model`` argument. Noted that this name(s) "
|
|
||||||
"will also be used in `model_name` tag content of "
|
|
||||||
"prometheus metrics, if multiple names provided, metrics "
|
|
||||||
"tag will take the first one.")
|
|
||||||
parser.add_argument('--qlora-adapter-name-or-path',
|
parser.add_argument('--qlora-adapter-name-or-path',
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
@ -822,13 +704,6 @@ class EngineArgs:
|
|||||||
"modules. This involves use of possibly costly and or blocking "
|
"modules. This involves use of possibly costly and or blocking "
|
||||||
"operations and hence might have a performance impact.")
|
"operations and hence might have a performance impact.")
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--disable-async-output-proc',
|
|
||||||
action='store_true',
|
|
||||||
default=EngineArgs.disable_async_output_proc,
|
|
||||||
help="Disable async output processing. This may result in "
|
|
||||||
"lower performance.")
|
|
||||||
|
|
||||||
# Scheduler arguments
|
# Scheduler arguments
|
||||||
scheduler_kwargs = get_kwargs(SchedulerConfig)
|
scheduler_kwargs = get_kwargs(SchedulerConfig)
|
||||||
scheduler_group = parser.add_argument_group(
|
scheduler_group = parser.add_argument_group(
|
||||||
@ -871,19 +746,6 @@ class EngineArgs:
|
|||||||
parser.add_argument('--scheduler-cls',
|
parser.add_argument('--scheduler-cls',
|
||||||
**scheduler_kwargs["scheduler_cls"])
|
**scheduler_kwargs["scheduler_cls"])
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
'--override-neuron-config',
|
|
||||||
type=json.loads,
|
|
||||||
default=None,
|
|
||||||
help="Override or set neuron device configuration. "
|
|
||||||
"e.g. ``{\"cast_logits_dtype\": \"bloat16\"}``.")
|
|
||||||
parser.add_argument(
|
|
||||||
'--override-pooler-config',
|
|
||||||
type=PoolerConfig.from_json,
|
|
||||||
default=None,
|
|
||||||
help="Override or set the pooling method for pooling models. "
|
|
||||||
"e.g. ``{\"pooling_type\": \"mean\", \"normalize\": false}``.")
|
|
||||||
|
|
||||||
parser.add_argument('--compilation-config',
|
parser.add_argument('--compilation-config',
|
||||||
'-O',
|
'-O',
|
||||||
type=CompilationConfig.from_cli,
|
type=CompilationConfig.from_cli,
|
||||||
@ -920,34 +782,6 @@ class EngineArgs:
|
|||||||
help='The worker extension class on top of the worker cls, '
|
help='The worker extension class on top of the worker cls, '
|
||||||
'it is useful if you just want to add new functions to the worker '
|
'it is useful if you just want to add new functions to the worker '
|
||||||
'class without changing the existing functions.')
|
'class without changing the existing functions.')
|
||||||
parser.add_argument(
|
|
||||||
"--generation-config",
|
|
||||||
type=optional_type(str),
|
|
||||||
default="auto",
|
|
||||||
help="The folder path to the generation config. "
|
|
||||||
"Defaults to 'auto', the generation config will be loaded from "
|
|
||||||
"model path. If set to 'vllm', no generation config is loaded, "
|
|
||||||
"vLLM defaults will be used. If set to a folder path, the "
|
|
||||||
"generation config will be loaded from the specified folder path. "
|
|
||||||
"If `max_new_tokens` is specified in generation config, then "
|
|
||||||
"it sets a server-wide limit on the number of output tokens "
|
|
||||||
"for all requests.")
|
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--override-generation-config",
|
|
||||||
type=json.loads,
|
|
||||||
default=None,
|
|
||||||
help="Overrides or sets generation config in JSON format. "
|
|
||||||
"e.g. ``{\"temperature\": 0.5}``. If used with "
|
|
||||||
"--generation-config=auto, the override parameters will be merged "
|
|
||||||
"with the default config from the model. If generation-config is "
|
|
||||||
"None, only the override parameters are used.")
|
|
||||||
|
|
||||||
parser.add_argument("--enable-sleep-mode",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Enable sleep mode for the engine. "
|
|
||||||
"(only cuda platform is supported)")
|
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--additional-config",
|
"--additional-config",
|
||||||
@ -966,16 +800,6 @@ class EngineArgs:
|
|||||||
"If enabled, the model will be able to generate reasoning content."
|
"If enabled, the model will be able to generate reasoning content."
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
|
||||||
"--disable-cascade-attn",
|
|
||||||
action="store_true",
|
|
||||||
default=False,
|
|
||||||
help="Disable cascade attention for V1. While cascade attention "
|
|
||||||
"does not change the mathematical correctness, disabling it "
|
|
||||||
"could be useful for preventing potential numerical issues. "
|
|
||||||
"Note that even if this is set to False, cascade attention will be "
|
|
||||||
"only used when the heuristic tells that it's beneficial.")
|
|
||||||
|
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1002,8 +826,7 @@ class EngineArgs:
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
hf_config_path=self.hf_config_path,
|
hf_config_path=self.hf_config_path,
|
||||||
task=self.task,
|
task=self.task,
|
||||||
# We know this is not None because we set it in __post_init__
|
tokenizer=self.tokenizer,
|
||||||
tokenizer=cast(str, self.tokenizer),
|
|
||||||
tokenizer_mode=self.tokenizer_mode,
|
tokenizer_mode=self.tokenizer_mode,
|
||||||
trust_remote_code=self.trust_remote_code,
|
trust_remote_code=self.trust_remote_code,
|
||||||
allowed_local_media_path=self.allowed_local_media_path,
|
allowed_local_media_path=self.allowed_local_media_path,
|
||||||
|
|||||||
@ -13,7 +13,7 @@ from typing_extensions import TypeVar, deprecated
|
|||||||
|
|
||||||
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput,
|
||||||
BeamSearchSequence, get_beam_search_score)
|
BeamSearchSequence, get_beam_search_score)
|
||||||
from vllm.config import CompilationConfig
|
from vllm.config import CompilationConfig, ModelDType, TokenizerMode
|
||||||
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
|
from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig,
|
||||||
TaskOption)
|
TaskOption)
|
||||||
from vllm.engine.llm_engine import LLMEngine
|
from vllm.engine.llm_engine import LLMEngine
|
||||||
@ -32,6 +32,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.model_executor.guided_decoding.guided_fields import (
|
from vllm.model_executor.guided_decoding.guided_fields import (
|
||||||
GuidedDecodingRequest, LLMGuidedOptions)
|
GuidedDecodingRequest, LLMGuidedOptions)
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
|
from vllm.outputs import (ClassificationRequestOutput, EmbeddingRequestOutput,
|
||||||
PoolingRequestOutput, RequestOutput,
|
PoolingRequestOutput, RequestOutput,
|
||||||
ScoringRequestOutput)
|
ScoringRequestOutput)
|
||||||
@ -163,20 +164,20 @@ class LLM:
|
|||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
tokenizer: Optional[str] = None,
|
tokenizer: Optional[str] = None,
|
||||||
tokenizer_mode: str = "auto",
|
tokenizer_mode: TokenizerMode = "auto",
|
||||||
skip_tokenizer_init: bool = False,
|
skip_tokenizer_init: bool = False,
|
||||||
trust_remote_code: bool = False,
|
trust_remote_code: bool = False,
|
||||||
allowed_local_media_path: str = "",
|
allowed_local_media_path: str = "",
|
||||||
tensor_parallel_size: int = 1,
|
tensor_parallel_size: int = 1,
|
||||||
dtype: str = "auto",
|
dtype: ModelDType = "auto",
|
||||||
quantization: Optional[str] = None,
|
quantization: Optional[QuantizationMethods] = None,
|
||||||
revision: Optional[str] = None,
|
revision: Optional[str] = None,
|
||||||
tokenizer_revision: Optional[str] = None,
|
tokenizer_revision: Optional[str] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
gpu_memory_utilization: float = 0.9,
|
gpu_memory_utilization: float = 0.9,
|
||||||
swap_space: float = 4,
|
swap_space: float = 4,
|
||||||
cpu_offload_gb: float = 0,
|
cpu_offload_gb: float = 0,
|
||||||
enforce_eager: Optional[bool] = None,
|
enforce_eager: bool = False,
|
||||||
max_seq_len_to_capture: int = 8192,
|
max_seq_len_to_capture: int = 8192,
|
||||||
disable_custom_all_reduce: bool = False,
|
disable_custom_all_reduce: bool = False,
|
||||||
disable_async_output_proc: bool = False,
|
disable_async_output_proc: bool = False,
|
||||||
@ -189,12 +190,7 @@ class LLM:
|
|||||||
compilation_config: Optional[Union[int, dict[str, Any]]] = None,
|
compilation_config: Optional[Union[int, dict[str, Any]]] = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> None:
|
) -> None:
|
||||||
'''
|
"""LLM constructor."""
|
||||||
LLM constructor.
|
|
||||||
|
|
||||||
Note: if enforce_eager is unset (enforce_eager is None)
|
|
||||||
it defaults to False.
|
|
||||||
'''
|
|
||||||
|
|
||||||
if "disable_log_stats" not in kwargs:
|
if "disable_log_stats" not in kwargs:
|
||||||
kwargs["disable_log_stats"] = True
|
kwargs["disable_log_stats"] = True
|
||||||
|
|||||||
@ -12,6 +12,7 @@ from torch.nn.parameter import Parameter
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
@ -186,7 +187,7 @@ class AQLMConfig(QuantizationConfig):
|
|||||||
f"out_group_size={self.out_group_size})")
|
f"out_group_size={self.out_group_size})")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> QuantizationMethods:
|
||||||
return "aqlm"
|
return "aqlm"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
|
||||||
@ -44,7 +45,7 @@ class AWQConfig(QuantizationConfig):
|
|||||||
f"zero_point={self.zero_point}, "
|
f"zero_point={self.zero_point}, "
|
||||||
f"modules_to_not_convert={self.modules_to_not_convert})")
|
f"modules_to_not_convert={self.modules_to_not_convert})")
|
||||||
|
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> QuantizationMethods:
|
||||||
return "awq"
|
return "awq"
|
||||||
|
|
||||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
|||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod,
|
UnquantizedLinearMethod,
|
||||||
set_weight_attrs)
|
set_weight_attrs)
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
|
from vllm.model_executor.layers.quantization.awq import (AWQConfig,
|
||||||
is_layer_skipped_awq)
|
is_layer_skipped_awq)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
@ -73,7 +74,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
|||||||
f"modules_to_not_convert={self.modules_to_not_convert})")
|
f"modules_to_not_convert={self.modules_to_not_convert})")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> QuantizationMethods:
|
||||||
return "awq_marlin"
|
return "awq_marlin"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -101,8 +102,8 @@ class AWQMarlinConfig(QuantizationConfig):
|
|||||||
modules_to_not_convert, config)
|
modules_to_not_convert, config)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def override_quantization_method(cls, hf_quant_cfg,
|
def override_quantization_method(
|
||||||
user_quant) -> Optional[str]:
|
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||||
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
|
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
|
||||||
is_valid_user_quant = (user_quant is None or user_quant == "marlin"
|
is_valid_user_quant = (user_quant is None or user_quant == "marlin"
|
||||||
or user_quant == "awq_marlin")
|
or user_quant == "awq_marlin")
|
||||||
|
|||||||
@ -2,11 +2,16 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List, Optional, Type
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
|
else:
|
||||||
|
QuantizationMethods = str
|
||||||
|
|
||||||
|
|
||||||
class QuantizeMethodBase(ABC):
|
class QuantizeMethodBase(ABC):
|
||||||
"""Base class for different quantized methods."""
|
"""Base class for different quantized methods."""
|
||||||
@ -66,7 +71,7 @@ class QuantizationConfig(ABC):
|
|||||||
self.packed_modules_mapping: Dict[str, List[str]] = dict()
|
self.packed_modules_mapping: Dict[str, List[str]] = dict()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> QuantizationMethods:
|
||||||
"""Name of the quantization method."""
|
"""Name of the quantization method."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -99,8 +104,8 @@ class QuantizationConfig(ABC):
|
|||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def override_quantization_method(cls, hf_quant_cfg,
|
def override_quantization_method(
|
||||||
user_quant) -> Optional[str]:
|
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||||
"""
|
"""
|
||||||
Detects if this quantization method can support a given checkpoint
|
Detects if this quantization method can support a given checkpoint
|
||||||
format by overriding the user specified quantization method --
|
format by overriding the user specified quantization method --
|
||||||
|
|||||||
@ -5,6 +5,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
from vllm.model_executor.layers.quantization.utils.bitblas_utils import (
|
||||||
@ -100,7 +101,7 @@ class BitBLASConfig(QuantizationConfig):
|
|||||||
f"quant_method={self.quant_method})")
|
f"quant_method={self.quant_method})")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> QuantizationMethods:
|
||||||
return "bitblas"
|
return "bitblas"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -139,8 +140,8 @@ class BitBLASConfig(QuantizationConfig):
|
|||||||
lm_head_quantized)
|
lm_head_quantized)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def override_quantization_method(cls, hf_quant_cfg,
|
def override_quantization_method(
|
||||||
user_quant) -> Optional[str]:
|
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||||
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||||
# compat: autogptq <=0.7.1 is_bitblas_format: bool
|
# compat: autogptq <=0.7.1 is_bitblas_format: bool
|
||||||
is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas"
|
is_bitblas_format = (hf_quant_cfg.get("checkpoint_format") == "bitblas"
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch
|
|||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod,
|
UnquantizedLinearMethod,
|
||||||
set_weight_attrs)
|
set_weight_attrs)
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.utils import direct_register_custom_op
|
from vllm.utils import direct_register_custom_op
|
||||||
@ -56,7 +57,7 @@ class BitsAndBytesConfig(QuantizationConfig):
|
|||||||
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
|
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> QuantizationMethods:
|
||||||
return "bitsandbytes"
|
return "bitsandbytes"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from vllm.logger import init_logger
|
|||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
|
||||||
@ -71,7 +72,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 70
|
return 70
|
||||||
|
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> QuantizationMethods:
|
||||||
return "compressed-tensors"
|
return "compressed-tensors"
|
||||||
|
|
||||||
def get_quant_method(
|
def get_quant_method(
|
||||||
|
|||||||
@ -7,6 +7,7 @@ import torch.nn as nn
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
@ -41,8 +42,8 @@ class DeepSpeedFPConfig(QuantizationConfig):
|
|||||||
f"group_size={self.group_size}")
|
f"group_size={self.group_size}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> QuantizationMethods:
|
||||||
return "DeepSpeedFP"
|
return "deepspeedfp"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig":
|
def from_config(cls, config: Dict[str, Any]) -> "DeepSpeedFPConfig":
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group
|
|||||||
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
|
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
|
||||||
from vllm.model_executor.layers.linear import (LinearBase,
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
@ -20,7 +21,7 @@ class ExpertsInt8Config(QuantizationConfig):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> QuantizationMethods:
|
||||||
return "experts_int8"
|
return "experts_int8"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from torch.nn.parameter import Parameter
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||||
@ -38,7 +39,7 @@ class FBGEMMFp8Config(QuantizationConfig):
|
|||||||
self.fp8_linear = Fp8LinearOp()
|
self.fp8_linear = Fp8LinearOp()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> QuantizationMethods:
|
||||||
return "fbgemm_fp8"
|
return "fbgemm_fp8"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -16,6 +16,7 @@ from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
|||||||
FusedMoeWeightScaleSupported)
|
FusedMoeWeightScaleSupported)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
@ -83,7 +84,7 @@ class Fp8Config(QuantizationConfig):
|
|||||||
self.weight_block_size = weight_block_size
|
self.weight_block_size = weight_block_size
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> QuantizationMethods:
|
||||||
return "fp8"
|
return "fp8"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from vllm.model_executor.layers.activation import SiluAndMul
|
|||||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||||
FusedMoEMethodBase)
|
FusedMoEMethodBase)
|
||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
@ -31,7 +32,7 @@ class GGUFConfig(QuantizationConfig):
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return ("GGUFConfig()")
|
return ("GGUFConfig()")
|
||||||
|
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> QuantizationMethods:
|
||||||
return "gguf"
|
return "gguf"
|
||||||
|
|
||||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from torch.nn.parameter import Parameter
|
|||||||
|
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.model_executor.layers.linear import LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearMethodBase
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
from vllm.model_executor.layers.quantization.utils.gptq_utils import (
|
||||||
@ -79,7 +80,7 @@ class GPTQConfig(QuantizationConfig):
|
|||||||
f"dynamic={self.dynamic}")
|
f"dynamic={self.dynamic}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> QuantizationMethods:
|
||||||
return "gptq"
|
return "gptq"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from torch.nn.parameter import Parameter
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
set_weight_attrs)
|
set_weight_attrs)
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||||
@ -123,7 +124,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
|
|||||||
f"quant_method={self.quant_method})")
|
f"quant_method={self.quant_method})")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> QuantizationMethods:
|
||||||
return "gptq_bitblas"
|
return "gptq_bitblas"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -151,8 +152,8 @@ class GPTQBitBLASConfig(QuantizationConfig):
|
|||||||
lm_head_quantized)
|
lm_head_quantized)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def override_quantization_method(cls, hf_quant_cfg,
|
def override_quantization_method(
|
||||||
user_quant) -> Optional[str]:
|
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||||
can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg)
|
can_convert = cls.is_gptq_bitblas_compatible(hf_quant_cfg)
|
||||||
|
|
||||||
is_valid_user_quant = (user_quant is None or user_quant == "bitblas"
|
is_valid_user_quant = (user_quant is None or user_quant == "bitblas"
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
|||||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||||
set_weight_attrs)
|
set_weight_attrs)
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
from vllm.model_executor.layers.quantization.kernels.mixed_precision import (
|
||||||
@ -100,7 +101,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
f"dynamic={self.dynamic}")
|
f"dynamic={self.dynamic}")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> QuantizationMethods:
|
||||||
return "gptq_marlin"
|
return "gptq_marlin"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -130,8 +131,8 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|||||||
lm_head_quantized, dynamic, config)
|
lm_head_quantized, dynamic, config)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def override_quantization_method(cls, hf_quant_cfg,
|
def override_quantization_method(
|
||||||
user_quant) -> Optional[str]:
|
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||||
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
|
can_convert = cls.is_gptq_marlin_compatible(hf_quant_cfg)
|
||||||
|
|
||||||
is_valid_user_quant = (user_quant is None or user_quant == "marlin"
|
is_valid_user_quant = (user_quant is None or user_quant == "marlin"
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
@ -85,7 +86,7 @@ class GPTQMarlin24Config(QuantizationConfig):
|
|||||||
self.quant_type, self.group_size)
|
self.quant_type, self.group_size)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> QuantizationMethods:
|
||||||
return "gptq_marlin_24"
|
return "gptq_marlin_24"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -108,8 +109,8 @@ class GPTQMarlin24Config(QuantizationConfig):
|
|||||||
return cls(weight_bits, group_size)
|
return cls(weight_bits, group_size)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def override_quantization_method(cls, hf_quant_cfg,
|
def override_quantization_method(
|
||||||
user_quant) -> Optional[str]:
|
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||||
is_marlin_24_format = (
|
is_marlin_24_format = (
|
||||||
hf_quant_cfg.get("checkpoint_format") == "marlin_24")
|
hf_quant_cfg.get("checkpoint_format") == "marlin_24")
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
@ -50,7 +51,7 @@ class HQQMarlinConfig(QuantizationConfig):
|
|||||||
f"group_size={self.group_size})")
|
f"group_size={self.group_size})")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> QuantizationMethods:
|
||||||
return "hqq"
|
return "hqq"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import torch
|
|||||||
|
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod,
|
from vllm.model_executor.layers.quantization.awq import (AWQLinearMethod,
|
||||||
is_layer_skipped_awq)
|
is_layer_skipped_awq)
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
@ -58,7 +59,7 @@ class IPEXConfig(QuantizationConfig):
|
|||||||
f"group_size={self.group_size})")
|
f"group_size={self.group_size})")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> QuantizationMethods:
|
||||||
return "ipex"
|
return "ipex"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -97,8 +98,8 @@ class IPEXConfig(QuantizationConfig):
|
|||||||
lm_head_quantized)
|
lm_head_quantized)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def override_quantization_method(cls, hf_quant_cfg,
|
def override_quantization_method(
|
||||||
user_quant) -> Optional[str]:
|
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||||
if not current_platform.is_cpu() and not current_platform.is_xpu():
|
if not current_platform.is_cpu() and not current_platform.is_xpu():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||||
@ -63,7 +64,7 @@ class MarlinConfig(QuantizationConfig):
|
|||||||
f"lm_head_quantized={self.lm_head_quantized})")
|
f"lm_head_quantized={self.lm_head_quantized})")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> QuantizationMethods:
|
||||||
return "marlin"
|
return "marlin"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -87,8 +88,8 @@ class MarlinConfig(QuantizationConfig):
|
|||||||
return cls(group_size, lm_head_quantized)
|
return cls(group_size, lm_head_quantized)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def override_quantization_method(cls, hf_quant_cfg,
|
def override_quantization_method(
|
||||||
user_quant) -> Optional[str]:
|
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||||
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
||||||
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
||||||
is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin"
|
is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin"
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from vllm._custom_ops import (cutlass_scaled_fp4_mm,
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
@ -42,7 +43,7 @@ class ModelOptFp8Config(QuantizationConfig):
|
|||||||
" the format is experimental and could change.")
|
" the format is experimental and could change.")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> QuantizationMethods:
|
||||||
return "modelopt"
|
return "modelopt"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -184,8 +185,8 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
|||||||
self.exclude_modules = exclude_modules
|
self.exclude_modules = exclude_modules
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> QuantizationMethods:
|
||||||
return "modelopt_nvfp4"
|
return "nvfp4"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
|||||||
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported)
|
||||||
from vllm.model_executor.layers.linear import (LinearBase,
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||||
@ -64,7 +65,7 @@ class MoeWNA16Config(QuantizationConfig):
|
|||||||
self.modules_to_not_convert = modules_to_not_convert
|
self.modules_to_not_convert = modules_to_not_convert
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> QuantizationMethods:
|
||||||
return "moe_wna16"
|
return "moe_wna16"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -100,8 +101,8 @@ class MoeWNA16Config(QuantizationConfig):
|
|||||||
lm_head_quantized, modules_to_not_convert, config)
|
lm_head_quantized, modules_to_not_convert, config)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def override_quantization_method(cls, hf_quant_cfg,
|
def override_quantization_method(
|
||||||
user_quant) -> Optional[str]:
|
cls, hf_quant_cfg, user_quant) -> Optional[QuantizationMethods]:
|
||||||
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
|
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
|
||||||
if can_convert and user_quant == "moe_wna16":
|
if can_convert and user_quant == "moe_wna16":
|
||||||
return cls.get_name()
|
return cls.get_name()
|
||||||
|
|||||||
@ -6,6 +6,7 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from torch.nn import Module
|
from torch.nn import Module
|
||||||
|
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
|
|
||||||
@ -30,7 +31,7 @@ class NeuronQuantConfig(QuantizationConfig):
|
|||||||
self.dequant_dtype = dequant_dtype
|
self.dequant_dtype = dequant_dtype
|
||||||
self.quantize_method = quantize_method
|
self.quantize_method = quantize_method
|
||||||
|
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> QuantizationMethods:
|
||||||
return "neuron_quant"
|
return "neuron_quant"
|
||||||
|
|
||||||
def get_supported_act_dtypes(self) -> List[str]:
|
def get_supported_act_dtypes(self) -> List[str]:
|
||||||
|
|||||||
@ -9,6 +9,7 @@ from vllm import _custom_ops as ops
|
|||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import (LinearBase,
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizeMethodBase)
|
QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
|
from vllm.model_executor.layers.quantization.fp8 import (Fp8Config,
|
||||||
@ -50,7 +51,7 @@ class PTPCFp8Config(Fp8Config):
|
|||||||
ignored_layers=ignored_layers)
|
ignored_layers=ignored_layers)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> QuantizationMethods:
|
||||||
return "ptpc_fp8"
|
return "ptpc_fp8"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -8,6 +8,7 @@ from torch.nn.parameter import Parameter
|
|||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.parameter import (BasevLLMParameter,
|
from vllm.model_executor.parameter import (BasevLLMParameter,
|
||||||
@ -84,7 +85,7 @@ class QQQConfig(QuantizationConfig):
|
|||||||
self.weight_bits, self.group_size)
|
self.weight_bits, self.group_size)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_name(cls) -> str:
|
def get_name(cls) -> QuantizationMethods:
|
||||||
return "qqq"
|
return "qqq"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -8,6 +8,7 @@ import torch
|
|||||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
@ -47,7 +48,7 @@ class QuarkConfig(QuantizationConfig):
|
|||||||
def get_min_capability(cls) -> int:
|
def get_min_capability(cls) -> int:
|
||||||
return 70
|
return 70
|
||||||
|
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> QuantizationMethods:
|
||||||
return "quark"
|
return "quark"
|
||||||
|
|
||||||
def get_quant_method(self, layer: torch.nn.Module,
|
def get_quant_method(self, layer: torch.nn.Module,
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import torch.nn.functional as F
|
|||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
@ -20,7 +21,7 @@ class TorchAOConfig(QuantizationConfig):
|
|||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"TorchAOConfig({self.torchao_config})"
|
return f"TorchAOConfig({self.torchao_config})"
|
||||||
|
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> QuantizationMethods:
|
||||||
return "torchao"
|
return "torchao"
|
||||||
|
|
||||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||||
|
|||||||
@ -7,6 +7,7 @@ from torch.nn import Module
|
|||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
|
|
||||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
from vllm.model_executor.parameter import ModelWeightParameter
|
from vllm.model_executor.parameter import ModelWeightParameter
|
||||||
@ -27,7 +28,7 @@ class Int8TpuConfig(QuantizationConfig):
|
|||||||
f"Unsupported activation scheme {activation_scheme}")
|
f"Unsupported activation scheme {activation_scheme}")
|
||||||
self.activation_scheme = activation_scheme
|
self.activation_scheme = activation_scheme
|
||||||
|
|
||||||
def get_name(self) -> str:
|
def get_name(self) -> QuantizationMethods:
|
||||||
return "tpu_int8"
|
return "tpu_int8"
|
||||||
|
|
||||||
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
||||||
|
|||||||
@ -1496,7 +1496,7 @@ def get_rope(
|
|||||||
if key in _ROPE_DICT:
|
if key in _ROPE_DICT:
|
||||||
return _ROPE_DICT[key]
|
return _ROPE_DICT[key]
|
||||||
|
|
||||||
if rope_scaling is None:
|
if not rope_scaling:
|
||||||
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base,
|
||||||
is_neox_style, dtype)
|
is_neox_style, dtype)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -180,7 +180,6 @@ def _get_neuron_config_after_override(default_neuron_config,
|
|||||||
NeuronConfig, QuantizationConfig,
|
NeuronConfig, QuantizationConfig,
|
||||||
SparseAttnConfig)
|
SparseAttnConfig)
|
||||||
|
|
||||||
overridden_neuron_config = overridden_neuron_config or {}
|
|
||||||
sparse_attn = overridden_neuron_config.pop("sparse_attn", {})
|
sparse_attn = overridden_neuron_config.pop("sparse_attn", {})
|
||||||
if sparse_attn:
|
if sparse_attn:
|
||||||
overridden_neuron_config["sparse_attn"] = SparseAttnConfig(
|
overridden_neuron_config["sparse_attn"] = SparseAttnConfig(
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user