[Quantization] Improve AWQ logic (#19431)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-06-12 19:02:11 +08:00 committed by GitHub
parent c9280e6346
commit 73e2e0118f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,19 +1,23 @@
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional from typing import Any, Optional, Union
import torch import torch
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.parameter import (GroupQuantScaleParameter, from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter) PackedvLLMParameter)
logger = init_logger(__name__)
class AWQConfig(QuantizationConfig): class AWQConfig(QuantizationConfig):
"""Config class for AWQ. """Config class for AWQ.
@ -74,12 +78,42 @@ class AWQConfig(QuantizationConfig):
config, ["modules_to_not_convert"], None) config, ["modules_to_not_convert"], None)
return cls(weight_bits, group_size, zero_point, modules_to_not_convert) return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
def get_quant_method(self, layer: torch.nn.Module, def get_quant_method(
prefix: str) -> Optional["LinearMethodBase"]: self, layer: torch.nn.Module, prefix: str
) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert): if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
return AWQLinearMethod(self) return AWQLinearMethod(self)
elif isinstance(layer, FusedMoE):
# Lazy import to avoid circular import.
from .awq_marlin import AWQMarlinConfig, AWQMoEMethod
from .moe_wna16 import MoeWNA16Config
from .utils.marlin_utils import check_moe_marlin_supports_layer
if not check_moe_marlin_supports_layer(layer, self.group_size):
logger.warning_once(
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
"Falling back to Moe WNA16 kernels.")
config = {
"quant_method": "awq",
"bits": self.weight_bits,
"group_size": self.group_size,
"zero_point": self.zero_point,
"lm_head": False,
}
return MoeWNA16Config.from_config(config).get_quant_method(
layer, prefix)
marlin_compatible_config_dict = {
"quant_method": "awq",
"bits": self.weight_bits,
"group_size": self.group_size,
"zero_point": self.zero_point,
"lm_head": False,
"modules_to_not_convert": self.modules_to_not_convert,
}
awq_marlin_config = AWQMarlinConfig.from_config(
marlin_compatible_config_dict)
return AWQMoEMethod(awq_marlin_config)
return None return None