diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 551a4e7cebc5..0cf8b69f9f6b 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -1,9 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Union +from typing import TYPE_CHECKING, Any, Union import torch +from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -13,12 +14,17 @@ from vllm.model_executor.layers.linear import ( LinearMethodBase, UnquantizedLinearMethod, ) -from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter +from vllm.transformers_utils.config import get_safetensors_params_metadata + +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization import QuantizationMethods + from vllm.model_executor.models.utils import WeightsMapper logger = init_logger(__name__) @@ -57,7 +63,7 @@ class AWQConfig(QuantizationConfig): f"modules_to_not_convert={self.modules_to_not_convert})" ) - def get_name(self) -> QuantizationMethods: + def get_name(self) -> "QuantizationMethods": return "awq" def get_supported_act_dtypes(self) -> list[torch.dtype]: @@ -90,7 +96,12 @@ class AWQConfig(QuantizationConfig): self, layer: torch.nn.Module, prefix: str ) -> Union["LinearMethodBase", "QuantizeMethodBase"] | None: if isinstance(layer, LinearBase): - if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + if is_layer_skipped( + prefix, + self.modules_to_not_convert, + self.packed_modules_mapping, + skip_with_substr=True, + ): return UnquantizedLinearMethod() return AWQLinearMethod(self) elif isinstance(layer, FusedMoE): @@ -128,9 +139,26 @@ class AWQConfig(QuantizationConfig): return AWQMoEMethod(awq_marlin_config, layer.moe_config) return None + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + if self.modules_to_not_convert: + self.modules_to_not_convert = hf_to_vllm_mapper.apply_list( + self.modules_to_not_convert + ) -def is_layer_skipped_awq(prefix: str, modules_to_not_convert: list[str]): - return any(module_name in prefix for module_name in modules_to_not_convert) + def maybe_update_config(self, model_name: str, revision: str | None = None): + if self.modules_to_not_convert: + return + + unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32] + metadata = get_safetensors_params_metadata(model_name, revision=revision) + layers = {param_name.rsplit(".", 1)[0] for param_name in metadata} + quant_layers: set[str] = { + param_name.rsplit(".", 1)[0] + for param_name, info in metadata.items() + if (dtype := info.get("dtype", None)) + and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes + } + self.modules_to_not_convert = list(layers - quant_layers) class AWQLinearMethod(LinearMethodBase): diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index d96c657e0119..1b4e2cb87d1a 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -2,9 +2,10 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import torch +from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from torch.nn import Parameter import vllm.model_executor.layers.fused_moe # noqa @@ -27,8 +28,7 @@ from vllm.model_executor.layers.linear import ( UnquantizedLinearMethod, set_weight_attrs, ) -from vllm.model_executor.layers.quantization import QuantizationMethods -from vllm.model_executor.layers.quantization.awq import AWQConfig, is_layer_skipped_awq +from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, @@ -49,10 +49,16 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import ( verify_marlin_supported, verify_marlin_supports_shape, ) +from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter from vllm.platforms import current_platform from vllm.scalar_type import scalar_types +from vllm.transformers_utils.config import get_safetensors_params_metadata + +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization import QuantizationMethods + from vllm.model_executor.models.utils import WeightsMapper logger = init_logger(__name__) @@ -106,7 +112,7 @@ class AWQMarlinConfig(QuantizationConfig): ) @classmethod - def get_name(cls) -> QuantizationMethods: + def get_name(cls) -> "QuantizationMethods": return "awq_marlin" @classmethod @@ -142,7 +148,7 @@ class AWQMarlinConfig(QuantizationConfig): @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant - ) -> QuantizationMethods | None: + ) -> Optional["QuantizationMethods"]: can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg) is_valid_user_quant = ( user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin" @@ -171,7 +177,9 @@ class AWQMarlinConfig(QuantizationConfig): if isinstance(layer, LinearBase) or ( isinstance(layer, ParallelLMHead) and self.lm_head_quantized ): - if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + if is_layer_skipped( + prefix, self.modules_to_not_convert, self.packed_modules_mapping + ): return UnquantizedLinearMethod() # Check if the layer is supported by AWQMarlin. if not check_marlin_supports_layer(layer, self.group_size): @@ -186,9 +194,7 @@ class AWQMarlinConfig(QuantizationConfig): elif isinstance(layer, FusedMoE): from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config - if is_layer_skipped_awq( - prefix, getattr(self, "modules_to_not_convert", []) - ): + if is_layer_skipped(prefix, getattr(self, "modules_to_not_convert", [])): return UnquantizedFusedMoEMethod(layer.moe_config) if not check_moe_marlin_supports_layer(layer, self.group_size): logger.warning_once( @@ -226,6 +232,27 @@ class AWQMarlinConfig(QuantizationConfig): quant_type=cls.TYPE_MAP[num_bits], group_size=group_size, has_zp=zero_point ) + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): + if self.modules_to_not_convert: + self.modules_to_not_convert = hf_to_vllm_mapper.apply_list( + self.modules_to_not_convert + ) + + def maybe_update_config(self, model_name: str, revision: str | None = None): + if self.modules_to_not_convert: + return + + unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32] + metadata = get_safetensors_params_metadata(model_name, revision=revision) + layers = {param_name.rsplit(".", 1)[0] for param_name in metadata} + quant_layers: set[str] = { + param_name.rsplit(".", 1)[0] + for param_name, info in metadata.items() + if (dtype := info.get("dtype", None)) + and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes + } + self.modules_to_not_convert = list(layers - quant_layers) + class AWQMarlinLinearMethod(LinearMethodBase): """Linear method for AWQ Marlin. diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index a3cd68948bc8..b7bc3abeb724 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -32,6 +32,7 @@ from vllm.utils.collection_utils import is_list_of if TYPE_CHECKING: from vllm.model_executor.layers.quantization import QuantizationMethods + from vllm.model_executor.models.utils import WeightsMapper else: QuantizationMethods = str @@ -164,7 +165,7 @@ class GPTQConfig(QuantizationConfig): return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod) - def apply_vllm_mapper(self, hf_to_vllm_mapper): + def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): if self.modules_in_block_to_quantize is not None: self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list( self.modules_in_block_to_quantize diff --git a/vllm/model_executor/layers/quantization/ipex_quant.py b/vllm/model_executor/layers/quantization/ipex_quant.py index 8616e8f4516a..5b3aabfde0c1 100644 --- a/vllm/model_executor/layers/quantization/ipex_quant.py +++ b/vllm/model_executor/layers/quantization/ipex_quant.py @@ -24,12 +24,10 @@ from vllm.model_executor.layers.quantization import ( QuantizationConfig, QuantizationMethods, ) -from vllm.model_executor.layers.quantization.awq import ( - AWQLinearMethod, - is_layer_skipped_awq, -) +from vllm.model_executor.layers.quantization.awq import AWQLinearMethod from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8LinearMethod from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod +from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -139,7 +137,9 @@ class IPEXConfig(QuantizationConfig): ) -> Optional["LinearMethodBase"]: if isinstance(layer, LinearBase): if self.method == "awq": - if is_layer_skipped_awq(prefix, self.modules_to_not_convert): + if is_layer_skipped( + prefix, self.modules_to_not_convert, self.packed_modules_mapping + ): return UnquantizedLinearMethod() return IPEXAWQLinearMethod(self) if self.method == "gptq": diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index c2ecf4c02828..d056d3404385 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -285,7 +285,18 @@ def is_layer_skipped( prefix: str, ignored_layers: list[str], fused_mapping: Mapping[str, list[str]] = MappingProxyType({}), + *, + skip_with_substr: bool = False, ) -> bool: + def prefix_full_match(prefix: str, ignored_layers: list[str]) -> bool: + return prefix in ignored_layers + + # For case like: ignored_layers = ["self_attn"] + def substr_match(prefix: str, ignored_layers: list[str]) -> bool: + return any(layer in prefix for layer in ignored_layers) + + match_func = substr_match if skip_with_substr else prefix_full_match + # prefix: model.layers.0.self_attn.q_proj # proj_name: q_proj proj_name = prefix.split(".")[-1] @@ -302,7 +313,7 @@ def is_layer_skipped( is_skipped = None for shard_prefix in shard_prefixes: - is_shard_skipped = shard_prefix in ignored_layers + is_shard_skipped = match_func(shard_prefix, ignored_layers) if is_skipped is None: is_skipped = is_shard_skipped @@ -312,16 +323,16 @@ def is_layer_skipped( "are quantized. All shards of fused layers " "to have the same precision." ) - elif "experts" in prefix: + elif "experts" in prefix and not skip_with_substr: + expert_ignore_layers = filter( + lambda layer_name: "experts" in layer_name, ignored_layers + ) return any( - [ - prefix in layer_name - for layer_name in ignored_layers - if "experts" in layer_name - ] + prefix in layer_name if not skip_with_substr else layer_name in prefix + for layer_name in expert_ignore_layers ) else: - is_skipped = prefix in ignored_layers + is_skipped = match_func(prefix, ignored_layers) assert is_skipped is not None return is_skipped diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 147661babca1..09937706f8c5 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -42,8 +42,6 @@ from typing_extensions import TypeVar from vllm.config import VllmConfig from vllm.config.multimodal import BaseDummyOptions from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig from vllm.model_executor.layers.resampler import ( BaseResampler, Resampler2, @@ -1514,11 +1512,6 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): super().__init__(vllm_config=vllm_config, prefix=prefix) assert self.version == (4, 0) - def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): - if isinstance(quant_config, (AWQConfig, AWQMarlinConfig)): - return None - return quant_config - def init_llm( self, vllm_config: VllmConfig, @@ -1532,7 +1525,6 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> nn.Module: - quant_config = self._maybe_ignore_quant_config(quant_config) model = Idefics2VisionTransformer( config.vision_config, quant_config=quant_config, @@ -1550,7 +1542,6 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA): quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> nn.Module: - quant_config = self._maybe_ignore_quant_config(quant_config) with set_default_torch_dtype(torch.float16): # The resampler in 4.0 remains consistent with the one in 2.5/2.6. resampler = Resampler2_5( @@ -1619,11 +1610,6 @@ class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA): super().__init__(vllm_config=vllm_config, prefix=prefix) assert self.version == (4, 5) - def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): - if isinstance(quant_config, (AWQConfig, AWQMarlinConfig)): - return None - return quant_config - def init_llm( self, vllm_config: VllmConfig, @@ -1637,7 +1623,6 @@ class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA): quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> nn.Module: - quant_config = self._maybe_ignore_quant_config(quant_config) model = Idefics2VisionTransformer( config.vision_config, quant_config=quant_config, @@ -1655,7 +1640,6 @@ class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA): quant_config: QuantizationConfig | None = None, prefix: str = "", ) -> nn.Module: - quant_config = self._maybe_ignore_quant_config(quant_config) with set_default_torch_dtype(torch.float16): # The resampler in 4.0 remains consistent with the one in 2.5/2.6. resampler = Resampler4_5(