diff --git a/vllm/envs.py b/vllm/envs.py index 7dcfabe3e044..53ce9ffe0a2d 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -159,6 +159,7 @@ if TYPE_CHECKING: VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "throughput" + VLLM_ALLOW_BATCHED_TRITON_FALLBACK: bool = False VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False @@ -1150,6 +1151,12 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_FLASHINFER_MOE_MXFP4_BF16": lambda: bool( int(os.getenv("VLLM_USE_FLASHINFER_MOE_MXFP4_BF16", "0")) ), + # If set to 1, allow fallback to batched triton kernel when deepgemm + # is unavailable. By default (0), the system will crash if deepgemm + # is expected but not available. + "VLLM_ALLOW_BATCHED_TRITON_FALLBACK": lambda: bool( + int(os.getenv("VLLM_ALLOW_BATCHED_TRITON_FALLBACK", "0")) + ), # Control the cache sized used by the xgrammar compiler. The default # of 512 MB should be enough for roughly 1000 JSON schemas. # It can be changed with this variable if needed for some reason. diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index 1b1af351a449..04265ac83b01 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -3,6 +3,7 @@ import torch +import vllm.envs as envs import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts, @@ -22,11 +23,8 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ): super().__init__(quant_config) - self.batched_triton_experts = BatchedTritonExperts( - max_num_tokens=max_num_tokens, - num_dispatchers=num_dispatchers, - quant_config=self.quant_config, - ) + # Store the original request for deep gemm + deep_gemm_requested = allow_deep_gemm self.allow_deep_gemm = ( allow_deep_gemm @@ -44,6 +42,28 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): else None ) + # If deep gemm was requested but is not available (either due to + # unsupported configuration or missing dependencies), check if + # we should allow fallback to batched triton kernel + if deep_gemm_requested and self.batched_deep_gemm_experts is None: + if not envs.VLLM_ALLOW_BATCHED_TRITON_FALLBACK: + raise RuntimeError( + "DeepGemm was requested but is not available. " + "The batched triton kernel fallback is disabled by default. " + "Set VLLM_ALLOW_BATCHED_TRITON_FALLBACK=1 to enable the fallback " + "for debugging purposes." + ) + + self.batched_triton_experts = ( + BatchedTritonExperts( + max_num_tokens=max_num_tokens, + num_dispatchers=num_dispatchers, + quant_config=self.quant_config, + ) + if self.batched_deep_gemm_experts is None + else None + ) + assert ( self.batched_deep_gemm_experts is not None or self.batched_triton_experts is not None