diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 48223c9f103e..0e3e13f5945e 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -124,12 +124,16 @@ class Fp8MoeBackend(Enum): def get_fp8_moe_backend( - block_quant: bool, moe_parallel_config: FusedMoEParallelConfig + block_quant: bool, + moe_parallel_config: FusedMoEParallelConfig, + with_lora_support: bool, ) -> Fp8MoeBackend: """ Select the primary FP8 MoE backend Note: Shape-specific fallbacks may still occur at runtime. """ + if with_lora_support: + return Fp8MoeBackend.TRITON # Prefer FlashInfer backends on supported GPUs; allow SM90 and SM100. if ( current_platform.is_cuda() @@ -665,7 +669,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.weight_block_size = self.quant_config.weight_block_size self.block_quant: bool = self.weight_block_size is not None self.fp8_backend = get_fp8_moe_backend( - self.block_quant, layer.moe_parallel_config + self.block_quant, layer.moe_parallel_config, self.moe.is_lora_enabled ) self.marlin_input_dtype = None @@ -1084,6 +1088,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): from vllm.model_executor.layers.fused_moe import ( BatchedDeepGemmExperts, BatchedTritonExperts, + TritonExperts, TritonOrDeepGemmExperts, ) @@ -1116,7 +1121,8 @@ class Fp8MoEMethod(FusedMoEMethodBase): num_dispatchers=prepare_finalize.num_dispatchers(), quant_config=self.moe_quant_config, ) - + elif self.moe.is_lora_enabled: + return TritonExperts(quant_config=self.moe_quant_config) elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: # Select GEMM experts with block-scale when weights are block-quantized experts = select_cutlass_fp8_gemm_impl(