From a9082a4d144e516d0ee00bacfa0dca9609b6b2c3 Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Mon, 25 Aug 2025 21:40:20 +0800 Subject: [PATCH] [Bugfix] Fix Qwen3 MoE GPTQ inference (#23490) Signed-off-by: Isotr0py --- vllm/model_executor/models/qwen3_moe.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/models/qwen3_moe.py b/vllm/model_executor/models/qwen3_moe.py index 2812f79a66b70..8498f61b35fdd 100644 --- a/vllm/model_executor/models/qwen3_moe.py +++ b/vllm/model_executor/models/qwen3_moe.py @@ -45,6 +45,9 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -146,11 +149,20 @@ class Qwen3MoeSparseMoeBlock(nn.Module): enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts) - self.gate = ReplicatedLinear(config.hidden_size, - config.num_experts, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.gate") + self.gate = ReplicatedLinear( + config.hidden_size, + config.num_experts, + bias=False, + quant_config=self._maybe_ignore_quant_config(quant_config), + prefix=f"{prefix}.gate") + + def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig): + # GPTQ configs do not have a list of ignored modules, however AutoGPTQ + # seems to avoid gate quantization. + # See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4 + if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)): + return None + return quant_config def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # NOTE: hidden_states can have either 1D or 2D shape. @@ -682,4 +694,4 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA, return loader.load_weights(weights) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: - return self.model.get_expert_mapping() \ No newline at end of file + return self.model.get_expert_mapping()