[Bugfix][LoRA][FusedMoE] Select MxFP4 Backend based on LoRA Enablement (#27487)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-10-27 10:32:50 -04:00 committed by GitHub
parent 4f882be4a0
commit 5d3be3ba4c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 34 additions and 6 deletions

View File

@ -825,6 +825,8 @@ class FusedMoEConfig:
is_act_and_mul: bool = True
is_lora_enabled: bool = False
def __post_init__(self):
if self.dp_size > 1:
logger.debug_once(

View File

@ -982,6 +982,7 @@ def maybe_roundup_hidden_size(
act_dtype: torch.dtype,
quant_config: QuantizationConfig | None,
moe_parallel_config: FusedMoEParallelConfig,
is_lora_enabled: bool,
) -> int:
"""
Given layer hidden size and MoE configurations, round up hidden_size
@ -992,6 +993,9 @@ def maybe_roundup_hidden_size(
act_dtype: Data type of the layer activations.
quant_config: Fused MoE quantization configuration.
moe_parallel_config: Fused MoE parallelization strategy configuration.
is_lora_enabled: True if the engine is enabled with LoRA. This
is used in the case of mxfp4 quantization in selecting the
MxFP4Backend.
Return:
Rounded up hidden_size if rounding up is required based on the configs.
@ -1015,7 +1019,7 @@ def maybe_roundup_hidden_size(
get_mxfp4_backend,
)
current_mxfp4_backend = get_mxfp4_backend()
current_mxfp4_backend = get_mxfp4_backend(is_lora_enabled)
if (
current_mxfp4_backend == Mxfp4Backend.SM90_FI_MXFP4_BF16
or current_mxfp4_backend == Mxfp4Backend.SM100_FI_MXFP4_MXFP8_CUTLASS
@ -1139,7 +1143,11 @@ class FusedMoE(CustomOp):
# Round up hidden size if needed.
hidden_size = maybe_roundup_hidden_size(
hidden_size, moe_in_dtype, quant_config, self.moe_parallel_config
hidden_size,
moe_in_dtype,
quant_config,
self.moe_parallel_config,
is_lora_enabled=self.vllm_config.lora_config is not None,
)
# For smuggling this layer into the fused moe custom op
@ -1270,8 +1278,9 @@ class FusedMoE(CustomOp):
max_num_tokens=envs.VLLM_MOE_DP_CHUNK_SIZE,
has_bias=has_bias,
is_act_and_mul=is_act_and_mul,
is_lora_enabled=vllm_config.lora_config is not None,
)
self.moe_config = moe
self.moe_config: FusedMoEConfig = moe
self.moe_quant_config: FusedMoEQuantConfig | None = None
self.quant_config = quant_config

View File

@ -73,8 +73,24 @@ class Mxfp4Backend(Enum):
TRITON = 6
def get_mxfp4_backend():
def get_mxfp4_backend_with_lora() -> Mxfp4Backend:
"""
Not all MXFP4 backends support LoRA. Select backends that are known to
have LoRA support.
"""
if not current_platform.is_cuda():
return Mxfp4Backend.NONE
logger.info_once("[get_mxfp4_backend_with_lora] Using Marlin backend")
return Mxfp4Backend.MARLIN
def get_mxfp4_backend(with_lora_support: bool) -> Mxfp4Backend:
# Backend Selection
if with_lora_support:
return get_mxfp4_backend_with_lora()
if current_platform.is_cuda():
if (
current_platform.is_device_capability(90)
@ -183,13 +199,14 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
super().__init__(moe)
self.topk_indices_dtype = None
self.moe = moe
self.mxfp4_backend = get_mxfp4_backend()
self.mxfp4_backend = get_mxfp4_backend(moe.is_lora_enabled)
self.max_capture_size = (
get_current_vllm_config().compilation_config.max_cudagraph_capture_size
)
assert self.mxfp4_backend != Mxfp4Backend.NONE, (
"No MXFP4 MoE backend (FlashInfer/Marlin/Triton) available."
f"get_mxfp4_backend(with_lora_support={moe.is_lora_enabled}) found"
"no compatible MXFP4 MoE backend (FlashInfer/Marlin/Triton)."
"Please check your environment and try again."
)
self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {}