diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index b69575c7e96d..56d1dfe135b3 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -761,8 +761,8 @@ def get_moe_wna16_block_config(config: dict[str, def should_moe_wna16_use_cuda(num_valid_tokens: int, group_size: int, num_experts: int, bit: int): - return bit == 4 and group_size in [32, 64, 128] and \ - num_valid_tokens / num_experts <= 6 + return current_platform.is_cuda() and bit == 4 and \ + group_size in [32, 64, 128] and num_valid_tokens / num_experts <= 6 def get_default_config( diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index d3ab1be3bee0..f18c936bac60 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -10,10 +10,11 @@ import torch from torch.nn.parameter import Parameter from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe.layer import FusedMoE from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.utils.gptq_utils import ( get_linear_quant_method) from vllm.model_executor.parameter import (ChannelQuantScaleParameter, @@ -110,8 +111,23 @@ class GPTQConfig(QuantizationConfig): return cls(weight_bits, group_size, desc_act, lm_head_quantized, dynamic) - def get_quant_method(self, layer: torch.nn.Module, - prefix: str) -> Optional["GPTQLinearMethod"]: + def get_quant_method( + self, layer: torch.nn.Module, prefix: str + ) -> Optional[Union["GPTQLinearMethod", "QuantizeMethodBase"]]: + if isinstance(layer, FusedMoE): + # GPTQ MoE support: fall back to MoeWNA16 for broad compatibility + from .moe_wna16 import MoeWNA16Config + + config = { + "quant_method": "gptq", + "bits": self.weight_bits, + "group_size": self.group_size, + "sym": True, # GPTQ typically uses symmetric quantization + "lm_head": False, + } + return MoeWNA16Config.from_config(config).get_quant_method( + layer, prefix) + return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)