[Quantization] Automatically infer AWQ modules_to_not_convert field (#26909)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-10-21 09:49:28 +08:00 committed by GitHub
parent bfe0b4bd2a
commit 352c0c8a28
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 96 additions and 45 deletions

View File

@ -1,9 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Union
from typing import TYPE_CHECKING, Any, Union
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from vllm import _custom_ops as ops
from vllm.logger import init_logger
@ -13,12 +14,17 @@ from vllm.model_executor.layers.linear import (
LinearMethodBase,
UnquantizedLinearMethod,
)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
from vllm.transformers_utils.config import get_safetensors_params_metadata
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.models.utils import WeightsMapper
logger = init_logger(__name__)
@ -57,7 +63,7 @@ class AWQConfig(QuantizationConfig):
f"modules_to_not_convert={self.modules_to_not_convert})"
)
def get_name(self) -> QuantizationMethods:
def get_name(self) -> "QuantizationMethods":
return "awq"
def get_supported_act_dtypes(self) -> list[torch.dtype]:
@ -90,7 +96,12 @@ class AWQConfig(QuantizationConfig):
self, layer: torch.nn.Module, prefix: str
) -> Union["LinearMethodBase", "QuantizeMethodBase"] | None:
if isinstance(layer, LinearBase):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
if is_layer_skipped(
prefix,
self.modules_to_not_convert,
self.packed_modules_mapping,
skip_with_substr=True,
):
return UnquantizedLinearMethod()
return AWQLinearMethod(self)
elif isinstance(layer, FusedMoE):
@ -128,9 +139,26 @@ class AWQConfig(QuantizationConfig):
return AWQMoEMethod(awq_marlin_config, layer.moe_config)
return None
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
if self.modules_to_not_convert:
self.modules_to_not_convert = hf_to_vllm_mapper.apply_list(
self.modules_to_not_convert
)
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: list[str]):
return any(module_name in prefix for module_name in modules_to_not_convert)
def maybe_update_config(self, model_name: str, revision: str | None = None):
if self.modules_to_not_convert:
return
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
metadata = get_safetensors_params_metadata(model_name, revision=revision)
layers = {param_name.rsplit(".", 1)[0] for param_name in metadata}
quant_layers: set[str] = {
param_name.rsplit(".", 1)[0]
for param_name, info in metadata.items()
if (dtype := info.get("dtype", None))
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
}
self.modules_to_not_convert = list(layers - quant_layers)
class AWQLinearMethod(LinearMethodBase):

View File

@ -2,9 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Callable
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional
import torch
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from torch.nn import Parameter
import vllm.model_executor.layers.fused_moe # noqa
@ -27,8 +28,7 @@ from vllm.model_executor.layers.linear import (
UnquantizedLinearMethod,
set_weight_attrs,
)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.awq import AWQConfig, is_layer_skipped_awq
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
@ -49,10 +49,16 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
verify_marlin_supported,
verify_marlin_supports_shape,
)
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.transformers_utils.config import get_safetensors_params_metadata
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.models.utils import WeightsMapper
logger = init_logger(__name__)
@ -106,7 +112,7 @@ class AWQMarlinConfig(QuantizationConfig):
)
@classmethod
def get_name(cls) -> QuantizationMethods:
def get_name(cls) -> "QuantizationMethods":
return "awq_marlin"
@classmethod
@ -142,7 +148,7 @@ class AWQMarlinConfig(QuantizationConfig):
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> QuantizationMethods | None:
) -> Optional["QuantizationMethods"]:
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (
user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin"
@ -171,7 +177,9 @@ class AWQMarlinConfig(QuantizationConfig):
if isinstance(layer, LinearBase) or (
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
if is_layer_skipped(
prefix, self.modules_to_not_convert, self.packed_modules_mapping
):
return UnquantizedLinearMethod()
# Check if the layer is supported by AWQMarlin.
if not check_marlin_supports_layer(layer, self.group_size):
@ -186,9 +194,7 @@ class AWQMarlinConfig(QuantizationConfig):
elif isinstance(layer, FusedMoE):
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
if is_layer_skipped_awq(
prefix, getattr(self, "modules_to_not_convert", [])
):
if is_layer_skipped(prefix, getattr(self, "modules_to_not_convert", [])):
return UnquantizedFusedMoEMethod(layer.moe_config)
if not check_moe_marlin_supports_layer(layer, self.group_size):
logger.warning_once(
@ -226,6 +232,27 @@ class AWQMarlinConfig(QuantizationConfig):
quant_type=cls.TYPE_MAP[num_bits], group_size=group_size, has_zp=zero_point
)
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
if self.modules_to_not_convert:
self.modules_to_not_convert = hf_to_vllm_mapper.apply_list(
self.modules_to_not_convert
)
def maybe_update_config(self, model_name: str, revision: str | None = None):
if self.modules_to_not_convert:
return
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
metadata = get_safetensors_params_metadata(model_name, revision=revision)
layers = {param_name.rsplit(".", 1)[0] for param_name in metadata}
quant_layers: set[str] = {
param_name.rsplit(".", 1)[0]
for param_name, info in metadata.items()
if (dtype := info.get("dtype", None))
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
}
self.modules_to_not_convert = list(layers - quant_layers)
class AWQMarlinLinearMethod(LinearMethodBase):
"""Linear method for AWQ Marlin.

View File

@ -32,6 +32,7 @@ from vllm.utils.collection_utils import is_list_of
if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.models.utils import WeightsMapper
else:
QuantizationMethods = str
@ -164,7 +165,7 @@ class GPTQConfig(QuantizationConfig):
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
def apply_vllm_mapper(self, hf_to_vllm_mapper):
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
if self.modules_in_block_to_quantize is not None:
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
self.modules_in_block_to_quantize

View File

@ -24,12 +24,10 @@ from vllm.model_executor.layers.quantization import (
QuantizationConfig,
QuantizationMethods,
)
from vllm.model_executor.layers.quantization.awq import (
AWQLinearMethod,
is_layer_skipped_awq,
)
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8LinearMethod
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
@ -139,7 +137,9 @@ class IPEXConfig(QuantizationConfig):
) -> Optional["LinearMethodBase"]:
if isinstance(layer, LinearBase):
if self.method == "awq":
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
if is_layer_skipped(
prefix, self.modules_to_not_convert, self.packed_modules_mapping
):
return UnquantizedLinearMethod()
return IPEXAWQLinearMethod(self)
if self.method == "gptq":

View File

@ -285,7 +285,18 @@ def is_layer_skipped(
prefix: str,
ignored_layers: list[str],
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
*,
skip_with_substr: bool = False,
) -> bool:
def prefix_full_match(prefix: str, ignored_layers: list[str]) -> bool:
return prefix in ignored_layers
# For case like: ignored_layers = ["self_attn"]
def substr_match(prefix: str, ignored_layers: list[str]) -> bool:
return any(layer in prefix for layer in ignored_layers)
match_func = substr_match if skip_with_substr else prefix_full_match
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
proj_name = prefix.split(".")[-1]
@ -302,7 +313,7 @@ def is_layer_skipped(
is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = shard_prefix in ignored_layers
is_shard_skipped = match_func(shard_prefix, ignored_layers)
if is_skipped is None:
is_skipped = is_shard_skipped
@ -312,16 +323,16 @@ def is_layer_skipped(
"are quantized. All shards of fused layers "
"to have the same precision."
)
elif "experts" in prefix:
elif "experts" in prefix and not skip_with_substr:
expert_ignore_layers = filter(
lambda layer_name: "experts" in layer_name, ignored_layers
)
return any(
[
prefix in layer_name
for layer_name in ignored_layers
if "experts" in layer_name
]
prefix in layer_name if not skip_with_substr else layer_name in prefix
for layer_name in expert_ignore_layers
)
else:
is_skipped = prefix in ignored_layers
is_skipped = match_func(prefix, ignored_layers)
assert is_skipped is not None
return is_skipped

View File

@ -42,8 +42,6 @@ from typing_extensions import TypeVar
from vllm.config import VllmConfig
from vllm.config.multimodal import BaseDummyOptions
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
from vllm.model_executor.layers.resampler import (
BaseResampler,
Resampler2,
@ -1514,11 +1512,6 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
super().__init__(vllm_config=vllm_config, prefix=prefix)
assert self.version == (4, 0)
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
if isinstance(quant_config, (AWQConfig, AWQMarlinConfig)):
return None
return quant_config
def init_llm(
self,
vllm_config: VllmConfig,
@ -1532,7 +1525,6 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> nn.Module:
quant_config = self._maybe_ignore_quant_config(quant_config)
model = Idefics2VisionTransformer(
config.vision_config,
quant_config=quant_config,
@ -1550,7 +1542,6 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> nn.Module:
quant_config = self._maybe_ignore_quant_config(quant_config)
with set_default_torch_dtype(torch.float16):
# The resampler in 4.0 remains consistent with the one in 2.5/2.6.
resampler = Resampler2_5(
@ -1619,11 +1610,6 @@ class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA):
super().__init__(vllm_config=vllm_config, prefix=prefix)
assert self.version == (4, 5)
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
if isinstance(quant_config, (AWQConfig, AWQMarlinConfig)):
return None
return quant_config
def init_llm(
self,
vllm_config: VllmConfig,
@ -1637,7 +1623,6 @@ class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> nn.Module:
quant_config = self._maybe_ignore_quant_config(quant_config)
model = Idefics2VisionTransformer(
config.vision_config,
quant_config=quant_config,
@ -1655,7 +1640,6 @@ class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA):
quant_config: QuantizationConfig | None = None,
prefix: str = "",
) -> nn.Module:
quant_config = self._maybe_ignore_quant_config(quant_config)
with set_default_torch_dtype(torch.float16):
# The resampler in 4.0 remains consistent with the one in 2.5/2.6.
resampler = Resampler4_5(