mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-09 19:15:26 +08:00
[Kernels] Remove BatchedTritonOrDeepGemmExperts and default fallback to Triton (#29929)
Signed-off-by: Bill Nell <bnell@redhat.com> Signed-off-by: bnellnm <49004751+bnellnm@users.noreply.github.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
This commit is contained in:
parent
ac1886588f
commit
2902c34826
@ -90,7 +90,6 @@ To be used with a particular `FusedMoEPrepareAndFinalize` subclass, MoE kernels
|
||||
| cutlass_fp8 | standard,</br>batched | fp8 | A,T | silu, gelu | Y | Y | [`cutlass_moe_fp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.cutlass_moe_fp8],</br>[`CutlassExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassExpertsFp8],</br>[`CutlasBatchedExpertsFp8`][vllm.model_executor.layers.fused_moe.cutlass_moe.CutlassBatchedExpertsFp8] |
|
||||
| flashinfer | standard | nvfp4,</br>fp8 | T | <sup>5</sup> | N | Y | [`flashinfer_cutlass_moe_fp4`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.flashinfer_cutlass_moe_fp4],</br>[`FlashInferExperts`][vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe.FlashInferExperts] |
|
||||
| gpt oss triton | standard | N/A | N/A | <sup>5</sup> | Y | Y | [`triton_kernel_fused_experts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.triton_kernel_fused_experts],</br>[`OAITritonExperts`][vllm.model_executor.layers.fused_moe.gpt_oss_triton_kernels_moe.OAITritonExperts] |
|
||||
| deep gemm+triton<sup>2</sup> | standard,</br>batched | all<sup>1</sup> | G(128),A,T | silu, gelu | <sup>6</sup> | Y | [`TritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe.TritonOrDeepGemmExperts],</br>[`BatchedTritonOrDeepGemmExperts`][vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe.BatchedTritonOrDeepGemmExperts] |
|
||||
| marlin | standard,</br>batched | <sup>3</sup> / N/A | <sup>3</sup> / N/A | silu,</br>swigluoai | Y | Y | [`fused_marlin_moe`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.fused_marlin_moe],</br>[`MarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.MarlinExperts],</br>[`BatchedMarlinExperts`][vllm.model_executor.layers.fused_moe.fused_marlin_moe.BatchedMarlinExperts] |
|
||||
| trtllm | standard | mxfp4,</br>nvfp4 | G(16),G(32) | <sup>5</sup> | N | Y | [`TrtLlmGenExperts`][vllm.model_executor.layers.fused_moe.trtllm_moe.TrtLlmGenExperts] |
|
||||
| pallas | standard | N/A | N/A | silu | N | N | [`fused_moe`][vllm.model_executor.layers.fused_moe.moe_pallas.fused_moe] |
|
||||
@ -114,5 +113,5 @@ The following table shows "families" of modular kernels that are intended to wor
|
||||
| backend | `FusedMoEPrepareAndFinalize` subclasses | `FusedMoEPermuteExpertsUnpermute` subclasses |
|
||||
|---------|-----------------------------------------|----------------------------------------------|
|
||||
| deepep_high_throughput | `DeepEPHTPrepareAndFinalize` | `DeepGemmExperts`,</br>`TritonExperts`,</br>`TritonOrDeepGemmExperts`,</br>`CutlassExpertsFp8`, </br>`MarlinExperts` |
|
||||
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`BatchedTritonOrDeepGemmExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts` |
|
||||
| deepep_low_latency,</br>pplx | `DeepEPLLPrepareAndFinalize`,</br>`PplxPrepareAndFinalize` | `BatchedDeepGemmExperts`,</br>`BatchedTritonExperts`,</br>`CutlassBatchedExpertsFp8`,</br>`BatchedMarlinExperts` |
|
||||
| flashinfer | `FlashInferCutlassMoEPrepareAndFinalize` | `FlashInferExperts` |
|
||||
|
||||
@ -13,9 +13,6 @@ from vllm.model_executor.layers.fused_moe.all2all_utils import (
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import (
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig,
|
||||
FusedMoEQuantConfig,
|
||||
@ -286,16 +283,6 @@ if has_deep_gemm() and is_deep_gemm_supported():
|
||||
needs_matching_quant=False,
|
||||
needs_deep_gemm=True,
|
||||
)
|
||||
register_experts(
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
batched_format,
|
||||
common_float_and_int_types,
|
||||
blocked_quantization_support=True,
|
||||
supports_chunking=False,
|
||||
supports_expert_map=False,
|
||||
needs_matching_quant=True,
|
||||
needs_deep_gemm=True,
|
||||
)
|
||||
register_experts(
|
||||
TritonOrDeepGemmExperts,
|
||||
standard_format,
|
||||
@ -457,10 +444,6 @@ def make_fused_experts(
|
||||
kwargs = batch_kwargs | quant_kwargs
|
||||
print(f"Making BatchedTritonExperts {kwargs} ...")
|
||||
experts = BatchedTritonExperts(**kwargs)
|
||||
elif fused_experts_type == BatchedTritonOrDeepGemmExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
|
||||
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
|
||||
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
|
||||
elif fused_experts_type == DeepGemmExperts:
|
||||
print(f"Making DeepGemmExperts {quant_config} ...")
|
||||
experts = DeepGemmExperts(quant_config)
|
||||
|
||||
@ -60,9 +60,6 @@ if HAS_TRITON:
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
|
||||
CutlassBatchedExpertsFp8,
|
||||
CutlassExpertsFp8,
|
||||
@ -98,7 +95,6 @@ if HAS_TRITON:
|
||||
"DeepGemmExperts",
|
||||
"BatchedDeepGemmExperts",
|
||||
"TritonOrDeepGemmExperts",
|
||||
"BatchedTritonOrDeepGemmExperts",
|
||||
]
|
||||
else:
|
||||
# Some model classes directly use the custom ops. Add placeholders
|
||||
|
||||
@ -1,180 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import BatchedTritonExperts
|
||||
from vllm.utils.deep_gemm import get_mk_alignment_for_contiguous_layout
|
||||
|
||||
|
||||
class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
def __init__(
|
||||
self,
|
||||
max_num_tokens: int,
|
||||
num_dispatchers: int,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
allow_deep_gemm: bool = False,
|
||||
):
|
||||
super().__init__(quant_config)
|
||||
|
||||
self.batched_triton_experts = BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=num_dispatchers,
|
||||
quant_config=self.quant_config,
|
||||
)
|
||||
|
||||
self.allow_deep_gemm = (
|
||||
allow_deep_gemm
|
||||
and self.quant_config.use_fp8_w8a8
|
||||
and self.block_shape == get_mk_alignment_for_contiguous_layout()
|
||||
)
|
||||
|
||||
self.batched_deep_gemm_experts = (
|
||||
BatchedDeepGemmExperts(
|
||||
max_num_tokens=max_num_tokens,
|
||||
num_dispatchers=num_dispatchers,
|
||||
quant_config=self.quant_config,
|
||||
)
|
||||
if self.allow_deep_gemm
|
||||
else None
|
||||
)
|
||||
|
||||
assert (
|
||||
self.batched_deep_gemm_experts is not None
|
||||
or self.batched_triton_experts is not None
|
||||
)
|
||||
|
||||
@property
|
||||
def activation_formats(
|
||||
self,
|
||||
) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]:
|
||||
if self.batched_triton_experts is not None:
|
||||
assert (
|
||||
self.batched_deep_gemm_experts is None
|
||||
or self.batched_deep_gemm_experts.activation_formats
|
||||
== self.batched_triton_experts.activation_formats
|
||||
)
|
||||
return self.batched_triton_experts.activation_formats
|
||||
else:
|
||||
assert self.batched_deep_gemm_experts is not None
|
||||
return self.batched_deep_gemm_experts.activation_formats
|
||||
|
||||
def supports_chunking(self) -> bool:
|
||||
bdge = self.batched_deep_gemm_experts
|
||||
bte = self.batched_triton_experts
|
||||
return (bdge is None or bdge.supports_chunking()) and (
|
||||
bte is None or bte.supports_chunking()
|
||||
)
|
||||
|
||||
def supports_expert_map(self) -> bool:
|
||||
bdge = self.batched_deep_gemm_experts
|
||||
bte = self.batched_triton_experts
|
||||
return (bdge is None or bdge.supports_expert_map()) and (
|
||||
bte is None or bte.supports_expert_map()
|
||||
)
|
||||
|
||||
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
|
||||
bdge = self.batched_deep_gemm_experts
|
||||
bte = self.batched_triton_experts
|
||||
bdge_war = bdge.finalize_weight_and_reduce_impl() if bdge else None
|
||||
bte_war = bte.finalize_weight_and_reduce_impl() if bte else None
|
||||
is_bdge_war = bdge_war is not None
|
||||
is_bte_war = bte_war is not None
|
||||
|
||||
if is_bdge_war and is_bte_war:
|
||||
assert bdge_war == bte_war, (
|
||||
"Both implementations should agree on WeightAndReduce impls. "
|
||||
f"Got bdge_war: {bdge_war}, and bte_war: {bte_war}"
|
||||
)
|
||||
|
||||
if bdge_war is not None:
|
||||
return bdge_war
|
||||
|
||||
assert bte_war is not None
|
||||
return bte_war
|
||||
|
||||
def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype:
|
||||
return act_dtype
|
||||
|
||||
def workspace_shapes(
|
||||
self,
|
||||
M: int,
|
||||
N: int,
|
||||
K: int,
|
||||
topk: int,
|
||||
global_num_experts: int,
|
||||
local_num_experts: int,
|
||||
expert_tokens_metadata: mk.ExpertTokensMetadata | None,
|
||||
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
|
||||
# Note: the deep gemm workspaces are strictly larger than the triton
|
||||
# workspaces so we can be pessimistic here and allocate for DeepGemm
|
||||
# even if we fall back to triton later, e.g. if expert maps are set.
|
||||
if self.allow_deep_gemm:
|
||||
assert self.batched_deep_gemm_experts is not None
|
||||
return self.batched_deep_gemm_experts.workspace_shapes(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
topk,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_metadata,
|
||||
)
|
||||
else:
|
||||
assert self.batched_triton_experts is not None
|
||||
return self.batched_triton_experts.workspace_shapes(
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
topk,
|
||||
global_num_experts,
|
||||
local_num_experts,
|
||||
expert_tokens_metadata,
|
||||
)
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor | None,
|
||||
a1q_scale: torch.Tensor | None,
|
||||
a2_scale: torch.Tensor | None,
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: mk.ExpertTokensMetadata | None,
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
experts = (
|
||||
self.batched_deep_gemm_experts
|
||||
if self.allow_deep_gemm
|
||||
else self.batched_triton_experts
|
||||
)
|
||||
assert experts is not None
|
||||
experts.apply(
|
||||
output,
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
activation,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
a1q_scale,
|
||||
a2_scale,
|
||||
workspace13,
|
||||
workspace2,
|
||||
expert_tokens_meta,
|
||||
apply_router_weight_on_input,
|
||||
)
|
||||
@ -90,8 +90,10 @@ from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.utils.deep_gemm import (
|
||||
get_col_major_tma_aligned_tensor,
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
is_deep_gemm_e8m0_used,
|
||||
)
|
||||
from vllm.utils.import_utils import has_deep_gemm
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -1088,9 +1090,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
return experts
|
||||
|
||||
# triton path
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts,
|
||||
)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts,
|
||||
@ -1098,6 +1102,8 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
assert not self.rocm_aiter_moe_enabled and not self.use_marlin
|
||||
|
||||
use_deep_gemm = envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM
|
||||
|
||||
if (
|
||||
prepare_finalize.activation_format
|
||||
== FusedMoEActivationFormat.BatchedExperts
|
||||
@ -1105,22 +1111,47 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank()
|
||||
assert max_num_tokens_per_rank is not None
|
||||
|
||||
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
|
||||
return BatchedTritonOrDeepGemmExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=self.moe_quant_config,
|
||||
allow_deep_gemm=(
|
||||
envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM
|
||||
),
|
||||
if use_deep_gemm and not has_deep_gemm():
|
||||
raise RuntimeError(
|
||||
"DeepGEMM requested for MoE layer but not installed."
|
||||
)
|
||||
|
||||
compatible_with_deep_gemm = (
|
||||
self.moe_quant_config.use_fp8_w8a8
|
||||
and self.moe_quant_config.block_shape
|
||||
== get_mk_alignment_for_contiguous_layout()
|
||||
)
|
||||
|
||||
# If this MoE layer is compatible with DeepGEMM, the proper env
|
||||
# vars are set and DeepGEMM is not installed, throw an error.
|
||||
if use_deep_gemm and compatible_with_deep_gemm and not has_deep_gemm():
|
||||
raise RuntimeError(
|
||||
f"MoE layer incompatible with DeepGEMM, expected "
|
||||
f"fp8==True, got {self.moe_quant_config.use_fp8_w8a8}"
|
||||
f"or block_shape {self.moe_quant_config.block_shape}"
|
||||
f"=={get_mk_alignment_for_contiguous_layout()}."
|
||||
)
|
||||
|
||||
if use_deep_gemm and compatible_with_deep_gemm and has_deep_gemm():
|
||||
logger.debug("BatchedDeepGemmExperts(%s)", self.__class__.__name__)
|
||||
return BatchedDeepGemmExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
else:
|
||||
logger.debug("BatchedTritonExperts(%s)", self.__class__.__name__)
|
||||
return BatchedTritonExperts(
|
||||
max_num_tokens=max_num_tokens_per_rank,
|
||||
num_dispatchers=prepare_finalize.num_dispatchers(),
|
||||
quant_config=self.moe_quant_config,
|
||||
)
|
||||
|
||||
else:
|
||||
logger.debug("TritonOrDeepGemmExperts(%s)", self.__class__.__name__)
|
||||
return TritonOrDeepGemmExperts(
|
||||
self.moe_quant_config,
|
||||
allow_deep_gemm=(
|
||||
envs.VLLM_USE_DEEP_GEMM and envs.VLLM_MOE_USE_DEEP_GEMM
|
||||
),
|
||||
allow_deep_gemm=use_deep_gemm,
|
||||
)
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user