[Quantization] Improve AWQ logic (#19431)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-06-12 19:02:11 +08:00 committed by GitHub
parent c9280e6346
commit 73e2e0118f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,19 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any, Optional
from typing import Any, Optional, Union
import torch
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.parameter import (GroupQuantScaleParameter,
PackedvLLMParameter)
logger = init_logger(__name__)
class AWQConfig(QuantizationConfig):
"""Config class for AWQ.
@ -74,12 +78,42 @@ class AWQConfig(QuantizationConfig):
config, ["modules_to_not_convert"], None)
return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["LinearMethodBase"]:
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional[Union["LinearMethodBase", "QuantizeMethodBase"]]:
if isinstance(layer, LinearBase):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return AWQLinearMethod(self)
elif isinstance(layer, FusedMoE):
# Lazy import to avoid circular import.
from .awq_marlin import AWQMarlinConfig, AWQMoEMethod
from .moe_wna16 import MoeWNA16Config
from .utils.marlin_utils import check_moe_marlin_supports_layer
if not check_moe_marlin_supports_layer(layer, self.group_size):
logger.warning_once(
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
"Falling back to Moe WNA16 kernels.")
config = {
"quant_method": "awq",
"bits": self.weight_bits,
"group_size": self.group_size,
"zero_point": self.zero_point,
"lm_head": False,
}
return MoeWNA16Config.from_config(config).get_quant_method(
layer, prefix)
marlin_compatible_config_dict = {
"quant_method": "awq",
"bits": self.weight_bits,
"group_size": self.group_size,
"zero_point": self.zero_point,
"lm_head": False,
"modules_to_not_convert": self.modules_to_not_convert,
}
awq_marlin_config = AWQMarlinConfig.from_config(
marlin_compatible_config_dict)
return AWQMoEMethod(awq_marlin_config)
return None