diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 1ae765a2260f..56fa597e2013 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -7,7 +7,8 @@ import torch from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.linear import UnquantizedLinearMethod +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig from vllm.model_executor.layers.quantization.base_config import ( @@ -125,9 +126,7 @@ class MoeWNA16Config(QuantizationConfig): prefix: str) -> Optional["QuantizeMethodBase"]: if is_layer_skipped_quant(prefix, self.modules_to_not_convert): return UnquantizedLinearMethod() - elif isinstance(layer, FusedMoE): - return MoeWNA16Method(self) - else: + elif isinstance(layer, LinearBase): if self.linear_quant_method == "gptq": if self.use_marlin: return GPTQMarlinConfig.from_config( @@ -144,6 +143,9 @@ class MoeWNA16Config(QuantizationConfig): self.full_config).get_quant_method(layer, prefix) else: raise ValueError("moe_wna16 only support gptq and awq.") + elif isinstance(layer, FusedMoE): + return MoeWNA16Method(self) + return None def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]):