diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md
index 44aaa65218cc4..48341d199cb80 100644
--- a/docs/design/moe_kernel_features.md
+++ b/docs/design/moe_kernel_features.md
@@ -90,7 +90,6 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
| cutlass_fp8 | standard,batched | fp8 | A,T | silu, gelu | Y | Y | [`cutlass_moe_fp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp8],[`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] |
| flashinfer | standard | nvfp4,fp8 | T | 5 | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
| gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
-| deep gemm+triton2 | standard,batched | all1 | G(128),A,T | silu, gelu | 6 | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] |
| marlin | standard,batched | 3 / N/A | 3 / N/A | silu,swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
| trtllm | standard | mxfp4,nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
@@ -114,5 +113,5 @@ The following table shows "families" of modular kernels that are intended to wor
| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses |
|---------|-----------------------------------------|----------------------------------------------|
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,`TritonExperts`,`TritonOrDeepGemmExperts`,`CutlassExpertsFp8`, `MarlinExperts` |
-| deepep_low_latency,pplx | `DeepEPLLPrepareAndFinalize`,`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,`BatchedTritonExperts`,`BatchedTritonOrDeepGemmExperts`,`CutlassBatchedExpertsFp8`,`BatchedMarlinExperts` |
+| deepep_low_latency,pplx | `DeepEPLLPrepareAndFinalize`,`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,`BatchedTritonExperts`,`CutlassBatchedExpertsFp8`,`BatchedMarlinExperts` |
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |
diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py
index d79fdfbe07af3..99b168dc75548 100644
--- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py
+++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py
@@ -13,9 +13,6 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import (
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
-from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import (
- BatchedTritonOrDeepGemmExperts,
-)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
@@ -286,16 +283,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
needs_matching_quant=False,
needs_deep_gemm=True,
)
- register_experts(
- BatchedTritonOrDeepGemmExperts,
- batched_format,
- common_float_and_int_types,
- blocked_quantization_support=True,
- supports_chunking=False,
- supports_expert_map=False,
- needs_matching_quant=True,
- needs_deep_gemm=True,
- )
register_experts(
TritonOrDeepGemmExperts,
standard_format,
@@ -457,10 +444,6 @@ def make_fused_experts(
kwargs = batch_kwargs | quant_kwargs
print(f"Making BatchedTritonExperts {kwargs} ...")
experts = BatchedTritonExperts(**kwargs)
- elif fused_experts_type == BatchedTritonOrDeepGemmExperts:
- kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
- print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
- experts = BatchedTritonOrDeepGemmExperts(**kwargs)
elif fused_experts_type == DeepGemmExperts:
print(f"Making DeepGemmExperts {quant_config} ...")
experts = DeepGemmExperts(quant_config)
diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py
index 669abcb3d6ff1..9103e84aa7057 100644
--- a/vllm/model_executor/layers/fused_moe/__init__.py
+++ b/vllm/model_executor/layers/fused_moe/__init__.py
@@ -60,9 +60,6 @@ if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
- from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
- BatchedTritonOrDeepGemmExperts,
- )
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassBatchedExpertsFp8,
CutlassExpertsFp8,
@@ -98,7 +95,6 @@ if HAS_TRITON:
"DeepGemmExperts",
"BatchedDeepGemmExperts",
"TritonOrDeepGemmExperts",
- "BatchedTritonOrDeepGemmExperts",
]
else:
# Some model classes directly use the custom ops. Add placeholders
diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py
deleted file mode 100644
index e69e9fd307aeb..0000000000000
--- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py
+++ /dev/null
@@ -1,180 +0,0 @@
-# SPDX-License-Identifier: Apache-2.0
-# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
-
-import torch
-
-import vllm.model_executor.layers.fused_moe.modular_kernel as mk
-from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
- BatchedDeepGemmExperts,
-)
-from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
-from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
-from vllm.utils.deep_gemm import get_mk_alignment_for_contiguous_layout
-
-
-class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
- def __init__(
- self,
- max_num_tokens: int,
- num_dispatchers: int,
- quant_config: FusedMoEQuantConfig,
- allow_deep_gemm: bool = False,
- ):
- super().__init__(quant_config)
-
- self.batched_triton_experts = BatchedTritonExperts(
- max_num_tokens=max_num_tokens,
- num_dispatchers=num_dispatchers,
- quant_config=self.quant_config,
- )
-
- self.allow_deep_gemm = (
- allow_deep_gemm
- and self.quant_config.use_fp8_w8a8
- and self.block_shape == get_mk_alignment_for_contiguous_layout()
- )
-
- self.batched_deep_gemm_experts = (
- BatchedDeepGemmExperts(
- max_num_tokens=max_num_tokens,
- num_dispatchers=num_dispatchers,
- quant_config=self.quant_config,
- )
- if self.allow_deep_gemm
- else None
- )
-
- assert (
- self.batched_deep_gemm_experts is not None
- or self.batched_triton_experts is not None
- )
-
- @property
- def activation_formats(
- self,
- ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
- if self.batched_triton_experts is not None:
- assert (
- self.batched_deep_gemm_experts is None
- or self.batched_deep_gemm_experts.activation_formats
- == self.batched_triton_experts.activation_formats
- )
- return self.batched_triton_experts.activation_formats
- else:
- assert self.batched_deep_gemm_experts is not None
- return self.batched_deep_gemm_experts.activation_formats
-
- def supports_chunking(self) -> bool:
- bdge = self.batched_deep_gemm_experts
- bte = self.batched_triton_experts
- return (bdge is None or bdge.supports_chunking()) and (
- bte is None or bte.supports_chunking()
- )
-
- def supports_expert_map(self) -> bool:
- bdge = self.batched_deep_gemm_experts
- bte = self.batched_triton_experts
- return (bdge is None or bdge.supports_expert_map()) and (
- bte is None or bte.supports_expert_map()
- )
-
- def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
- bdge = self.batched_deep_gemm_experts
- bte = self.batched_triton_experts
- bdge_war = bdge.finalize_weight_and_reduce_impl() if bdge else None
- bte_war = bte.finalize_weight_and_reduce_impl() if bte else None
- is_bdge_war = bdge_war is not None
- is_bte_war = bte_war is not None
-
- if is_bdge_war and is_bte_war:
- assert bdge_war == bte_war, (
- "Both implementations should agree on WeightAndReduce impls. "
- f"Got bdge_war: {bdge_war}, and bte_war: {bte_war}"
- )
-
- if bdge_war is not None:
- return bdge_war
-
- assert bte_war is not None
- return bte_war
-
- def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
- return act_dtype
-
- def workspace_shapes(
- self,
- M: int,
- N: int,
- K: int,
- topk: int,
- global_num_experts: int,
- local_num_experts: int,
- expert_tokens_metadata: mk.ExpertTokensMetadata | None,
- ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
- # Note: the deep gemm workspaces are strictly larger than the triton
- # workspaces so we can be pessimistic here and allocate for DeepGemm
- # even if we fall back to triton later, e.g. if expert maps are set.
- if self.allow_deep_gemm:
- assert self.batched_deep_gemm_experts is not None
- return self.batched_deep_gemm_experts.workspace_shapes(
- M,
- N,
- K,
- topk,
- global_num_experts,
- local_num_experts,
- expert_tokens_metadata,
- )
- else:
- assert self.batched_triton_experts is not None
- return self.batched_triton_experts.workspace_shapes(
- M,
- N,
- K,
- topk,
- global_num_experts,
- local_num_experts,
- expert_tokens_metadata,
- )
-
- def apply(
- self,
- output: torch.Tensor,
- hidden_states: torch.Tensor,
- w1: torch.Tensor,
- w2: torch.Tensor,
- topk_weights: torch.Tensor,
- topk_ids: torch.Tensor,
- activation: str,
- global_num_experts: int,
- expert_map: torch.Tensor | None,
- a1q_scale: torch.Tensor | None,
- a2_scale: torch.Tensor | None,
- workspace13: torch.Tensor,
- workspace2: torch.Tensor,
- expert_tokens_meta: mk.ExpertTokensMetadata | None,
- apply_router_weight_on_input: bool,
- ):
- experts = (
- self.batched_deep_gemm_experts
- if self.allow_deep_gemm
- else self.batched_triton_experts
- )
- assert experts is not None
- experts.apply(
- output,
- hidden_states,
- w1,
- w2,
- topk_weights,
- topk_ids,
- activation,
- global_num_experts,
- expert_map,
- a1q_scale,
- a2_scale,
- workspace13,
- workspace2,
- expert_tokens_meta,
- apply_router_weight_on_input,
- )
diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
index c7368bf427fe1..d7fb6d2ca367d 100644
--- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
+++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py
@@ -90,8 +90,10 @@ from vllm.platforms import CpuArchEnum, current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.deep_gemm import (
get_col_major_tma_aligned_tensor,
+ get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used,
)
+from vllm.utils.import_utils import has_deep_gemm
logger = init_logger(__name__)
@@ -1088,9 +1090,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
return experts
- # triton path
- from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
- BatchedTritonOrDeepGemmExperts,
+ from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
+ BatchedDeepGemmExperts,
+ )
+ from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
+ BatchedTritonExperts,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
@@ -1098,6 +1102,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
assert not self.rocm_aiter_moe_enabled and not self.use_marlin
+ use_deep_gemm = envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM
+
if (
prepare_finalize.activation_format
== FusedMoEActivationFormat.BatchedExperts
@@ -1105,22 +1111,47 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
- logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
- return BatchedTritonOrDeepGemmExperts(
- max_num_tokens=max_num_tokens_per_rank,
- num_dispatchers=prepare_finalize.num_dispatchers(),
- quant_config=self.moe_quant_config,
- allow_deep_gemm=(
- envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM
- ),
+ if use_deep_gemm and not has_deep_gemm():
+ raise RuntimeError(
+ "DeepGEMM requested for MoE layer but not installed."
+ )
+
+ compatible_with_deep_gemm = (
+ self.moe_quant_config.use_fp8_w8a8
+ and self.moe_quant_config.block_shape
+ == get_mk_alignment_for_contiguous_layout()
)
+
+ # If this MoE layer is compatible with DeepGEMM, the proper env
+ # vars are set and DeepGEMM is not installed, throw an error.
+ if use_deep_gemm and compatible_with_deep_gemm and not has_deep_gemm():
+ raise RuntimeError(
+ f"MoE layer incompatible with DeepGEMM, expected "
+ f"fp8==True, got {self.moe_quant_config.use_fp8_w8a8}"
+ f"or block_shape {self.moe_quant_config.block_shape}"
+ f"=={get_mk_alignment_for_contiguous_layout()}."
+ )
+
+ if use_deep_gemm and compatible_with_deep_gemm and has_deep_gemm():
+ logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__)
+ return BatchedDeepGemmExperts(
+ max_num_tokens=max_num_tokens_per_rank,
+ num_dispatchers=prepare_finalize.num_dispatchers(),
+ quant_config=self.moe_quant_config,
+ )
+ else:
+ logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
+ return BatchedTritonExperts(
+ max_num_tokens=max_num_tokens_per_rank,
+ num_dispatchers=prepare_finalize.num_dispatchers(),
+ quant_config=self.moe_quant_config,
+ )
+
else:
logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__)
return TritonOrDeepGemmExperts(
self.moe_quant_config,
- allow_deep_gemm=(
- envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM
- ),
+ allow_deep_gemm=use_deep_gemm,
)
def get_fused_moe_quant_config(