diff --git a/vllm/envs.py b/vllm/envs.py index 30c62e90e9fb..9421488051e5 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -147,6 +147,7 @@ if TYPE_CHECKING: VLLM_TPU_MOST_MODEL_LEN: int | None = None VLLM_TPU_USING_PATHWAYS: bool = False VLLM_USE_DEEP_GEMM: bool = True + VLLM_MOE_USE_DEEP_GEMM: bool = True VLLM_USE_DEEP_GEMM_E8M0: bool = True VLLM_DEEP_GEMM_WARMUP: Literal[ "skip", @@ -1116,6 +1117,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ), # Allow use of DeepGemm kernels for fused moe ops. "VLLM_USE_DEEP_GEMM": lambda: bool(int(os.getenv("VLLM_USE_DEEP_GEMM", "1"))), + # Allow use of DeepGemm specifically for MoE fused ops (overrides only MoE). + "VLLM_MOE_USE_DEEP_GEMM": lambda: bool( + int(os.getenv("VLLM_MOE_USE_DEEP_GEMM", "1")) + ), # Whether to use E8M0 scaling when DeepGEMM is used on Blackwell GPUs. "VLLM_USE_DEEP_GEMM_E8M0": lambda: bool( int(os.getenv("VLLM_USE_DEEP_GEMM_E8M0", "1")) @@ -1569,6 +1574,7 @@ def compute_hash() -> str: "VLLM_USE_FLASHINFER_SAMPLER", "VLLM_DISABLED_KERNELS", "VLLM_USE_DEEP_GEMM", + "VLLM_MOE_USE_DEEP_GEMM", "VLLM_USE_DEEP_GEMM_E8M0", "VLLM_USE_FUSED_MOE_GROUPED_TOPK", "VLLM_USE_FLASHINFER_MOE_FP16", diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index d32ae6674ee6..59567f2ca13c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -966,10 +966,18 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): max_num_tokens=max_num_tokens_per_rank, num_dispatchers=prepare_finalize.num_dispatchers(), quant_config=self.moe_quant_config, + allow_deep_gemm=( + envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM + ), ) else: logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__) - return TritonOrDeepGemmExperts(self.moe_quant_config, allow_deep_gemm=True) + return TritonOrDeepGemmExperts( + self.moe_quant_config, + allow_deep_gemm=( + envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM + ), + ) def get_fused_moe_quant_config( self, layer: torch.nn.Module diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index c7d5b251cf4e..83d136600b77 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -158,7 +158,7 @@ def get_fp8_moe_backend(block_quant: bool) -> Fp8MoeBackend: return Fp8MoeBackend.MARLIN # deepGEMM on supported platforms with block-quantized weights - if envs.VLLM_USE_DEEP_GEMM and block_quant: + if envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM and block_quant: if not has_deep_gemm(): logger.warning_once("DeepGEMM backend requested but not available.") elif is_deep_gemm_supported(): diff --git a/vllm/model_executor/warmup/deep_gemm_warmup.py b/vllm/model_executor/warmup/deep_gemm_warmup.py index bdcebd498ef0..e0c584df8760 100644 --- a/vllm/model_executor/warmup/deep_gemm_warmup.py +++ b/vllm/model_executor/warmup/deep_gemm_warmup.py @@ -148,6 +148,9 @@ def _fp8_linear_may_use_deep_gemm(module: torch.nn.Module) -> bool: def _fused_moe_grouped_gemm_may_use_deep_gemm(module: torch.nn.Module) -> bool: + if not (envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM): + return False + if not isinstance(module, FusedMoE): return False