diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 8cd9c0a7ef253..11a9d4ac5c1ae 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -6,16 +6,13 @@ from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import UnquantizedLinearMethod -from vllm.model_executor.layers.quantization.awq import (AWQConfig, - AWQLinearMethod) -from vllm.model_executor.layers.quantization.awq_marlin import ( - AWQMarlinConfig, AWQMarlinLinearMethod) +from vllm.model_executor.layers.quantization.awq import AWQConfig +from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) -from vllm.model_executor.layers.quantization.gptq import (GPTQConfig, - GPTQLinearMethod) +from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig, GPTQMarlinLinearMethod) + GPTQMarlinConfig) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform @@ -131,18 +128,18 @@ class MoeWNA16Config(QuantizationConfig): else: if self.linear_quant_method == "gptq": if self.use_marlin: - return GPTQMarlinLinearMethod( - GPTQMarlinConfig.from_config(self.full_config)) + return GPTQMarlinConfig.from_config( + self.full_config).get_quant_method(layer, prefix) else: - return GPTQLinearMethod( - GPTQConfig.from_config(self.full_config)) + return GPTQConfig.from_config( + self.full_config).get_quant_method(layer, prefix) elif self.linear_quant_method == "awq": if self.use_marlin: - return AWQMarlinLinearMethod( - AWQMarlinConfig.from_config(self.full_config)) + return AWQMarlinConfig.from_config( + self.full_config).get_quant_method(layer, prefix) else: - return AWQLinearMethod( - AWQConfig.from_config(self.full_config)) + return AWQConfig.from_config( + self.full_config).get_quant_method(layer, prefix) else: raise ValueError("moe_wna16 only support gptq and awq.")