diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 9e2718057038d..e033032903e87 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -28,6 +28,7 @@ from vllm.model_executor.layers.fused_moe import ( FusedMoeWeightScaleSupported, ) from vllm.model_executor.layers.fused_moe.config import ( + FusedMoEParallelConfig, FusedMoEQuantConfig, RoutingMethodType, fp8_w8a8_moe_quant_config, @@ -118,7 +119,9 @@ class Fp8MoeBackend(Enum): TRITON = 6 -def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend: +def get_fp8_moe_backend( + block_quant: bool, moe_parallel_config: FusedMoEParallelConfig +) -> Fp8MoeBackend: """ Select the primary FP8 MoE backend Note: Shape-specific fallbacks may still occur at runtime. @@ -159,8 +162,19 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend: logger.info_once("Using Marlin backend for FP8 MoE") return Fp8MoeBackend.MARLIN - # deepGEMM on supported platforms with block-quantized weights - if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant: + # Determine if we should use DeepGEMM with block-quantized weights: + # - If explicitly set by user, respect their choice + # - If not explicitly set (default), disable when TP size is >= 8 + moe_use_deep_gemm = envs.VLLM_MOE_USE_DEEP_GEMM + if not envs.is_set("VLLM_MOE_USE_DEEP_GEMM") and moe_parallel_config.tp_size >= 8: + moe_use_deep_gemm = False + logger.info_once( + "DeepGEMM MoE is disabled by default when TP size is >= 8. " + "Set VLLM_MOE_USE_DEEP_GEMM=1 to enable it.", + scope="local", + ) + + if envs.VLLM_USE_DEEP_GEMM and moe_use_deep_gemm and block_quant: if not has_deep_gemm(): logger.warning_once( "DeepGEMM backend requested but not available.", scope="local" @@ -641,7 +655,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): self.quant_config = quant_config self.weight_block_size = self.quant_config.weight_block_size self.block_quant: bool = self.weight_block_size is not None - self.fp8_backend = get_fp8_moe_backend(self.block_quant) + self.fp8_backend = get_fp8_moe_backend( + self.block_quant, layer.moe_parallel_config + ) self.use_marlin = self.fp8_backend == Fp8MoeBackend.MARLIN self.flashinfer_moe_backend: FlashinferMoeBackend | None = None