diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index e198322ba7a89..615da58eeda28 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -678,6 +678,10 @@ class FusedMoE(CustomOp): and self.moe_config.use_flashinfer_cutlass_kernels ) + @property + def use_marlin_kernels(self): + return getattr(self.quant_method, "use_marlin", False) + @property def use_dp_chunking(self) -> bool: return ( diff --git a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py index 06112ca51b6d5..6ec8b33ed9309 100644 --- a/vllm/model_executor/layers/fused_moe/shared_fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/shared_fused_moe.py @@ -28,17 +28,17 @@ class SharedFusedMoE(FusedMoE): super().__init__(**kwargs) self._shared_experts = shared_experts - # Disable shared expert overlap if we are using eplb, because of - # correctness issues, or if using flashinfer with DP, since there - # is nothing to be gained in this case. Disabling the overlap - # optimization also prevents the shared experts from being hidden - # from torch.compile. + # Disable shared expert overlap if: + # - we are using eplb, because of correctness issues + # - we are using flashinfer with DP, since there nothint to gain + # - we are using marlin kjernels self.use_overlapped = ( use_overlapped and not ( # TODO(wentao): find the root cause and remove this condition self.enable_eplb or (self.moe_config.use_flashinfer_cutlass_kernels and self.dp_size > 1) + or self.use_marlin_kernels ) and self._shared_experts is not None ) diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 3e1f87b59a34d..3f6ea68072b40 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -424,6 +424,7 @@ class AWQMoEMethod(FusedMoEMethodBase): if self.quant_config.weight_bits != 4: raise ValueError("AWQMoEMethod only supports 4bit now.") self.quant_type = scalar_types.uint4 + self.use_marlin = True def create_weights( self, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index 6257a410e9432..f1050c15f79e7 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -1342,6 +1342,7 @@ class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod): f"{WNA16_SUPPORTED_BITS}", ) self.quant_type = WNA16_SUPPORTED_TYPES_MAP[self.num_bits] + self.use_marlin = True def create_weights( self, diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index 42a569e7770c0..68a122fd46c6b 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -482,6 +482,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase): self.quant_type = scalar_types.uint8b128 else: raise ValueError("GPTQMarlinMoEMethod only supports int4 and int8 now.") + self.use_marlin = True def create_weights( self, diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 8d7297a0a1b3b..7940b359a150c 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -216,6 +216,7 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): def __init__(self, moe: FusedMoEConfig): super().__init__(moe) self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled) + self.use_marlin = self.mxfp4_backend == Mxfp4Backend.MARLIN self.max_capture_size = ( get_current_vllm_config().compilation_config.max_cudagraph_capture_size )