From b9986454fe8ba80e2a109d069397b6b59aae658b Mon Sep 17 00:00:00 2001 From: Srikanth Srinivas Date: Sun, 2 Feb 2025 21:46:19 -0800 Subject: [PATCH] Fix for attention layers to remain unquantized during moe_wn16 quant (#12570) Fix to AWQ quant loading of the new R1 model The new optimized MoE kernels for a large number of experts `moe_wn16` uses AWQ quant which requires the attention layers to be in 16bit The current merge has broken this, and the `get_quant_method` must return None for it to work correctly again --------- Signed-off-by: Srikanth Srinivas Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Signed-off-by: Beim Signed-off-by: rshaw@neuralmagic.com Signed-off-by: mgoin Signed-off-by: npanpaliya Signed-off-by: Aleksandr Malyshev Signed-off-by: Lucas Wilkinson Signed-off-by: simon-mo Signed-off-by: Cody Yu Signed-off-by: Chen Zhang Signed-off-by: Tyler Michael Smith Signed-off-by: Ryan N Signed-off-by: Brian Dellabetta Signed-off-by: Jee Jee Li Signed-off-by: Rahul Tuli Signed-off-by: Russell Bryant Signed-off-by: simon-mo Signed-off-by: Vicente Herrera Signed-off-by: Jinzhen Lin Signed-off-by: Woosuk Kwon Signed-off-by: Shawn Du Signed-off-by: Kunshang Ji Signed-off-by: youkaichao Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Beim <805908499@qq.com> Co-authored-by: Robert Shaw <114415538+robertgshaw2-redhat@users.noreply.github.com> Co-authored-by: mgoin Co-authored-by: simon-mo Co-authored-by: Nishidha Co-authored-by: Lucas Wilkinson Co-authored-by: Aleksandr Malyshev <164964928+maleksan85@users.noreply.github.com> Co-authored-by: Aleksandr Malyshev Co-authored-by: Woosuk Kwon Co-authored-by: simon-mo Co-authored-by: Michael Goin Co-authored-by: Zhuohan Li Co-authored-by: Tyler Michael Smith Co-authored-by: Alexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com> Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com> Co-authored-by: Cody Yu Co-authored-by: Chen Zhang Co-authored-by: Kevin H. Luu Co-authored-by: Tyler Michael Smith Co-authored-by: Ryan Nguyen <96593302+xpbowler@users.noreply.github.com> Co-authored-by: Brian Dellabetta Co-authored-by: fade_away <1028552010@qq.com> Co-authored-by: weilong.yu Co-authored-by: Jee Jee Li Co-authored-by: Eldar Kurtic Co-authored-by: Rahul Tuli Co-authored-by: Russell Bryant Co-authored-by: Vicente Herrera Co-authored-by: Jinzhen Lin Co-authored-by: Shawn Du Co-authored-by: Kunshang Ji Co-authored-by: youkaichao --- vllm/model_executor/layers/quantization/moe_wna16.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/moe_wna16.py b/vllm/model_executor/layers/quantization/moe_wna16.py index 1ae765a2260f3..56fa597e20131 100644 --- a/vllm/model_executor/layers/quantization/moe_wna16.py +++ b/vllm/model_executor/layers/quantization/moe_wna16.py @@ -7,7 +7,8 @@ import torch from vllm.distributed import get_tensor_model_parallel_rank, get_tp_group from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) -from vllm.model_executor.layers.linear import UnquantizedLinearMethod +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig from vllm.model_executor.layers.quantization.base_config import ( @@ -125,9 +126,7 @@ class MoeWNA16Config(QuantizationConfig): prefix: str) -> Optional["QuantizeMethodBase"]: if is_layer_skipped_quant(prefix, self.modules_to_not_convert): return UnquantizedLinearMethod() - elif isinstance(layer, FusedMoE): - return MoeWNA16Method(self) - else: + elif isinstance(layer, LinearBase): if self.linear_quant_method == "gptq": if self.use_marlin: return GPTQMarlinConfig.from_config( @@ -144,6 +143,9 @@ class MoeWNA16Config(QuantizationConfig): self.full_config).get_quant_method(layer, prefix) else: raise ValueError("moe_wna16 only support gptq and awq.") + elif isinstance(layer, FusedMoE): + return MoeWNA16Method(self) + return None def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]):