diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 0131a330f70d2..4bedb951a33f5 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -17,6 +17,9 @@ from vllm.model_executor.layers.fused_moe.layer import ( FusedMoEMethodBase, FusedMoeWeightScaleSupported, ) +from vllm.model_executor.layers.fused_moe.unquantized_fused_moe_method import ( + UnquantizedFusedMoEMethod, +) from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( @@ -162,6 +165,8 @@ class MoeWNA16Config(QuantizationConfig): self, layer: torch.nn.Module, prefix: str ) -> Optional["QuantizeMethodBase"]: if is_layer_skipped_quant(prefix, self.modules_to_not_convert): + if isinstance(layer, FusedMoE): + return UnquantizedFusedMoEMethod(layer.moe_config) return UnquantizedLinearMethod() elif isinstance(layer, LinearBase): # Avoid circular import