mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:06:03 +08:00
[Quantization] Automatically infer AWQ modules_to_not_convert field (#26909)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
bfe0b4bd2a
commit
352c0c8a28
@ -1,9 +1,10 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Any, Union
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
@ -13,12 +14,17 @@ from vllm.model_executor.layers.linear import (
|
||||
LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
||||
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
||||
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -57,7 +63,7 @@ class AWQConfig(QuantizationConfig):
|
||||
f"modules_to_not_convert={self.modules_to_not_convert})"
|
||||
)
|
||||
|
||||
def get_name(self) -> QuantizationMethods:
|
||||
def get_name(self) -> "QuantizationMethods":
|
||||
return "awq"
|
||||
|
||||
def get_supported_act_dtypes(self) -> list[torch.dtype]:
|
||||
@ -90,7 +96,12 @@ class AWQConfig(QuantizationConfig):
|
||||
self, layer: torch.nn.Module, prefix: str
|
||||
) -> Union["LinearMethodBase", "QuantizeMethodBase"] | None:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||
if is_layer_skipped(
|
||||
prefix,
|
||||
self.modules_to_not_convert,
|
||||
self.packed_modules_mapping,
|
||||
skip_with_substr=True,
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
return AWQLinearMethod(self)
|
||||
elif isinstance(layer, FusedMoE):
|
||||
@ -128,9 +139,26 @@ class AWQConfig(QuantizationConfig):
|
||||
return AWQMoEMethod(awq_marlin_config, layer.moe_config)
|
||||
return None
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
if self.modules_to_not_convert:
|
||||
self.modules_to_not_convert = hf_to_vllm_mapper.apply_list(
|
||||
self.modules_to_not_convert
|
||||
)
|
||||
|
||||
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: list[str]):
|
||||
return any(module_name in prefix for module_name in modules_to_not_convert)
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
if self.modules_to_not_convert:
|
||||
return
|
||||
|
||||
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
metadata = get_safetensors_params_metadata(model_name, revision=revision)
|
||||
layers = {param_name.rsplit(".", 1)[0] for param_name in metadata}
|
||||
quant_layers: set[str] = {
|
||||
param_name.rsplit(".", 1)[0]
|
||||
for param_name, info in metadata.items()
|
||||
if (dtype := info.get("dtype", None))
|
||||
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
|
||||
}
|
||||
self.modules_to_not_convert = list(layers - quant_layers)
|
||||
|
||||
|
||||
class AWQLinearMethod(LinearMethodBase):
|
||||
|
||||
@ -2,9 +2,10 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import torch
|
||||
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
|
||||
from torch.nn import Parameter
|
||||
|
||||
import vllm.model_executor.layers.fused_moe # noqa
|
||||
@ -27,8 +28,7 @@ from vllm.model_executor.layers.linear import (
|
||||
UnquantizedLinearMethod,
|
||||
set_weight_attrs,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig, is_layer_skipped_awq
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
@ -49,10 +49,16 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
verify_marlin_supported,
|
||||
verify_marlin_supports_shape,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
||||
from vllm.model_executor.parameter import GroupQuantScaleParameter, PackedvLLMParameter
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.transformers_utils.config import get_safetensors_params_metadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -106,7 +112,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_name(cls) -> QuantizationMethods:
|
||||
def get_name(cls) -> "QuantizationMethods":
|
||||
return "awq_marlin"
|
||||
|
||||
@classmethod
|
||||
@ -142,7 +148,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
@classmethod
|
||||
def override_quantization_method(
|
||||
cls, hf_quant_cfg, user_quant
|
||||
) -> QuantizationMethods | None:
|
||||
) -> Optional["QuantizationMethods"]:
|
||||
can_convert = cls.is_awq_marlin_compatible(hf_quant_cfg)
|
||||
is_valid_user_quant = (
|
||||
user_quant is None or user_quant == "marlin" or user_quant == "awq_marlin"
|
||||
@ -171,7 +177,9 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
if isinstance(layer, LinearBase) or (
|
||||
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
||||
):
|
||||
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||
if is_layer_skipped(
|
||||
prefix, self.modules_to_not_convert, self.packed_modules_mapping
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
# Check if the layer is supported by AWQMarlin.
|
||||
if not check_marlin_supports_layer(layer, self.group_size):
|
||||
@ -186,9 +194,7 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
elif isinstance(layer, FusedMoE):
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import MoeWNA16Config
|
||||
|
||||
if is_layer_skipped_awq(
|
||||
prefix, getattr(self, "modules_to_not_convert", [])
|
||||
):
|
||||
if is_layer_skipped(prefix, getattr(self, "modules_to_not_convert", [])):
|
||||
return UnquantizedFusedMoEMethod(layer.moe_config)
|
||||
if not check_moe_marlin_supports_layer(layer, self.group_size):
|
||||
logger.warning_once(
|
||||
@ -226,6 +232,27 @@ class AWQMarlinConfig(QuantizationConfig):
|
||||
quant_type=cls.TYPE_MAP[num_bits], group_size=group_size, has_zp=zero_point
|
||||
)
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
if self.modules_to_not_convert:
|
||||
self.modules_to_not_convert = hf_to_vllm_mapper.apply_list(
|
||||
self.modules_to_not_convert
|
||||
)
|
||||
|
||||
def maybe_update_config(self, model_name: str, revision: str | None = None):
|
||||
if self.modules_to_not_convert:
|
||||
return
|
||||
|
||||
unquant_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
metadata = get_safetensors_params_metadata(model_name, revision=revision)
|
||||
layers = {param_name.rsplit(".", 1)[0] for param_name in metadata}
|
||||
quant_layers: set[str] = {
|
||||
param_name.rsplit(".", 1)[0]
|
||||
for param_name, info in metadata.items()
|
||||
if (dtype := info.get("dtype", None))
|
||||
and _SAFETENSORS_TO_TORCH_DTYPE[dtype] not in unquant_dtypes
|
||||
}
|
||||
self.modules_to_not_convert = list(layers - quant_layers)
|
||||
|
||||
|
||||
class AWQMarlinLinearMethod(LinearMethodBase):
|
||||
"""Linear method for AWQ Marlin.
|
||||
|
||||
@ -32,6 +32,7 @@ from vllm.utils.collection_utils import is_list_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.models.utils import WeightsMapper
|
||||
else:
|
||||
QuantizationMethods = str
|
||||
|
||||
@ -164,7 +165,7 @@ class GPTQConfig(QuantizationConfig):
|
||||
|
||||
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
||||
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper):
|
||||
def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"):
|
||||
if self.modules_in_block_to_quantize is not None:
|
||||
self.modules_in_block_to_quantize = hf_to_vllm_mapper.apply_list(
|
||||
self.modules_in_block_to_quantize
|
||||
|
||||
@ -24,12 +24,10 @@ from vllm.model_executor.layers.quantization import (
|
||||
QuantizationConfig,
|
||||
QuantizationMethods,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.awq import (
|
||||
AWQLinearMethod,
|
||||
is_layer_skipped_awq,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.awq import AWQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8LinearMethod
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -139,7 +137,9 @@ class IPEXConfig(QuantizationConfig):
|
||||
) -> Optional["LinearMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if self.method == "awq":
|
||||
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
|
||||
if is_layer_skipped(
|
||||
prefix, self.modules_to_not_convert, self.packed_modules_mapping
|
||||
):
|
||||
return UnquantizedLinearMethod()
|
||||
return IPEXAWQLinearMethod(self)
|
||||
if self.method == "gptq":
|
||||
|
||||
@ -285,7 +285,18 @@ def is_layer_skipped(
|
||||
prefix: str,
|
||||
ignored_layers: list[str],
|
||||
fused_mapping: Mapping[str, list[str]] = MappingProxyType({}),
|
||||
*,
|
||||
skip_with_substr: bool = False,
|
||||
) -> bool:
|
||||
def prefix_full_match(prefix: str, ignored_layers: list[str]) -> bool:
|
||||
return prefix in ignored_layers
|
||||
|
||||
# For case like: ignored_layers = ["self_attn"]
|
||||
def substr_match(prefix: str, ignored_layers: list[str]) -> bool:
|
||||
return any(layer in prefix for layer in ignored_layers)
|
||||
|
||||
match_func = substr_match if skip_with_substr else prefix_full_match
|
||||
|
||||
# prefix: model.layers.0.self_attn.q_proj
|
||||
# proj_name: q_proj
|
||||
proj_name = prefix.split(".")[-1]
|
||||
@ -302,7 +313,7 @@ def is_layer_skipped(
|
||||
|
||||
is_skipped = None
|
||||
for shard_prefix in shard_prefixes:
|
||||
is_shard_skipped = shard_prefix in ignored_layers
|
||||
is_shard_skipped = match_func(shard_prefix, ignored_layers)
|
||||
|
||||
if is_skipped is None:
|
||||
is_skipped = is_shard_skipped
|
||||
@ -312,16 +323,16 @@ def is_layer_skipped(
|
||||
"are quantized. All shards of fused layers "
|
||||
"to have the same precision."
|
||||
)
|
||||
elif "experts" in prefix:
|
||||
elif "experts" in prefix and not skip_with_substr:
|
||||
expert_ignore_layers = filter(
|
||||
lambda layer_name: "experts" in layer_name, ignored_layers
|
||||
)
|
||||
return any(
|
||||
[
|
||||
prefix in layer_name
|
||||
for layer_name in ignored_layers
|
||||
if "experts" in layer_name
|
||||
]
|
||||
prefix in layer_name if not skip_with_substr else layer_name in prefix
|
||||
for layer_name in expert_ignore_layers
|
||||
)
|
||||
else:
|
||||
is_skipped = prefix in ignored_layers
|
||||
is_skipped = match_func(prefix, ignored_layers)
|
||||
|
||||
assert is_skipped is not None
|
||||
return is_skipped
|
||||
|
||||
@ -42,8 +42,6 @@ from typing_extensions import TypeVar
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.multimodal import BaseDummyOptions
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
||||
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
||||
from vllm.model_executor.layers.resampler import (
|
||||
BaseResampler,
|
||||
Resampler2,
|
||||
@ -1514,11 +1512,6 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
assert self.version == (4, 0)
|
||||
|
||||
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
||||
if isinstance(quant_config, (AWQConfig, AWQMarlinConfig)):
|
||||
return None
|
||||
return quant_config
|
||||
|
||||
def init_llm(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
@ -1532,7 +1525,6 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
quant_config = self._maybe_ignore_quant_config(quant_config)
|
||||
model = Idefics2VisionTransformer(
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
@ -1550,7 +1542,6 @@ class MiniCPMV4_0(MiniCPMVBaseModel, SupportsLoRA):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
quant_config = self._maybe_ignore_quant_config(quant_config)
|
||||
with set_default_torch_dtype(torch.float16):
|
||||
# The resampler in 4.0 remains consistent with the one in 2.5/2.6.
|
||||
resampler = Resampler2_5(
|
||||
@ -1619,11 +1610,6 @@ class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
assert self.version == (4, 5)
|
||||
|
||||
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
||||
if isinstance(quant_config, (AWQConfig, AWQMarlinConfig)):
|
||||
return None
|
||||
return quant_config
|
||||
|
||||
def init_llm(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
@ -1637,7 +1623,6 @@ class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
quant_config = self._maybe_ignore_quant_config(quant_config)
|
||||
model = Idefics2VisionTransformer(
|
||||
config.vision_config,
|
||||
quant_config=quant_config,
|
||||
@ -1655,7 +1640,6 @@ class MiniCPMV4_5(MiniCPMVBaseModel, SupportsLoRA):
|
||||
quant_config: QuantizationConfig | None = None,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
quant_config = self._maybe_ignore_quant_config(quant_config)
|
||||
with set_default_torch_dtype(torch.float16):
|
||||
# The resampler in 4.0 remains consistent with the one in 2.5/2.6.
|
||||
resampler = Resampler4_5(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user