mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 18:38:38 +08:00
[Bugfix] Fix Qwen3 MoE GPTQ inference (#23490)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn>
This commit is contained in:
parent
e0329ed4b4
commit
a9082a4d14
@ -45,6 +45,9 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||||
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||||
|
GPTQMarlinConfig)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
ParallelLMHead, VocabParallelEmbedding)
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
@ -146,12 +149,21 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
|||||||
enable_eplb=self.enable_eplb,
|
enable_eplb=self.enable_eplb,
|
||||||
num_redundant_experts=self.n_redundant_experts)
|
num_redundant_experts=self.n_redundant_experts)
|
||||||
|
|
||||||
self.gate = ReplicatedLinear(config.hidden_size,
|
self.gate = ReplicatedLinear(
|
||||||
|
config.hidden_size,
|
||||||
config.num_experts,
|
config.num_experts,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||||
prefix=f"{prefix}.gate")
|
prefix=f"{prefix}.gate")
|
||||||
|
|
||||||
|
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
||||||
|
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
|
||||||
|
# seems to avoid gate quantization.
|
||||||
|
# See: https://huggingface.co/Qwen/Qwen3-30B-A3B-GPTQ-Int4
|
||||||
|
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
|
||||||
|
return None
|
||||||
|
return quant_config
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||||
# NOTE: hidden_states can have either 1D or 2D shape.
|
# NOTE: hidden_states can have either 1D or 2D shape.
|
||||||
orig_shape = hidden_states.shape
|
orig_shape = hidden_states.shape
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user