[Bugfix] Fix FP8 MoE LoRA (#29890)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li 2025-12-05 02:17:49 +08:00 committed by GitHub
parent 6dcb07f676
commit 652ba93da3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -124,12 +124,16 @@ class Fp8MoeBackend(Enum):
def get_fp8_moe_backend(
block_quant: bool, moe_parallel_config: FusedMoEParallelConfig
block_quant: bool,
moe_parallel_config: FusedMoEParallelConfig,
with_lora_support: bool,
) -> Fp8MoeBackend:
"""
Select the primary FP8 MoE backend
Note: Shape-specific fallbacks may still occur at runtime.
"""
if with_lora_support:
return Fp8MoeBackend.TRITON
# Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100.
if (
current_platform.is_cuda()
@ -665,7 +669,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
self.weight_block_size = self.quant_config.weight_block_size
self.block_quant: bool = self.weight_block_size is not None
self.fp8_backend = get_fp8_moe_backend(
self.block_quant, layer.moe_parallel_config
self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled
)
self.marlin_input_dtype = None
@ -1084,6 +1088,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
from vllm.model_executor.layers.fused_moe import (
BatchedDeepGemmExperts,
BatchedTritonExperts,
TritonExperts,
TritonOrDeepGemmExperts,
)
@ -1116,7 +1121,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
)
elif self.moe.is_lora_enabled:
return TritonExperts(quant_config=self.moe_quant_config)
elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS:
# Select GEMM experts with block-scale when weights are block-quantized
experts = select_cutlass_fp8_gemm_impl(