From 2161efe9781cc0bee2f60342dffd0c1f7f0f2b57 Mon Sep 17 00:00:00 2001 From: Benjamin Chislett Date: Mon, 6 Oct 2025 16:16:30 -0400 Subject: [PATCH] [Bugfix] Allow skipping MoE in NVFP4 (fix for MTP) (#25987) Signed-off-by: Benjamin Chislett --- vllm/model_executor/layers/fused_moe/layer.py | 2 ++ vllm/model_executor/layers/quantization/modelopt.py | 5 ++++- vllm/model_executor/models/deepseek_eagle.py | 1 + vllm/model_executor/models/deepseek_mtp.py | 9 +++++++-- vllm/model_executor/models/deepseek_v2.py | 6 ++++-- 5 files changed, 18 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 767f9cd46a934..9c8ccc6ec0085 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1194,6 +1194,8 @@ class FusedMoE(CustomOp): if quant_config is None else quant_config.get_quant_method(self, prefix) ) + if quant_method is None: + quant_method = UnquantizedFusedMoEMethod(moe) assert quant_method is not None assert isinstance(quant_method, FusedMoEMethodBase) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 8c074ebdc8db5..c285b10720d86 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -884,8 +884,9 @@ class ModelOptNvFp4Config(QuantizationConfig): ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import + skip_layer = self.is_layer_excluded(prefix) if isinstance(layer, LinearBase): - if self.is_layer_excluded(prefix): + if skip_layer: return UnquantizedLinearMethod() # Check if this is a vision model layer that should not be quantized if "vision_tower" in prefix or "vision_model" in prefix: @@ -894,6 +895,8 @@ class ModelOptNvFp4Config(QuantizationConfig): elif isinstance(layer, Attention): return ModelOptFp8KVCacheMethod(self) elif isinstance(layer, FusedMoE): + if skip_layer: + return None return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer) return None diff --git a/vllm/model_executor/models/deepseek_eagle.py b/vllm/model_executor/models/deepseek_eagle.py index 467468dcc01eb..faa7edd4bc3c3 100644 --- a/vllm/model_executor/models/deepseek_eagle.py +++ b/vllm/model_executor/models/deepseek_eagle.py @@ -55,6 +55,7 @@ class DeepseekV2Model(nn.Module): DeepseekV2DecoderLayer( vllm_config, prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), + config=self.config, ) for i in range(self.config.num_hidden_layers) ] diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 36c1e0cbe69ba..041dd6db73251 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -48,7 +48,8 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: super().__init__() - config = vllm_config.model_config.hf_config + config = vllm_config.speculative_config.draft_model_config.hf_config + self.config = config quant_config = vllm_config.quant_config self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -66,11 +67,15 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module): ) else: topk_indices_buffer = None + self.shared_head = SharedHead( config=config, prefix=prefix, quant_config=quant_config ) self.mtp_block = DeepseekV2DecoderLayer( - vllm_config, prefix, topk_indices_buffer + vllm_config, + prefix, + config=self.config, + topk_indices_buffer=topk_indices_buffer, ) def forward( diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index f149b02e5522d..5b05d0e3a532d 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -1055,11 +1055,13 @@ class DeepseekV2DecoderLayer(nn.Module): self, vllm_config: VllmConfig, prefix: str, + config: Optional[DeepseekV2Config] = None, topk_indices_buffer: Optional[torch.Tensor] = None, ) -> None: super().__init__() - config = vllm_config.model_config.hf_config + if config is None: + config = vllm_config.model_config.hf_config model_config = vllm_config.model_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config @@ -1200,7 +1202,7 @@ class DeepseekV2Model(nn.Module): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: DeepseekV2DecoderLayer( - vllm_config, prefix, topk_indices_buffer + vllm_config, prefix, topk_indices_buffer=topk_indices_buffer ), prefix=f"{prefix}.layers", )