mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-17 03:55:01 +08:00
Explicitly explain quant method override ordering and ensure all overrides are ordered (#17256)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
b6dd32aa07
commit
c7941cca18
@ -28,6 +28,7 @@ import vllm.envs as envs
|
|||||||
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
|
from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS,
|
||||||
|
QuantizationMethods,
|
||||||
get_quantization_config)
|
get_quantization_config)
|
||||||
from vllm.model_executor.models import ModelRegistry
|
from vllm.model_executor.models import ModelRegistry
|
||||||
from vllm.platforms import CpuArchEnum, current_platform
|
from vllm.platforms import CpuArchEnum, current_platform
|
||||||
@ -767,12 +768,43 @@ class ModelConfig:
|
|||||||
"compressed-tensors")
|
"compressed-tensors")
|
||||||
quant_cfg["quant_method"] = quant_method
|
quant_cfg["quant_method"] = quant_method
|
||||||
|
|
||||||
|
# Quantization methods which are overrides (i.e. they have a
|
||||||
|
# `override_quantization_method` method) must be checked in order
|
||||||
|
# of preference (this is particularly important for GPTQ).
|
||||||
|
overrides = [
|
||||||
|
"marlin",
|
||||||
|
"bitblas",
|
||||||
|
"gptq_marlin_24",
|
||||||
|
"gptq_marlin",
|
||||||
|
"gptq_bitblas",
|
||||||
|
"awq_marlin",
|
||||||
|
"ipex",
|
||||||
|
"moe_wna16",
|
||||||
|
]
|
||||||
|
quantization_methods = [
|
||||||
|
q for q in supported_quantization if q not in overrides
|
||||||
|
]
|
||||||
|
# Any custom overrides will be in quantization_methods so we place
|
||||||
|
# them at the start of the list so custom overrides have preference
|
||||||
|
# over the built in ones.
|
||||||
|
quantization_methods = quantization_methods + overrides
|
||||||
|
|
||||||
# Detect which checkpoint is it
|
# Detect which checkpoint is it
|
||||||
for name in QUANTIZATION_METHODS:
|
for name in quantization_methods:
|
||||||
method = get_quantization_config(name)
|
method = get_quantization_config(name)
|
||||||
quantization_override = method.override_quantization_method(
|
quantization_override = method.override_quantization_method(
|
||||||
quant_cfg, self.quantization)
|
quant_cfg, self.quantization)
|
||||||
if quantization_override:
|
if quantization_override is not None:
|
||||||
|
# Raise error if the override is not custom (custom would
|
||||||
|
# be in QUANTIZATION_METHODS but not QuantizationMethods)
|
||||||
|
# and hasn't been added to the overrides list.
|
||||||
|
if (name in get_args(QuantizationMethods)
|
||||||
|
and name not in overrides):
|
||||||
|
raise ValueError(
|
||||||
|
f"Quantization method {name} is an override but "
|
||||||
|
"is has not been added to the `overrides` list "
|
||||||
|
"above. This is necessary to ensure that the "
|
||||||
|
"overrides are checked in order of preference.")
|
||||||
quant_method = quantization_override
|
quant_method = quantization_override
|
||||||
self.quantization = quantization_override
|
self.quantization = quantization_override
|
||||||
break
|
break
|
||||||
|
|||||||
@ -1,11 +1,11 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
|
||||||
from typing import Dict, List, Type
|
from typing import Literal, Type, get_args
|
||||||
|
|
||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig)
|
QuantizationConfig)
|
||||||
|
|
||||||
QUANTIZATION_METHODS: List[str] = [
|
QuantizationMethods = Literal[
|
||||||
"aqlm",
|
"aqlm",
|
||||||
"awq",
|
"awq",
|
||||||
"deepspeedfp",
|
"deepspeedfp",
|
||||||
@ -15,8 +15,6 @@ QUANTIZATION_METHODS: List[str] = [
|
|||||||
"fbgemm_fp8",
|
"fbgemm_fp8",
|
||||||
"modelopt",
|
"modelopt",
|
||||||
"nvfp4",
|
"nvfp4",
|
||||||
# The order of gptq methods is important for config.py iteration over
|
|
||||||
# override_quantization_method(..)
|
|
||||||
"marlin",
|
"marlin",
|
||||||
"bitblas",
|
"bitblas",
|
||||||
"gguf",
|
"gguf",
|
||||||
@ -36,6 +34,7 @@ QUANTIZATION_METHODS: List[str] = [
|
|||||||
"moe_wna16",
|
"moe_wna16",
|
||||||
"torchao",
|
"torchao",
|
||||||
]
|
]
|
||||||
|
QUANTIZATION_METHODS: list[str] = list(get_args(QuantizationMethods))
|
||||||
|
|
||||||
# The customized quantization methods which will be added to this dict.
|
# The customized quantization methods which will be added to this dict.
|
||||||
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {}
|
_CUSTOMIZED_METHOD_TO_QUANT_CONFIG = {}
|
||||||
@ -111,7 +110,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|||||||
from .torchao import TorchAOConfig
|
from .torchao import TorchAOConfig
|
||||||
from .tpu_int8 import Int8TpuConfig
|
from .tpu_int8 import Int8TpuConfig
|
||||||
|
|
||||||
method_to_config: Dict[str, Type[QuantizationConfig]] = {
|
method_to_config: dict[str, Type[QuantizationConfig]] = {
|
||||||
"aqlm": AQLMConfig,
|
"aqlm": AQLMConfig,
|
||||||
"awq": AWQConfig,
|
"awq": AWQConfig,
|
||||||
"deepspeedfp": DeepSpeedFPConfig,
|
"deepspeedfp": DeepSpeedFPConfig,
|
||||||
@ -120,8 +119,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|||||||
"fbgemm_fp8": FBGEMMFp8Config,
|
"fbgemm_fp8": FBGEMMFp8Config,
|
||||||
"modelopt": ModelOptFp8Config,
|
"modelopt": ModelOptFp8Config,
|
||||||
"nvfp4": ModelOptNvFp4Config,
|
"nvfp4": ModelOptNvFp4Config,
|
||||||
# The order of gptq methods is important for config.py iteration over
|
|
||||||
# override_quantization_method(..)
|
|
||||||
"marlin": MarlinConfig,
|
"marlin": MarlinConfig,
|
||||||
"bitblas": BitBLASConfig,
|
"bitblas": BitBLASConfig,
|
||||||
"gguf": GGUFConfig,
|
"gguf": GGUFConfig,
|
||||||
@ -150,6 +147,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"QuantizationConfig",
|
"QuantizationConfig",
|
||||||
|
"QuantizationMethods",
|
||||||
"get_quantization_config",
|
"get_quantization_config",
|
||||||
"QUANTIZATION_METHODS",
|
"QUANTIZATION_METHODS",
|
||||||
]
|
]
|
||||||
Loading…
x
Reference in New Issue
Block a user