diff --git a/docs/design/moe_kernel_features.md b/docs/design/moe_kernel_features.md index 44aaa65218cc4..48341d199cb80 100644 --- a/docs/design/moe_kernel_features.md +++ b/docs/design/moe_kernel_features.md @@ -90,7 +90,6 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels | cutlass_fp8 | standard,
batched | fp8 | A,T | silu, gelu | Y | Y | [`cutlass_moe_fp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp8],
[`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],
[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] | | flashinfer | standard | nvfp4,
fp8 | T | 5 | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],
[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] | | gpt oss triton | standard | N/A | N/A | 5 | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],
[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] | -| deep gemm+triton2 | standard,
batched | all1 | G(128),A,T | silu, gelu | 6 | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],
[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] | | marlin | standard,
batched | 3 / N/A | 3 / N/A | silu,
swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],
[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],
[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] | | trtllm | standard | mxfp4,
nvfp4 | G(16),G(32) | 5 | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] | | pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] | @@ -114,5 +113,5 @@ The following table shows "families" of modular kernels that are intended to wor | backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses | |---------|-----------------------------------------|----------------------------------------------| | deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,
`TritonExperts`,
`TritonOrDeepGemmExperts`,
`CutlassExpertsFp8`,
`MarlinExperts` | -| deepep_low_latency,
pplx | `DeepEPLLPrepareAndFinalize`,
`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,
`BatchedTritonExperts`,
`BatchedTritonOrDeepGemmExperts`,
`CutlassBatchedExpertsFp8`,
`BatchedMarlinExperts` | +| deepep_low_latency,
pplx | `DeepEPLLPrepareAndFinalize`,
`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,
`BatchedTritonExperts`,
`CutlassBatchedExpertsFp8`,
`BatchedMarlinExperts` | | flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` | diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index d79fdfbe07af3..99b168dc75548 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -13,9 +13,6 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import ( from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts, ) -from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( - BatchedTritonOrDeepGemmExperts, -) from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, FusedMoEQuantConfig, @@ -286,16 +283,6 @@ if has_deep_gemm() and is_deep_gemm_supported(): needs_matching_quant=False, needs_deep_gemm=True, ) - register_experts( - BatchedTritonOrDeepGemmExperts, - batched_format, - common_float_and_int_types, - blocked_quantization_support=True, - supports_chunking=False, - supports_expert_map=False, - needs_matching_quant=True, - needs_deep_gemm=True, - ) register_experts( TritonOrDeepGemmExperts, standard_format, @@ -457,10 +444,6 @@ def make_fused_experts( kwargs = batch_kwargs | quant_kwargs print(f"Making BatchedTritonExperts {kwargs} ...") experts = BatchedTritonExperts(**kwargs) - elif fused_experts_type == BatchedTritonOrDeepGemmExperts: - kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs - print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...") - experts = BatchedTritonOrDeepGemmExperts(**kwargs) elif fused_experts_type == DeepGemmExperts: print(f"Making DeepGemmExperts {quant_config} ...") experts = DeepGemmExperts(quant_config) diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 669abcb3d6ff1..9103e84aa7057 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -60,9 +60,6 @@ if HAS_TRITON: from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( BatchedDeepGemmExperts, ) - from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 - BatchedTritonOrDeepGemmExperts, - ) from vllm.model_executor.layers.fused_moe.cutlass_moe import ( CutlassBatchedExpertsFp8, CutlassExpertsFp8, @@ -98,7 +95,6 @@ if HAS_TRITON: "DeepGemmExperts", "BatchedDeepGemmExperts", "TritonOrDeepGemmExperts", - "BatchedTritonOrDeepGemmExperts", ] else: # Some model classes directly use the custom ops. Add placeholders diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py deleted file mode 100644 index e69e9fd307aeb..0000000000000 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ /dev/null @@ -1,180 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import torch - -import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( - BatchedDeepGemmExperts, -) -from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig -from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts -from vllm.utils.deep_gemm import get_mk_alignment_for_contiguous_layout - - -class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): - def __init__( - self, - max_num_tokens: int, - num_dispatchers: int, - quant_config: FusedMoEQuantConfig, - allow_deep_gemm: bool = False, - ): - super().__init__(quant_config) - - self.batched_triton_experts = BatchedTritonExperts( - max_num_tokens=max_num_tokens, - num_dispatchers=num_dispatchers, - quant_config=self.quant_config, - ) - - self.allow_deep_gemm = ( - allow_deep_gemm - and self.quant_config.use_fp8_w8a8 - and self.block_shape == get_mk_alignment_for_contiguous_layout() - ) - - self.batched_deep_gemm_experts = ( - BatchedDeepGemmExperts( - max_num_tokens=max_num_tokens, - num_dispatchers=num_dispatchers, - quant_config=self.quant_config, - ) - if self.allow_deep_gemm - else None - ) - - assert ( - self.batched_deep_gemm_experts is not None - or self.batched_triton_experts is not None - ) - - @property - def activation_formats( - self, - ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: - if self.batched_triton_experts is not None: - assert ( - self.batched_deep_gemm_experts is None - or self.batched_deep_gemm_experts.activation_formats - == self.batched_triton_experts.activation_formats - ) - return self.batched_triton_experts.activation_formats - else: - assert self.batched_deep_gemm_experts is not None - return self.batched_deep_gemm_experts.activation_formats - - def supports_chunking(self) -> bool: - bdge = self.batched_deep_gemm_experts - bte = self.batched_triton_experts - return (bdge is None or bdge.supports_chunking()) and ( - bte is None or bte.supports_chunking() - ) - - def supports_expert_map(self) -> bool: - bdge = self.batched_deep_gemm_experts - bte = self.batched_triton_experts - return (bdge is None or bdge.supports_expert_map()) and ( - bte is None or bte.supports_expert_map() - ) - - def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: - bdge = self.batched_deep_gemm_experts - bte = self.batched_triton_experts - bdge_war = bdge.finalize_weight_and_reduce_impl() if bdge else None - bte_war = bte.finalize_weight_and_reduce_impl() if bte else None - is_bdge_war = bdge_war is not None - is_bte_war = bte_war is not None - - if is_bdge_war and is_bte_war: - assert bdge_war == bte_war, ( - "Both implementations should agree on WeightAndReduce impls. " - f"Got bdge_war: {bdge_war}, and bte_war: {bte_war}" - ) - - if bdge_war is not None: - return bdge_war - - assert bte_war is not None - return bte_war - - def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: - return act_dtype - - def workspace_shapes( - self, - M: int, - N: int, - K: int, - topk: int, - global_num_experts: int, - local_num_experts: int, - expert_tokens_metadata: mk.ExpertTokensMetadata | None, - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: - # Note: the deep gemm workspaces are strictly larger than the triton - # workspaces so we can be pessimistic here and allocate for DeepGemm - # even if we fall back to triton later, e.g. if expert maps are set. - if self.allow_deep_gemm: - assert self.batched_deep_gemm_experts is not None - return self.batched_deep_gemm_experts.workspace_shapes( - M, - N, - K, - topk, - global_num_experts, - local_num_experts, - expert_tokens_metadata, - ) - else: - assert self.batched_triton_experts is not None - return self.batched_triton_experts.workspace_shapes( - M, - N, - K, - topk, - global_num_experts, - local_num_experts, - expert_tokens_metadata, - ) - - def apply( - self, - output: torch.Tensor, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - expert_map: torch.Tensor | None, - a1q_scale: torch.Tensor | None, - a2_scale: torch.Tensor | None, - workspace13: torch.Tensor, - workspace2: torch.Tensor, - expert_tokens_meta: mk.ExpertTokensMetadata | None, - apply_router_weight_on_input: bool, - ): - experts = ( - self.batched_deep_gemm_experts - if self.allow_deep_gemm - else self.batched_triton_experts - ) - assert experts is not None - experts.apply( - output, - hidden_states, - w1, - w2, - topk_weights, - topk_ids, - activation, - global_num_experts, - expert_map, - a1q_scale, - a2_scale, - workspace13, - workspace2, - expert_tokens_meta, - apply_router_weight_on_input, - ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py index c7368bf427fe1..d7fb6d2ca367d 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -90,8 +90,10 @@ from vllm.platforms import CpuArchEnum, current_platform from vllm.scalar_type import scalar_types from vllm.utils.deep_gemm import ( get_col_major_tma_aligned_tensor, + get_mk_alignment_for_contiguous_layout, is_deep_gemm_e8m0_used, ) +from vllm.utils.import_utils import has_deep_gemm logger = init_logger(__name__) @@ -1088,9 +1090,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): return experts - # triton path - from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501 - BatchedTritonOrDeepGemmExperts, + from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import ( + BatchedDeepGemmExperts, + ) + from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedTritonExperts, ) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts, @@ -1098,6 +1102,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): assert not self.rocm_aiter_moe_enabled and not self.use_marlin + use_deep_gemm = envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM + if ( prepare_finalize.activation_format == FusedMoEActivationFormat.BatchedExperts @@ -1105,22 +1111,47 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() assert max_num_tokens_per_rank is not None - logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) - return BatchedTritonOrDeepGemmExperts( - max_num_tokens=max_num_tokens_per_rank, - num_dispatchers=prepare_finalize.num_dispatchers(), - quant_config=self.moe_quant_config, - allow_deep_gemm=( - envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM - ), + if use_deep_gemm and not has_deep_gemm(): + raise RuntimeError( + "DeepGEMM requested for MoE layer but not installed." + ) + + compatible_with_deep_gemm = ( + self.moe_quant_config.use_fp8_w8a8 + and self.moe_quant_config.block_shape + == get_mk_alignment_for_contiguous_layout() ) + + # If this MoE layer is compatible with DeepGEMM, the proper env + # vars are set and DeepGEMM is not installed, throw an error. + if use_deep_gemm and compatible_with_deep_gemm and not has_deep_gemm(): + raise RuntimeError( + f"MoE layer incompatible with DeepGEMM, expected " + f"fp8==True, got {self.moe_quant_config.use_fp8_w8a8}" + f"or block_shape {self.moe_quant_config.block_shape}" + f"=={get_mk_alignment_for_contiguous_layout()}." + ) + + if use_deep_gemm and compatible_with_deep_gemm and has_deep_gemm(): + logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__) + return BatchedDeepGemmExperts( + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), + quant_config=self.moe_quant_config, + ) + else: + logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__) + return BatchedTritonExperts( + max_num_tokens=max_num_tokens_per_rank, + num_dispatchers=prepare_finalize.num_dispatchers(), + quant_config=self.moe_quant_config, + ) + else: logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__) return TritonOrDeepGemmExperts( self.moe_quant_config, - allow_deep_gemm=( - envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM - ), + allow_deep_gemm=use_deep_gemm, ) def get_fused_moe_quant_config(