diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index f8bc3ab5e7d1e..fe42e26a17061 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -1,19 +1,23 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Any, Optional +from typing import Any, Optional, Union import torch from vllm import _custom_ops as ops +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.parameter import (GroupQuantScaleParameter, PackedvLLMParameter) +logger = init_logger(__name__) + class AWQConfig(QuantizationConfig): """Config class for AWQ. @@ -74,12 +78,42 @@ class AWQConfig(QuantizationConfig): config, ["modules_to_not_convert"], None) return cls(weight_bits, group_size, zero_point, modules_to_not_convert) - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["LinearMethodBase"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]: if isinstance(layer, LinearBase): if is_layer_skipped_awq(prefix, self.modules_to_not_convert): return UnquantizedLinearMethod() return AWQLinearMethod(self) + elif isinstance(layer, FusedMoE): + # Lazy import to avoid circular import. + from .awq_marlin import AWQMarlinConfig, AWQMoEMethod + from .moe_wna16 import MoeWNA16Config + from .utils.marlin_utils import check_moe_marlin_supports_layer + if not check_moe_marlin_supports_layer(layer, self.group_size): + logger.warning_once( + f"Layer '{prefix}' is not supported by AWQMoeMarlin. " + "Falling back to Moe WNA16 kernels.") + config = { + "quant_method": "awq", + "bits": self.weight_bits, + "group_size": self.group_size, + "zero_point": self.zero_point, + "lm_head": False, + } + return MoeWNA16Config.from_config(config).get_quant_method( + layer, prefix) + marlin_compatible_config_dict = { + "quant_method": "awq", + "bits": self.weight_bits, + "group_size": self.group_size, + "zero_point": self.zero_point, + "lm_head": False, + "modules_to_not_convert": self.modules_to_not_convert, + } + awq_marlin_config = AWQMarlinConfig.from_config( + marlin_compatible_config_dict) + return AWQMoEMethod(awq_marlin_config) return None