[Kernels] Remove BatchedTritonOrDeepGemmExperts and default fallback to Triton (#29929)

Signed-off-by: Bill Nell <bnell@redhat.com>
Signed-off-by: bnellnm <49004751+bnellnm@users.noreply.github.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
bnellnm 2025-12-03 15:49:00 -05:00 committed by GitHub
parent ac1886588f
commit 2902c34826
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 46 additions and 217 deletions

View File

@ -90,7 +90,6 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
| cutlass_fp8 | standard,</br>batched | fp8 | A,T | silu, gelu | Y | Y | [`cutlass_moe_fp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp8],</br>[`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],</br>[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] |
| flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],</br>[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
| deep gemm+triton<sup>2</sup> | standard,</br>batched | all<sup>1</sup> | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],</br>[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] |
| marlin | standard,</br>batched | <sup>3</sup> / N/A | <sup>3</sup> / N/A | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
@ -114,5 +113,5 @@ The following table shows "families" of modular kernels that are intended to wor
| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses |
|---------|-----------------------------------------|----------------------------------------------|
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts` |
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts` |
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |

View File

@ -13,9 +13,6 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import (
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import (
BatchedTritonOrDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEConfig,
FusedMoEQuantConfig,
@ -286,16 +283,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
needs_matching_quant=False,
needs_deep_gemm=True,
)
register_experts(
BatchedTritonOrDeepGemmExperts,
batched_format,
common_float_and_int_types,
blocked_quantization_support=True,
supports_chunking=False,
supports_expert_map=False,
needs_matching_quant=True,
needs_deep_gemm=True,
)
register_experts(
TritonOrDeepGemmExperts,
standard_format,
@ -457,10 +444,6 @@ def make_fused_experts(
kwargs = batch_kwargs | quant_kwargs
print(f"Making BatchedTritonExperts {kwargs} ...")
experts = BatchedTritonExperts(**kwargs)
elif fused_experts_type == BatchedTritonOrDeepGemmExperts:
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
elif fused_experts_type == DeepGemmExperts:
print(f"Making DeepGemmExperts {quant_config} ...")
experts = DeepGemmExperts(quant_config)

View File

@ -60,9 +60,6 @@ if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
CutlassBatchedExpertsFp8,
CutlassExpertsFp8,
@ -98,7 +95,6 @@ if HAS_TRITON:
"DeepGemmExperts",
"BatchedDeepGemmExperts",
"TritonOrDeepGemmExperts",
"BatchedTritonOrDeepGemmExperts",
]
else:
# Some model classes directly use the custom ops. Add placeholders

View File

@ -1,180 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
from vllm.utils.deep_gemm import get_mk_alignment_for_contiguous_layout
class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
def __init__(
self,
max_num_tokens: int,
num_dispatchers: int,
quant_config: FusedMoEQuantConfig,
allow_deep_gemm: bool = False,
):
super().__init__(quant_config)
self.batched_triton_experts = BatchedTritonExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
quant_config=self.quant_config,
)
self.allow_deep_gemm = (
allow_deep_gemm
and self.quant_config.use_fp8_w8a8
and self.block_shape == get_mk_alignment_for_contiguous_layout()
)
self.batched_deep_gemm_experts = (
BatchedDeepGemmExperts(
max_num_tokens=max_num_tokens,
num_dispatchers=num_dispatchers,
quant_config=self.quant_config,
)
if self.allow_deep_gemm
else None
)
assert (
self.batched_deep_gemm_experts is not None
or self.batched_triton_experts is not None
)
@property
def activation_formats(
self,
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
if self.batched_triton_experts is not None:
assert (
self.batched_deep_gemm_experts is None
or self.batched_deep_gemm_experts.activation_formats
== self.batched_triton_experts.activation_formats
)
return self.batched_triton_experts.activation_formats
else:
assert self.batched_deep_gemm_experts is not None
return self.batched_deep_gemm_experts.activation_formats
def supports_chunking(self) -> bool:
bdge = self.batched_deep_gemm_experts
bte = self.batched_triton_experts
return (bdge is None or bdge.supports_chunking()) and (
bte is None or bte.supports_chunking()
)
def supports_expert_map(self) -> bool:
bdge = self.batched_deep_gemm_experts
bte = self.batched_triton_experts
return (bdge is None or bdge.supports_expert_map()) and (
bte is None or bte.supports_expert_map()
)
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
bdge = self.batched_deep_gemm_experts
bte = self.batched_triton_experts
bdge_war = bdge.finalize_weight_and_reduce_impl() if bdge else None
bte_war = bte.finalize_weight_and_reduce_impl() if bte else None
is_bdge_war = bdge_war is not None
is_bte_war = bte_war is not None
if is_bdge_war and is_bte_war:
assert bdge_war == bte_war, (
"Both implementations should agree on WeightAndReduce impls. "
f"Got bdge_war: {bdge_war}, and bte_war: {bte_war}"
)
if bdge_war is not None:
return bdge_war
assert bte_war is not None
return bte_war
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
return act_dtype
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_metadata: mk.ExpertTokensMetadata | None,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
# Note: the deep gemm workspaces are strictly larger than the triton
# workspaces so we can be pessimistic here and allocate for DeepGemm
# even if we fall back to triton later, e.g. if expert maps are set.
if self.allow_deep_gemm:
assert self.batched_deep_gemm_experts is not None
return self.batched_deep_gemm_experts.workspace_shapes(
M,
N,
K,
topk,
global_num_experts,
local_num_experts,
expert_tokens_metadata,
)
else:
assert self.batched_triton_experts is not None
return self.batched_triton_experts.workspace_shapes(
M,
N,
K,
topk,
global_num_experts,
local_num_experts,
expert_tokens_metadata,
)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
experts = (
self.batched_deep_gemm_experts
if self.allow_deep_gemm
else self.batched_triton_experts
)
assert experts is not None
experts.apply(
output,
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
activation,
global_num_experts,
expert_map,
a1q_scale,
a2_scale,
workspace13,
workspace2,
expert_tokens_meta,
apply_router_weight_on_input,
)

View File

@ -90,8 +90,10 @@ from vllm.platforms import CpuArchEnum, current_platform
from vllm.scalar_type import scalar_types
from vllm.utils.deep_gemm import (
get_col_major_tma_aligned_tensor,
get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used,
)
from vllm.utils.import_utils import has_deep_gemm
logger = init_logger(__name__)
@ -1088,9 +1090,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
return experts
# triton path
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
BatchedTritonOrDeepGemmExperts,
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
BatchedDeepGemmExperts,
)
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
BatchedTritonExperts,
)
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
TritonOrDeepGemmExperts,
@ -1098,6 +1102,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
assert not self.rocm_aiter_moe_enabled and not self.use_marlin
use_deep_gemm = envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM
if (
prepare_finalize.activation_format
== FusedMoEActivationFormat.BatchedExperts
@ -1105,22 +1111,47 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
assert max_num_tokens_per_rank is not None
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
return BatchedTritonOrDeepGemmExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
allow_deep_gemm=(
envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM
),
if use_deep_gemm and not has_deep_gemm():
raise RuntimeError(
"DeepGEMM requested for MoE layer but not installed."
)
compatible_with_deep_gemm = (
self.moe_quant_config.use_fp8_w8a8
and self.moe_quant_config.block_shape
== get_mk_alignment_for_contiguous_layout()
)
# If this MoE layer is compatible with DeepGEMM, the proper env
# vars are set and DeepGEMM is not installed, throw an error.
if use_deep_gemm and compatible_with_deep_gemm and not has_deep_gemm():
raise RuntimeError(
f"MoE layer incompatible with DeepGEMM, expected "
f"fp8==True, got {self.moe_quant_config.use_fp8_w8a8}"
f"or block_shape {self.moe_quant_config.block_shape}"
f"=={get_mk_alignment_for_contiguous_layout()}."
)
if use_deep_gemm and compatible_with_deep_gemm and has_deep_gemm():
logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__)
return BatchedDeepGemmExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
)
else:
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
return BatchedTritonExperts(
max_num_tokens=max_num_tokens_per_rank,
num_dispatchers=prepare_finalize.num_dispatchers(),
quant_config=self.moe_quant_config,
)
else:
logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__)
return TritonOrDeepGemmExperts(
self.moe_quant_config,
allow_deep_gemm=(
envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM
),
allow_deep_gemm=use_deep_gemm,
)
def get_fused_moe_quant_config(