From 183a70967a90ef06e614ad020ead7d27e87b7688 Mon Sep 17 00:00:00 2001 From: JartX Date: Mon, 1 Sep 2025 05:33:40 +0200 Subject: [PATCH] [BUGFIX] GPTQ quantization compatibility for Qwen3 MOE models (AutoGPTQ and AutoRound-GPTQ) (#23994) Signed-off-by: JartX Signed-off-by: Isotr0py Co-authored-by: Isotr0py --- vllm/model_executor/layers/quantization/gptq.py | 8 +++++++- vllm/model_executor/layers/quantization/gptq_marlin.py | 3 +++ vllm/model_executor/models/qwen3_moe.py | 10 +++++++--- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index f18c936bac605..2272709f93091 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -37,6 +37,7 @@ class GPTQConfig(QuantizationConfig): desc_act: bool, lm_head_quantized: bool, dynamic: dict[str, dict[str, Union[int, bool]]], + autoround_version: str = "", ) -> None: # GPTQModel use `dynamic` config property to allow per module # quantization config so each module can be individually optimized. @@ -74,6 +75,9 @@ class GPTQConfig(QuantizationConfig): "Currently, only 2/3/4/8-bit weight quantization is " f"supported for GPTQ, but got {self.weight_bits} bits.") + # used to identify GPTQ model quantized by autoround + self.autoround_version = autoround_version + def __repr__(self) -> str: return (f"GPTQConfig(weight_bits={self.weight_bits}, " f"group_size={self.group_size}, " @@ -108,8 +112,10 @@ class GPTQConfig(QuantizationConfig): desc_act = cls.get_from_keys(config, ["desc_act"]) lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False) + autoround_version = cls.get_from_keys_or(config, ["autoround_version"], + default="") return cls(weight_bits, group_size, desc_act, lm_head_quantized, - dynamic) + dynamic, autoround_version) def get_quant_method( self, layer: torch.nn.Module, prefix: str diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 350975966668e..3644d91f64e3c 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -119,6 +119,9 @@ class GPTQMarlinConfig(QuantizationConfig): self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)] + # used to identify GPTQ model quantized by autoround + self.autoround_version = full_config.get("autoround_version", "") + def __repr__(self) -> str: return (f"GPTQMarlinConfig(quant_type={self.quant_type}, " f"group_size={self.group_size}, " diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 94e6a66bea5cb..a7e0a00350e6b 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -159,9 +159,13 @@ class Qwen3MoeSparseMoeBlock(nn.Module): def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): # GPTQ configs do not have a list of ignored modules, however AutoGPTQ - # seems to avoid gate quantization. - # See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4 - if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + # seems to avoid gate quantization while AutoRound does. + # See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4, + # and https://huggingface.co/jart25/Qwen3-Coder-30B-A3B-Instruct-Int4-gptq + if isinstance( + quant_config, + (GPTQConfig, + GPTQMarlinConfig)) and not quant_config.autoround_version: return None return quant_config