From 5d3be3ba4c109f3ce5f0ffb78245296f3b9aa5ac Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Mon, 27 Oct 2025 10:32:50 -0400 Subject: [PATCH] [Bugfix][LoRA][FusedMoE] Select MxFP4 Backend based on LoRA Enablement (#27487) Signed-off-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath --- .../model_executor/layers/fused_moe/config.py | 2 ++ vllm/model_executor/layers/fused_moe/layer.py | 15 +++++++++--- .../layers/quantization/mxfp4.py | 23 ++++++++++++++++--- 3 files changed, 34 insertions(+), 6 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/config.py b/vllm/model_executor/layers/fused_moe/config.py index 5403d4e62f85e..2394053329802 100644 --- a/vllm/model_executor/layers/fused_moe/config.py +++ b/vllm/model_executor/layers/fused_moe/config.py @@ -825,6 +825,8 @@ class FusedMoEConfig: is_act_and_mul: bool = True + is_lora_enabled: bool = False + def __post_init__(self): if self.dp_size > 1: logger.debug_once( diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index c144aa23e46e4..9b826f05fe307 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -982,6 +982,7 @@ def maybe_roundup_hidden_size( act_dtype: torch.dtype, quant_config: QuantizationConfig | None, moe_parallel_config: FusedMoEParallelConfig, + is_lora_enabled: bool, ) -> int: """ Given layer hidden size and MoE configurations, round up hidden_size @@ -992,6 +993,9 @@ def maybe_roundup_hidden_size( act_dtype: Data type of the layer activations. quant_config: Fused MoE quantization configuration. moe_parallel_config: Fused MoE parallelization strategy configuration. + is_lora_enabled: True if the engine is enabled with LoRA. This + is used in the case of mxfp4 quantization in selecting the + MxFP4Backend. Return: Rounded up hidden_size if rounding up is required based on the configs. @@ -1015,7 +1019,7 @@ def maybe_roundup_hidden_size( get_mxfp4_backend, ) - current_mxfp4_backend = get_mxfp4_backend() + current_mxfp4_backend = get_mxfp4_backend(is_lora_enabled) if ( current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16 or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS @@ -1139,7 +1143,11 @@ class FusedMoE(CustomOp): # Round up hidden size if needed. hidden_size = maybe_roundup_hidden_size( - hidden_size, moe_in_dtype, quant_config, self.moe_parallel_config + hidden_size, + moe_in_dtype, + quant_config, + self.moe_parallel_config, + is_lora_enabled=self.vllm_config.lora_config is not None, ) # For smuggling this layer into the fused moe custom op @@ -1270,8 +1278,9 @@ class FusedMoE(CustomOp): max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE, has_bias=has_bias, is_act_and_mul=is_act_and_mul, + is_lora_enabled=vllm_config.lora_config is not None, ) - self.moe_config = moe + self.moe_config: FusedMoEConfig = moe self.moe_quant_config: FusedMoEQuantConfig | None = None self.quant_config = quant_config diff --git a/vllm/model_executor/layers/quantization/mxfp4.py b/vllm/model_executor/layers/quantization/mxfp4.py index 6ffaa558887a1..597ee1b6bafe1 100644 --- a/vllm/model_executor/layers/quantization/mxfp4.py +++ b/vllm/model_executor/layers/quantization/mxfp4.py @@ -73,8 +73,24 @@ class Mxfp4Backend(Enum): TRITON = 6 -def get_mxfp4_backend(): +def get_mxfp4_backend_with_lora() -> Mxfp4Backend: + """ + Not all MXFP4 backends support LoRA. Select backends that are known to + have LoRA support. + """ + if not current_platform.is_cuda(): + return Mxfp4Backend.NONE + + logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend") + return Mxfp4Backend.MARLIN + + +def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend: # Backend Selection + + if with_lora_support: + return get_mxfp4_backend_with_lora() + if current_platform.is_cuda(): if ( current_platform.is_device_capability(90) @@ -183,13 +199,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): super().__init__(moe) self.topk_indices_dtype = None self.moe = moe - self.mxfp4_backend = get_mxfp4_backend() + self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled) self.max_capture_size = ( get_current_vllm_config().compilation_config.max_cudagraph_capture_size ) assert self.mxfp4_backend != Mxfp4Backend.NONE, ( - "No MXFP4 MoE backend (FlashInfer/Marlin/Triton) available." + f"get_mxfp4_backend(with_lora_support={moe.is_lora_enabled}) found" + "no compatible MXFP4 MoE backend (FlashInfer/Marlin/Triton)." "Please check your environment and try again." ) self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}