[Bugfix] Allow skipping MoE in NVFP4 (fix for MTP) (#25987)

Signed-off-by: Benjamin Chislett <bchislett@nvidia.com>
This commit is contained in:
Benjamin Chislett 2025-10-06 16:16:30 -04:00 committed by GitHub
parent f23b4c04fd
commit 2161efe978
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 18 additions and 5 deletions

View File

@ -1194,6 +1194,8 @@ class FusedMoE(CustomOp):
if quant_config is None if quant_config is None
else quant_config.get_quant_method(self, prefix) else quant_config.get_quant_method(self, prefix)
) )
if quant_method is None:
quant_method = UnquantizedFusedMoEMethod(moe)
assert quant_method is not None assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase) assert isinstance(quant_method, FusedMoEMethodBase)

View File

@ -884,8 +884,9 @@ class ModelOptNvFp4Config(QuantizationConfig):
) -> Optional["QuantizeMethodBase"]: ) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import from vllm.attention.layer import Attention # Avoid circular import
skip_layer = self.is_layer_excluded(prefix)
if isinstance(layer, LinearBase): if isinstance(layer, LinearBase):
if self.is_layer_excluded(prefix): if skip_layer:
return UnquantizedLinearMethod() return UnquantizedLinearMethod()
# Check if this is a vision model layer that should not be quantized # Check if this is a vision model layer that should not be quantized
if "vision_tower" in prefix or "vision_model" in prefix: if "vision_tower" in prefix or "vision_model" in prefix:
@ -894,6 +895,8 @@ class ModelOptNvFp4Config(QuantizationConfig):
elif isinstance(layer, Attention): elif isinstance(layer, Attention):
return ModelOptFp8KVCacheMethod(self) return ModelOptFp8KVCacheMethod(self)
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
if skip_layer:
return None
return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer) return ModelOptNvFp4FusedMoE(self, layer.moe_config, layer)
return None return None

View File

@ -55,6 +55,7 @@ class DeepseekV2Model(nn.Module):
DeepseekV2DecoderLayer( DeepseekV2DecoderLayer(
vllm_config, vllm_config,
prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"), prefix=maybe_prefix(prefix, f"layers.{i + start_layer_id}"),
config=self.config,
) )
for i in range(self.config.num_hidden_layers) for i in range(self.config.num_hidden_layers)
] ]

View File

@ -48,7 +48,8 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
super().__init__() super().__init__()
config = vllm_config.model_config.hf_config config = vllm_config.speculative_config.draft_model_config.hf_config
self.config = config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@ -66,11 +67,15 @@ class DeepSeekMultiTokenPredictorLayer(nn.Module):
) )
else: else:
topk_indices_buffer = None topk_indices_buffer = None
self.shared_head = SharedHead( self.shared_head = SharedHead(
config=config, prefix=prefix, quant_config=quant_config config=config, prefix=prefix, quant_config=quant_config
) )
self.mtp_block = DeepseekV2DecoderLayer( self.mtp_block = DeepseekV2DecoderLayer(
vllm_config, prefix, topk_indices_buffer vllm_config,
prefix,
config=self.config,
topk_indices_buffer=topk_indices_buffer,
) )
def forward( def forward(

View File

@ -1055,10 +1055,12 @@ class DeepseekV2DecoderLayer(nn.Module):
self, self,
vllm_config: VllmConfig, vllm_config: VllmConfig,
prefix: str, prefix: str,
config: Optional[DeepseekV2Config] = None,
topk_indices_buffer: Optional[torch.Tensor] = None, topk_indices_buffer: Optional[torch.Tensor] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
if config is None:
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config model_config = vllm_config.model_config
cache_config = vllm_config.cache_config cache_config = vllm_config.cache_config
@ -1200,7 +1202,7 @@ class DeepseekV2Model(nn.Module):
self.start_layer, self.end_layer, self.layers = make_layers( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer( lambda prefix: DeepseekV2DecoderLayer(
vllm_config, prefix, topk_indices_buffer vllm_config, prefix, topk_indices_buffer=topk_indices_buffer
), ),
prefix=f"{prefix}.layers", prefix=f"{prefix}.layers",
) )