[Bugfix] Fix Qwen3 MoE GPTQ inference (#23490)

Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
Isotr0py 2025-08-25 21:40:20 +08:00 committed by GitHub
parent e0329ed4b4
commit a9082a4d14
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -45,6 +45,9 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
@ -146,11 +149,20 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
enable_eplb=self.enable_eplb,
num_redundant_experts=self.n_redundant_experts)
self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.gate")
self.gate = ReplicatedLinear(
config.hidden_size,
config.num_experts,
bias=False,
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=f"{prefix}.gate")
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid gate quantization.
# See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
return None
return quant_config
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# NOTE: hidden_states can have either 1D or 2D shape.
@ -682,4 +694,4 @@ class Qwen3MoeForCausalLM(nn.Module, SupportsPP, SupportsLoRA,
return loader.load_weights(weights)
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()
return self.model.get_expert_mapping()