From da364615fcc42306c73401f22f57b0c1f8226efb Mon Sep 17 00:00:00 2001 From: bnellnm <49004751+bnellnm@users.noreply.github.com> Date: Wed, 8 Oct 2025 17:51:52 -0400 Subject: [PATCH] [Kernels] Modular kernel refactor (#24812) Signed-off-by: Bill Nell --- .../moe/modular_kernel_tools/common.py | 37 +- .../make_feature_matrix.py | 2 +- .../moe/modular_kernel_tools/mk_objects.py | 26 +- .../moe/test_modular_kernel_combinations.py | 144 ++-- .../layers/fused_moe/batched_deep_gemm_moe.py | 15 +- .../batched_triton_or_deep_gemm_moe.py | 11 +- .../layers/fused_moe/cutlass_moe.py | 57 +- .../layers/fused_moe/deep_gemm_moe.py | 6 +- .../fused_moe/deepep_ht_prepare_finalize.py | 3 + .../fused_moe/deepep_ll_prepare_finalize.py | 3 + .../fused_moe/flashinfer_cutlass_moe.py | 12 +- .../flashinfer_cutlass_prepare_finalize.py | 8 + .../layers/fused_moe/fused_batched_moe.py | 17 +- .../layers/fused_moe/fused_marlin_moe.py | 6 +- .../layers/fused_moe/fused_moe.py | 6 +- .../fused_moe/gpt_oss_triton_kernels_moe.py | 6 +- vllm/model_executor/layers/fused_moe/layer.py | 90 +-- .../layers/fused_moe/modular_kernel.py | 763 ++++++++++-------- .../layers/fused_moe/pplx_prepare_finalize.py | 3 + .../layers/fused_moe/prepare_finalize.py | 3 + .../layers/fused_moe/triton_deep_gemm_moe.py | 8 +- .../layers/fused_moe/trtllm_moe.py | 12 +- 22 files changed, 665 insertions(+), 573 deletions(-) diff --git a/tests/kernels/moe/modular_kernel_tools/common.py b/tests/kernels/moe/modular_kernel_tools/common.py index 091fa4fafe211..ff12d1fb9a805 100644 --- a/tests/kernels/moe/modular_kernel_tools/common.py +++ b/tests/kernels/moe/modular_kernel_tools/common.py @@ -209,18 +209,18 @@ class Config: info = prepare_finalize_info(self.prepare_finalize_type) return info.backend - def is_valid(self): + def is_valid(self) -> tuple[bool, Optional[str]]: # Check prepare-finalize and fused-experts compatibility if self.is_batched_prepare_finalize(): if not self.is_batched_fused_experts(): - return False + return False, "Mismatched format." else: if not self.is_standard_fused_experts(): - return False + return False, "Mismatched format." use_chunking = self.fused_moe_chunk_size is not None if use_chunking and not self.is_fe_supports_chunking(): - return False + return False, "Chunking not supported." # Check quantization sanity if ( @@ -229,7 +229,7 @@ class Config: + int(self.quant_block_shape is not None) ) > 1: # invalid quant config - return False + return False, f"Bad quant_config {self.quant_config}." # check type support if self.quant_dtype is None: @@ -237,34 +237,43 @@ class Config: self.dtype not in self.pf_supported_types() or self.dtype not in self.fe_supported_types() ): - return False + return False, ( + f"Unsupported type {self.dtype} not in " + f"{self.pf_supported_types()} and " + f"{self.fe_supported_types()}." + ) else: if ( self.quant_dtype not in self.pf_supported_types() or self.quant_dtype not in self.fe_supported_types() ): - return False + return False, ( + f"Unsupported quant type {self.quant_dtype} " + f"not in {self.pf_supported_types()} and " + f"{self.fe_supported_types()}." + ) # Check block quanization support is_block_quatized = self.quant_block_shape is not None if is_block_quatized and self.quant_dtype is None: - return False + return False, "No block quantization support." + if is_block_quatized and not self.is_block_quant_supported(): - return False + return False, "Mismatched block quantization support." # deep_gemm only works with block-quantized if self.needs_deep_gemm() and not is_block_quatized: - return False + return False, "Needs DeepGEMM but not block quantized." # Check dependencies (turn into asserts?) if self.needs_deep_ep() and not has_deep_ep(): - return False + return False, "Needs DeepEP, but DeepEP not available." if self.needs_deep_gemm() and not has_deep_gemm(): - return False + return False, "Needs DeepGEMM, but DeepGEMM not available." if self.needs_pplx() and not has_pplx(): # noqa: SIM103 - return False + return False, "Needs PPLX, but PPLX not available." - return True + return True, None @dataclass diff --git a/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py b/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py index 0ef306051c8a4..7d555202afe6a 100644 --- a/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py +++ b/tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py @@ -140,7 +140,7 @@ def make_feature_matrix(csv_file_path: str): ) success = None - if config.is_valid(): + if config.is_valid()[0]: print(f"Running config : {config.describe()} ...") try: weights: WeightTensors = WeightTensors.make(config) diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index 566fb1e09d3b0..174b2d1781ae0 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -244,7 +244,7 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability register_prepare_and_finalize( FlashInferCutlassMoEPrepareAndFinalize, standard_format, - nvfp4_types, + nvfp4_types + fp8_types, blocked_quantization_support=True, backend=None, force_multigpu=True, @@ -254,7 +254,7 @@ if has_flashinfer_cutlass_fused_moe() and current_platform.has_device_capability register_experts( FlashInferExperts, standard_format, - nvfp4_types, + nvfp4_types + fp8_types, blocked_quantization_support=True, supports_chunking=True, # Note: this is a hack to get it to run for now @@ -274,17 +274,15 @@ if has_deep_gemm() and is_deep_gemm_supported(): needs_matching_quant=False, needs_deep_gemm=True, ) - ( - register_experts( - DeepGemmExperts, - standard_format, - fp8_types, - blocked_quantization_support=True, - supports_chunking=True, - supports_expert_map=True, - needs_matching_quant=False, - needs_deep_gemm=True, - ), + register_experts( + DeepGemmExperts, + standard_format, + fp8_types, + blocked_quantization_support=True, + supports_chunking=True, + supports_expert_map=True, + needs_matching_quant=False, + needs_deep_gemm=True, ) register_experts( BatchedTritonOrDeepGemmExperts, @@ -464,7 +462,7 @@ def make_fused_experts( print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...") experts = BatchedTritonOrDeepGemmExperts(**kwargs) elif fused_experts_type == DeepGemmExperts: - print("Making DeepGemmExperts {quant_config} ...") + print(f"Making DeepGemmExperts {quant_config} ...") experts = DeepGemmExperts(quant_config) elif fused_experts_type == TritonExperts: kwargs = quant_kwargs diff --git a/tests/kernels/moe/test_modular_kernel_combinations.py b/tests/kernels/moe/test_modular_kernel_combinations.py index 9c4114523590c..b028e676f086f 100644 --- a/tests/kernels/moe/test_modular_kernel_combinations.py +++ b/tests/kernels/moe/test_modular_kernel_combinations.py @@ -5,7 +5,7 @@ import copy import textwrap import traceback from itertools import product -from typing import Optional +from typing import Any, Optional import pytest import torch @@ -13,10 +13,9 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.config import VllmConfig, set_current_vllm_config from vllm.platforms import current_platform -from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx +from vllm.utils import cuda_device_count_stateless, has_deep_ep, has_deep_gemm, has_pplx from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe -from ...utils import multi_gpu_test from .modular_kernel_tools.common import ( Config, RankTensors, @@ -132,7 +131,8 @@ def rank_worker( def run(config: Config, verbose: bool): - assert config.is_valid() + assert config.is_valid()[0] + assert not is_nyi_config(config) weights: WeightTensors = WeightTensors.make(config) @@ -168,17 +168,77 @@ def is_nyi_config(config: Config) -> bool: return not info.supports_expert_map -@pytest.mark.parametrize("k", Ks) -@pytest.mark.parametrize("n", Ns) -@pytest.mark.parametrize("e", Es) -@pytest.mark.parametrize("dtype", DTYPEs) -@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS) +def generate_valid_test_cases( + world_size: int, prepare_finalize_types +) -> list[tuple[Any, ...]]: + cases = [] + total = 0 + + for k, n, e, dtype, quant_config, combination, chunk_size in product( + Ks, + Ns, + Es, + DTYPEs, + MK_QUANT_CONFIGS, + product(prepare_finalize_types, MK_FUSED_EXPERT_TYPES), + FUSED_MOE_CHUNK_SIZEs, + ): + total = total + 1 + + config = Config( + Ms=Ms, + K=k, + N=n, + E=e, + topks=TOPKs, + dtype=dtype, + quant_config=quant_config, + prepare_finalize_type=combination[0], + fused_experts_type=combination[1], + fused_moe_chunk_size=chunk_size, + world_size=world_size, + ) + + # TODO(bnell): figure out how to get verbose flag here. + verbose = False # pytestconfig.getoption('verbose') > 0 + + valid, reason = config.is_valid() + + if not valid: + if verbose: + print(f"Test config {config} is not valid: {reason}") + continue + + if is_nyi_config(config): + if verbose: + print(f"Test config {config} is nyi.") + continue + + cases.append( + ( + k, + n, + e, + dtype, + quant_config, + combination[0], + combination[1], + chunk_size, + world_size, + ) + ) + + print(f"{len(cases)} of {total} valid configs generated.") + + return cases + + @pytest.mark.parametrize( - "combination", product(MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES) + "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size", + generate_valid_test_cases( + world_size=2, prepare_finalize_types=MK_MULTI_GPU_PREPARE_FINALIZE_TYPES + ), ) -@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) -@pytest.mark.parametrize("world_size", [2]) -@multi_gpu_test(num_gpus=2) @meets_multi_gpu_requirements def test_modular_kernel_combinations_multigpu( k: int, @@ -186,13 +246,19 @@ def test_modular_kernel_combinations_multigpu( e: int, dtype: torch.dtype, quant_config: Optional[TestMoEQuantConfig], - combination: tuple[ - mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute - ], - fused_moe_chunk_size: Optional[int], + prepare_finalize_type: mk.FusedMoEPrepareAndFinalize, + fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute, + chunk_size: Optional[int], world_size: int, pytestconfig, ): + if cuda_device_count_stateless() < world_size: + pytest.skip( + f"Not enough GPUs available to run, got " + f"{cuda_device_count_stateless()} exepected " + f"{world_size}." + ) + config = Config( Ms=Ms, K=k, @@ -201,42 +267,30 @@ def test_modular_kernel_combinations_multigpu( topks=TOPKs, dtype=dtype, quant_config=quant_config, - prepare_finalize_type=combination[0], - fused_experts_type=combination[1], - fused_moe_chunk_size=fused_moe_chunk_size, + prepare_finalize_type=prepare_finalize_type, + fused_experts_type=fused_experts_type, + fused_moe_chunk_size=chunk_size, world_size=world_size, ) - - if not config.is_valid(): - pytest.skip(f"Tests config {config} is not valid. Skipping ...") - - if is_nyi_config(config): - pytest.skip(f"Tests config {config} is nyi. Skipping ...") - verbosity = pytestconfig.getoption("verbose") run(config, verbosity > 0) -@pytest.mark.parametrize("k", Ks) -@pytest.mark.parametrize("n", Ns) -@pytest.mark.parametrize("e", Es) -@pytest.mark.parametrize("dtype", DTYPEs) -@pytest.mark.parametrize("quant_config", MK_QUANT_CONFIGS) @pytest.mark.parametrize( - "combination", product(MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES) + "k,n,e,dtype,quant_config,prepare_finalize_type,fused_experts_type,chunk_size,world_size", + generate_valid_test_cases( + world_size=1, prepare_finalize_types=MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES + ), ) -@pytest.mark.parametrize("fused_moe_chunk_size", FUSED_MOE_CHUNK_SIZEs) -@pytest.mark.parametrize("world_size", [1]) def test_modular_kernel_combinations_singlegpu( k: int, n: int, e: int, dtype: torch.dtype, quant_config: Optional[TestMoEQuantConfig], - combination: tuple[ - mk.FusedMoEPrepareAndFinalize, mk.FusedMoEPermuteExpertsUnpermute - ], - fused_moe_chunk_size: Optional[int], + prepare_finalize_type: mk.FusedMoEPrepareAndFinalize, + fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute, + chunk_size: Optional[int], world_size: int, pytestconfig, ): @@ -248,18 +302,12 @@ def test_modular_kernel_combinations_singlegpu( topks=TOPKs, dtype=dtype, quant_config=quant_config, - prepare_finalize_type=combination[0], - fused_experts_type=combination[1], - fused_moe_chunk_size=fused_moe_chunk_size, + prepare_finalize_type=prepare_finalize_type, + fused_experts_type=fused_experts_type, + fused_moe_chunk_size=chunk_size, world_size=world_size, ) - if not config.is_valid(): - pytest.skip(f"Tests config {config} is not valid. Skipping ...") - - if is_nyi_config(config): - pytest.skip(f"Tests config {config} is nyi. Skipping ...") - verbosity = pytestconfig.getoption("verbose") run(config, verbosity > 0) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index f30ebec76c673..94b18e51da963 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -247,29 +247,24 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_metadata: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: - assert a.dim() == 2 + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # FIXME (varun): We should be able to dispatch only from the leader # DP ranks in the case of TP > 1. At the moment, all the Ranks # end up sending their tokens. This needs to be fixed. num_dispatchers = self.num_dispatchers num_experts = local_num_experts - max_num_tokens = ( - a.size(0) if self.max_num_tokens is None else self.max_num_tokens - ) + max_num_tokens = M if self.max_num_tokens is None else self.max_num_tokens workspace13 = (num_experts, max_num_tokens * num_dispatchers, max(K, N)) workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2)) output = (num_experts, max_num_tokens * num_dispatchers, K) - return (workspace13, workspace2, output, a.dtype) + return (workspace13, workspace2, output) def apply( self, @@ -300,7 +295,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): assert w2.size(1) == K - E, max_num_tokens, N, K, top_k_num = self.moe_problem_size( + E, max_num_tokens, N, K, _ = self.moe_problem_size( hidden_states, w1, w2, topk_ids ) diff --git a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py index d268f70477f4c..09c4de0f87159 100644 --- a/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py @@ -99,10 +99,11 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): assert bte_war is not None return bte_war + def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: + return act_dtype + def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -110,15 +111,13 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): global_num_experts: int, local_num_experts: int, expert_tokens_metadata: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. if self.allow_deep_gemm: assert self.batched_deep_gemm_experts is not None return self.batched_deep_gemm_experts.workspace_shapes( - a, - aq, M, N, K, @@ -130,8 +129,6 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): else: assert self.batched_triton_experts is not None return self.batched_triton_experts.workspace_shapes( - a, - aq, M, N, K, diff --git a/vllm/model_executor/layers/fused_moe/cutlass_moe.py b/vllm/model_executor/layers/fused_moe/cutlass_moe.py index d3fed93329583..fa158287d418d 100644 --- a/vllm/model_executor/layers/fused_moe/cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/cutlass_moe.py @@ -366,10 +366,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): # topk weights and reduction are fused in moe_unpermute cuda kernel return TopKWeightAndReduceNoOP() + def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: + return self.out_dtype if self.out_dtype is not None else act_dtype + def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -377,16 +378,11 @@ class CutlassExpertsFp8(CutlassExpertsFp8Base): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: workspace1 = (M * topk, max(N, K)) workspace2 = (M * topk, max(N // 2, K)) output = (M, K) - return ( - workspace1, - workspace2, - output, - self.out_dtype if self.out_dtype is not None else a.dtype, - ) + return (workspace1, workspace2, output) class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): @@ -428,11 +424,11 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): def supports_expert_map(self) -> bool: return False - # TODO(bnell): maybe remove need for passing aq to workspace_shapes + def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: + return self.out_dtype if self.out_dtype is not None else act_dtype + def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -440,19 +436,13 @@ class CutlassBatchedExpertsFp8(CutlassExpertsFp8Base): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: - padded_M = aq.size(1) + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: num_dp = self.num_dispatchers assert num_dp is not None - workspace1 = (self.max_experts_per_worker, padded_M * num_dp, max(N, K)) - workspace2 = (self.max_experts_per_worker, padded_M * num_dp, max(N // 2, K)) - output = (self.max_experts_per_worker, padded_M, K) - return ( - workspace1, - workspace2, - output, - self.out_dtype if self.out_dtype is not None else a.dtype, - ) + workspace1 = (self.max_experts_per_worker, M * num_dp, max(N, K)) + workspace2 = (self.max_experts_per_worker, M * num_dp, max(N // 2, K)) + output = (self.max_experts_per_worker, M, K) + return (workspace1, workspace2, output) def cutlass_moe_fp8( @@ -767,10 +757,11 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: return TopKWeightAndReduceNoOP() + def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: + return self.out_dtype if self.out_dtype is not None else act_dtype + def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -778,25 +769,19 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: workspace1: tuple[int, ...] = () workspace2: tuple[int, ...] = () output: tuple[int, ...] = () if self.use_batched_format: - padded_M = aq.size(1) - workspace1 = (self.max_experts_per_worker, padded_M, max(N, K)) - workspace2 = (self.max_experts_per_worker, padded_M, (N // 2)) - output = (self.max_experts_per_worker, padded_M, K) + workspace1 = (self.max_experts_per_worker, M, max(N, K)) + workspace2 = (self.max_experts_per_worker, M, (N // 2)) + output = (self.max_experts_per_worker, M, K) else: workspace1 = (M * topk, max(2 * N, K)) workspace2 = (M * topk, N) output = (M, K) - return ( - workspace1, - workspace2, - output, - self.out_dtype if self.out_dtype is not None else a.dtype, - ) + return (workspace1, workspace2, output) def apply( self, diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index fec3a7c5d0a9f..fc0cb5c530da6 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -198,8 +198,6 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -207,7 +205,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: assert self.block_shape is not None block_m = self.block_shape[0] M_sum = compute_aligned_M( @@ -218,7 +216,7 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace1 = (M_sum, max(N, K)) workspace2 = (M_sum, max(N // 2, K)) output = (M, K) - return (workspace1, workspace2, output, a.dtype) + return (workspace1, workspace2, output) def apply( self, diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py index 9a2844b7d998a..85c4fd90dc6c1 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -70,6 +70,9 @@ class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def num_dispatchers(self) -> int: return self.num_dispatchers_ + def output_is_reduced(self) -> bool: + return True + @property def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.Standard diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 6712995b52af2..117bfe6e6b4d7 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -73,6 +73,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def num_dispatchers(self) -> int: return self.num_dispatchers_ + def output_is_reduced(self) -> bool: + return True + @property def activation_format(self) -> mk.FusedMoEActivationFormat: return mk.FusedMoEActivationFormat.BatchedExperts diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index a2d8fe0da1544..1b33c7075fb36 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -90,8 +90,6 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -99,7 +97,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # We use global_num_experts due to how moe_align_block_size handles # expert_maps. """ @@ -118,14 +116,12 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): - Note: in order for activation chunking to work, the first dimension of each tuple must be the number of tokens. """ - aq_m, aq_n = aq.shape + workspace1 = (M, K) workspace2 = (0,) - output_shape = (aq_m, aq_n * 2) if self.quant_dtype == "nvfp4" else (aq_m, aq_n) - workspace_dtype = a.dtype - workspace1 = output_shape + output_shape = (M, K * 2 if self.quant_dtype == "nvfp4" else K) # The workspace is determined by `aq`, since it comes after any # potential communication op and is involved in the expert computation. - return (workspace1, workspace2, output_shape, workspace_dtype) + return (workspace1, workspace2, output_shape) def apply( self, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 04bc987d08855..4907b9ff5730b 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -11,6 +11,9 @@ from vllm.distributed.device_communicators.base_device_communicator import ( ) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceNoOP, +) from vllm.model_executor.layers.fused_moe.utils import moe_kernel_quantize_input from vllm.utils.flashinfer import nvfp4_block_scale_interleave @@ -45,6 +48,9 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def num_dispatchers(self) -> int: return self.num_dispatchers_ + def output_is_reduced(self) -> bool: + return False + def _apply_router_weight_on_input( self, a1: torch.Tensor, @@ -194,6 +200,8 @@ class FlashInferAllGatherMoEPrepareAndFinalize(FlashInferCutlassMoEPrepareAndFin apply_router_weight_on_input: bool, weight_and_reduce_impl: mk.TopKWeightAndReduce, ) -> None: + assert isinstance(weight_and_reduce_impl, TopKWeightAndReduceNoOP) + if self.use_dp: fused_expert_output = get_dp_group().reduce_scatterv( fused_expert_output, dim=0, sizes=get_local_sizes() diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 02a935a1dca21..0c31684d23677 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -509,6 +509,9 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def num_dispatchers(self) -> int: return self.num_dispatchers_ + def output_is_reduced(self) -> bool: + return False + def prepare( self, a1: torch.Tensor, @@ -665,8 +668,6 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -674,14 +675,13 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: - assert a.dim() == 2 + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: num_dp = self.num_dispatchers num_experts = local_num_experts workspace13 = (num_experts, self.max_num_tokens * num_dp, K) workspace2 = (self.max_num_tokens * num_dp, N) output = workspace13 - return (workspace13, workspace2, output, a.dtype) + return (workspace13, workspace2, output) def dequant(self, t: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: assert self.quant_config.is_quantized @@ -862,8 +862,6 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -871,15 +869,14 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: - assert a.dim() == 2 + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: num_dp = self.num_dispatchers num_experts = local_num_experts max_num_tokens = self.max_num_tokens workspace13 = (num_experts, max_num_tokens * num_dp, max(K, N)) workspace2 = (num_experts, max_num_tokens * num_dp, (N // 2)) output = (num_experts, max_num_tokens * num_dp, K) - return (workspace13, workspace2, output, a.dtype) + return (workspace13, workspace2, output) def apply( self, diff --git a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py index c46cc016214f2..b0cc83fd2e450 100644 --- a/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_marlin_moe.py @@ -331,8 +331,6 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute): def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -340,7 +338,7 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # Modular Kernel provisions output buffer from workspace1. However in # the fused_marlin_moe() function, the final torch.sum(), is defined # essentially as, @@ -360,7 +358,7 @@ class MarlinExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace2 = (M * topk * max(2 * N, K),) output = (M, K) - return (workspace1, workspace2, output, a.dtype) + return (workspace1, workspace2, output) def apply( self, diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 2a3abcaadebd8..da7c4a3c55893 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1954,8 +1954,6 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -1963,11 +1961,11 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: workspace1 = (M, topk, max(N // 2, K)) workspace2 = (M, topk, max(N, K)) output = (M, K) - return (workspace1, workspace2, output, a.dtype) + return (workspace1, workspace2, output) def apply( self, diff --git a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py index 39faeed5d10f7..283ce80556d26 100644 --- a/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +++ b/vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py @@ -255,8 +255,6 @@ class OAITritonExperts(BaseOAITritonExperts): def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -264,12 +262,12 @@ class OAITritonExperts(BaseOAITritonExperts): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # workspace are allocated inside the kernel workspace1 = (M, K) workspace2 = (0, 0) output = (M, K) - return (workspace1, workspace2, output, a.dtype) + return (workspace1, workspace2, output) def apply( self, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 9c8ccc6ec0085..94a733aa03b93 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -283,6 +283,10 @@ class FusedMoEMethodBase(QuantizeMethodBase): ) -> Optional[FusedMoEQuantConfig]: raise NotImplementedError + @property + def using_modular_kernel(self) -> bool: + return self.fused_experts is not None + @abstractmethod def apply( self, @@ -1237,39 +1241,25 @@ class FusedMoE(CustomOp): self.batched_hidden_states: Optional[torch.Tensor] = None self.batched_router_logits: Optional[torch.Tensor] = None - # TODO(bnell): flashinfer uses non-batched format. - # Does it really need a batched buffer? - if ( - self.moe_parallel_config.use_pplx_kernels - or self.moe_parallel_config.use_deepep_ll_kernels - or self.moe_config.use_flashinfer_cutlass_kernels - ): + if self.use_dp_chunking: + states_shape: tuple[int, ...] + logits_shape: tuple[int, ...] + + # Note here we use `num_experts` which is logical expert count if vllm_config.parallel_config.enable_dbo: - self.batched_hidden_states = torch.zeros( - (2, moe.max_num_tokens, self.hidden_size), - dtype=moe.in_dtype, - device=torch.cuda.current_device(), - ) - - # Note here we use `num_experts` which is logical expert count - self.batched_router_logits = torch.zeros( - (2, moe.max_num_tokens, num_experts), - dtype=moe.in_dtype, - device=torch.cuda.current_device(), - ) + states_shape = (2, moe.max_num_tokens, self.hidden_size) + logits_shape = (2, moe.max_num_tokens, num_experts) else: - self.batched_hidden_states = torch.zeros( - (moe.max_num_tokens, self.hidden_size), - dtype=moe.in_dtype, - device=torch.cuda.current_device(), - ) + states_shape = (moe.max_num_tokens, self.hidden_size) + logits_shape = (moe.max_num_tokens, num_experts) - # Note here we use `num_experts` which is logical expert count - self.batched_router_logits = torch.zeros( - (moe.max_num_tokens, num_experts), - dtype=moe.in_dtype, - device=torch.cuda.current_device(), - ) + self.batched_hidden_states = torch.zeros( + states_shape, dtype=moe.in_dtype, device=torch.cuda.current_device() + ) + + self.batched_router_logits = torch.zeros( + logits_shape, dtype=moe.in_dtype, device=torch.cuda.current_device() + ) @property def shared_experts(self) -> Optional[torch.nn.Module]: @@ -1323,6 +1313,16 @@ class FusedMoE(CustomOp): and self.moe_config.use_flashinfer_cutlass_kernels ) + @property + def use_dp_chunking(self) -> bool: + # Route to the chunked forward path using the FlashInfer Cutlass kernel + # only when data parallelism (DP) is enabled. + return ( + self.moe_parallel_config.use_pplx_kernels + or self.moe_parallel_config.use_deepep_ll_kernels + or (self.dp_size > 1 and self.use_flashinfer_cutlass_kernels) + ) + def update_expert_map(self): # ep_size and ep_rank should already be updated assert self.expert_map is not None @@ -1987,21 +1987,17 @@ class FusedMoE(CustomOp): Therefore it is required that we reduce the shared_experts output early. """ + assert self.quant_method is not None return ( - self.use_pplx_kernels - or self.use_deepep_ht_kernels - or self.use_deepep_ll_kernels + self.quant_method.fused_experts is not None + and self.quant_method.fused_experts.output_is_reduced() ) def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor): """ - The pplx combine kernel reduces across GPU ranks by default. + Some combine kernels reduce across GPU ranks by default. """ - if ( - self.use_pplx_kernels - or self.use_deepep_ht_kernels - or self.use_deepep_ll_kernels - ): + if self.must_reduce_shared_expert_outputs(): return final_hidden_states else: return tensor_model_parallel_all_reduce(final_hidden_states) @@ -2209,23 +2205,11 @@ class FusedMoE(CustomOp): self.ensure_moe_quant_config() - # Route to the chunked forward path using the FlashInfer Cutlass kernel - # only when data parallelism (DP) is enabled. - _use_flashinfer_cutlass_kernels = ( - self.dp_size > 1 and self.use_flashinfer_cutlass_kernels - ) - - if ( - self.moe_parallel_config.use_pplx_kernels - or self.moe_parallel_config.use_deepep_ll_kernels - or _use_flashinfer_cutlass_kernels - ): + if self.use_dp_chunking: return self.forward_impl_chunked(hidden_states, router_logits) do_naive_dispatch_combine: bool = ( - self.dp_size > 1 - and not self.moe_parallel_config.use_deepep_ht_kernels - and not self.moe_config.use_flashinfer_cutlass_kernels + self.dp_size > 1 and not self.quant_method.using_modular_kernel ) # If there are shared experts but we are not using a modular kernel, the diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 1f6209c9d08e6..62162b6cbae10 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -337,6 +337,14 @@ class FusedMoEPrepareAndFinalize(ABC): def num_dispatchers(self) -> int: raise NotImplementedError + @abstractmethod + def output_is_reduced(self) -> bool: + """ + Indicates whether or not the output of finalize is reduced across all + ranks. + """ + raise NotImplementedError + # TODO: add supported activations method (return string) class FusedMoEPermuteExpertsUnpermute(ABC): @@ -493,11 +501,15 @@ class FusedMoEPermuteExpertsUnpermute(ABC): """ raise NotImplementedError + def workspace_dtype(self, act_dtype: torch.dtype) -> torch.dtype: + """ + Workspace type: The dtype to use for the workspace tensors. + """ + return act_dtype + @abstractmethod def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -505,22 +517,33 @@ class FusedMoEPermuteExpertsUnpermute(ABC): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: """ Compute the shapes for the temporary and final outputs of the two gemms and activation in the fused expert function. Since the gemms are independent, the workspace for the first gemm can be shared with the workspace for the last gemm. + Inputs: + - M: number of tokens. + - N: Row (or column) dimension of expert weights. + - K: hidden dimension + - topk: The number of top-k experts to select. + - global_num_experts: global number of experts. + - local_num_experts: local number of experts due to DP/EP. + - expert_tokens_meta: number of tokens per expert metadata for batched + format. + Returns a tuple of: - workspace13 shape tuple: must be large enough to hold the result of either expert gemm. - workspace2 shape tuple: must be large enough to hold the result of the activation function. - output shape tuple: must be exact size of the final gemm output. - - Workspace type: The dtype to use for the workspace tensors. - - Note: in order for activation chunking to work, the first dimension - of each tuple must be the number of tokens. + - Note: workspace shapes can be 0 if the workspace is not needed. + But in order for activation chunking to work, the first dimension + of each tuple must be the number of tokens when the shape is + not 0. """ raise NotImplementedError @@ -561,7 +584,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): workspace2: torch.Tensor, expert_tokens_meta: Optional[ExpertTokensMetadata], apply_router_weight_on_input: bool, - ): + ) -> None: """ This function computes the intermediate result of a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2. @@ -600,7 +623,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC): raise NotImplementedError -def _chunk_scales( +def _slice_scales( scales: Optional[torch.Tensor], start: int, end: int ) -> Optional[torch.Tensor]: if scales is not None: @@ -615,9 +638,10 @@ class SharedResizableBuffer: def __init__(self): self.buffer = None - def get(self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype): - if shape == () or shape is None: - return None + def get( + self, shape: tuple[int, ...], device: torch.device, dtype: torch.dtype + ) -> torch.Tensor: + assert shape != () shape_numel = prod(shape) if ( self.buffer is None @@ -678,131 +702,63 @@ class FusedMoEModularKernel(torch.nn.Module): f"{fused_experts.activation_formats[0]}" ) - def _do_fused_experts( + def output_is_reduced(self) -> bool: + """ + Indicates whether or not the output of fused MoE kernel + is reduced across all ranks. + """ + return self.prepare_finalize.output_is_reduced() + + def _chunk_info(self, M: int) -> tuple[int, int]: + """ + Compute number of chunks and chunk size for given M. + If chunking is not supported, set the CHUNK_SIZE to M so we + get num_chunks == 1. Take max(M, 1) to avoid divide by zero. + If there are no tokens to process, the number of chunks will be zero. + """ + CHUNK_SIZE = ( + max(M, 1) + if not self.fused_experts.supports_chunking() + else min(M, envs.VLLM_FUSED_MOE_CHUNK_SIZE) + ) + num_chunks = cdiv(M, CHUNK_SIZE) + # If there are no tokens, then there should be no loop iterations. + assert M > 0 or num_chunks == 0 + return num_chunks, CHUNK_SIZE + + def _allocate_buffers( self, - fused_out: Optional[torch.Tensor], - a1: torch.Tensor, - a1q: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, + out_dtype: torch.dtype, + device: torch.device, + M_chunk: int, + M_full: int, + N: int, + K: int, + top_k: int, global_num_experts: int, local_num_experts: int, - expert_map: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], expert_tokens_meta: Optional[ExpertTokensMetadata], - apply_router_weight_on_input: bool, - ) -> torch.Tensor: - _, M, N, K, top_k = self.fused_experts.moe_problem_size(a1q, w1, w2, topk_ids) + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Allocate temporary and output buffers for the fused experts op. + Inputs: + - out_dtype: output type of workspace and output tensors. + - device: the device of the workspace and output tensors. + See `workspace_shapes` for a description of the remainder of arguments. + Returns a tuple of (workspace13, workspace2, output) tensors. + """ + assert M_full > 0 and M_chunk > 0 - (workspace13_shape, workspace2_shape, fused_out_shape, workspace_dtype) = ( - self.fused_experts.workspace_shapes( - a1, - a1q, - M, - N, - K, - top_k, - global_num_experts, - local_num_experts, - expert_tokens_meta, - ) - ) + num_chunks, _ = self._chunk_info(M_full) # select per-ubatch buffers to avoid cross-ubatch reuse under DBO ubatch_idx = dbo_current_ubatch_id() buffers = self.shared_buffers[ubatch_idx] + workspace_dtype = self.fused_experts.workspace_dtype(out_dtype) - # We can reuse the memory between cache1 and cache3 because by the - # time we need cache3, we're done with cache1. - workspace13 = buffers.workspace13.get( - workspace13_shape, device=a1.device, dtype=workspace_dtype - ) - workspace2 = buffers.workspace2.get( - workspace2_shape, device=a1.device, dtype=workspace_dtype - ) - - assert fused_out is None or fused_out.shape == fused_out_shape, ( - f"fused_out {fused_out.shape} but expected {fused_out_shape}" - ) - if fused_out is None: - # reuse workspace13 for the output - fused_out = _resize_cache(workspace13, fused_out_shape) - - self.fused_experts.apply( - fused_out, - a1q, - w1, - w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - a1q_scale=a1q_scale, - a2_scale=a2_scale, - workspace13=workspace13, - workspace2=workspace2, - expert_tokens_meta=expert_tokens_meta, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - - return fused_out - - def _maybe_chunk_fused_experts( - self, - a1: torch.Tensor, - a1q: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str, - global_num_experts: int, - local_num_experts: int, - expert_map: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - expert_tokens_meta: Optional[ExpertTokensMetadata], - apply_router_weight_on_input: bool, - ) -> torch.Tensor: - _, M, N, K, top_k = self.fused_experts.moe_problem_size(a1q, w1, w2, topk_ids) - - CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE - num_chunks = cdiv(M, CHUNK_SIZE) - - # TODO(bnell): get rid of one level here, update slice functions - # to nops on num_chunks==1 - - if not self.fused_experts.supports_chunking() or num_chunks == 1: - return self._do_fused_experts( - fused_out=None, - a1=a1, - a1q=a1q, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=activation, - global_num_experts=global_num_experts, - local_num_experts=local_num_experts, - expert_map=expert_map, - a1q_scale=a1q_scale, - a2_scale=self.fused_experts.a2_scale, - expert_tokens_meta=expert_tokens_meta, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - - # Chunking required case - assert num_chunks > 1 - - # Construct the entire output that can then be processed in chunks. - (_, _, fused_out_shape, _) = self.fused_experts.workspace_shapes( - a1, - a1q, - M, + # Get intermediate workspace shapes based off the chunked M size. + workspace13_shape, workspace2_shape, _ = self.fused_experts.workspace_shapes( + M_chunk, N, K, top_k, @@ -810,102 +766,338 @@ class FusedMoEModularKernel(torch.nn.Module): local_num_experts, expert_tokens_meta, ) - ubatch_idx = dbo_current_ubatch_id() - buffers = self.shared_buffers[ubatch_idx] - fused_out = buffers.fused_out.get( - fused_out_shape, device=a1q.device, dtype=a1.dtype + + # Get final output shape based on the full M size. + _, _, fused_out_shape = self.fused_experts.workspace_shapes( + M_full, + N, + K, + top_k, + global_num_experts, + local_num_experts, + expert_tokens_meta, ) - def slice_input_tensors( - chunk_idx: int, - ) -> tuple[ - torch.Tensor, - Optional[torch.Tensor], - Optional[torch.Tensor], - torch.Tensor, - torch.Tensor, - ]: - s = chunk_idx * CHUNK_SIZE - e = min(s + CHUNK_SIZE, M) - return ( - a1q[s:e], - _chunk_scales(a1q_scale, s, e), - _chunk_scales(self.fused_experts.a2_scale, s, e), - topk_ids[s:e], - topk_weights[s:e], + # We can reuse the memory between cache1 and cache3 because by the + # time we need cache3, we're done with cache1. + workspace13 = buffers.workspace13.get( + workspace13_shape, device=device, dtype=workspace_dtype + ) + workspace2 = buffers.workspace2.get( + workspace2_shape, device=device, dtype=workspace_dtype + ) + + # Construct the entire output that can then be processed in chunks. + # Reuse workspace13 for the output in the non-chunked case as long + # as it is large enough. This will not always be the case for standard + # format experts and with experts that have empty workspaces. + if num_chunks == 1 and prod(workspace13_shape) >= prod(fused_out_shape): + fused_out = _resize_cache(workspace13, fused_out_shape) + else: + fused_out = buffers.fused_out.get( + fused_out_shape, device=device, dtype=out_dtype ) - def slice_output_tensor(chunk_idx: int) -> torch.Tensor: - assert fused_out.size(0) % M == 0, ( - f"fused_out shape {fused_out.shape} vs M {M}" - ) - factor = fused_out.size(0) // M - out_chunk_size = CHUNK_SIZE * factor - s = chunk_idx * out_chunk_size - e = min(s + out_chunk_size, fused_out.size(0)) - return fused_out[s:e] + return workspace13, workspace2, fused_out - def slice_expert_tokens_metadata( - full_expert_tokens_meta: ExpertTokensMetadata, - chunk_topk_ids: torch.Tensor, - local_num_experts: int, - expert_map: Optional[torch.Tensor], - ) -> ExpertTokensMetadata: - # The existing expert_num_tokens is for the entire a1q - # input. Chunking forces recomputation of the number - # of tokens assigned to each expert. - c_expert_num_tokens = count_expert_num_tokens( - chunk_topk_ids, local_num_experts, expert_map + @staticmethod + def _slice_output_tensor( + fused_out: torch.Tensor, + chunk_idx: int, + num_chunks: int, + CHUNK_SIZE: int, + M: int, + ) -> torch.Tensor: + if num_chunks == 1: + return fused_out + + assert fused_out.size(0) % M == 0, f"fused_out shape {fused_out.shape} vs M {M}" + factor = fused_out.size(0) // M + out_chunk_size = CHUNK_SIZE * factor + s = chunk_idx * out_chunk_size + e = min(s + out_chunk_size, fused_out.size(0)) + return fused_out[s:e] + + @staticmethod + def _slice_expert_tokens_metadata( + num_chunks: int, + full_expert_tokens_meta: Optional[ExpertTokensMetadata], + chunk_topk_ids: torch.Tensor, + local_num_experts: int, + expert_map: Optional[torch.Tensor], + ) -> Optional[ExpertTokensMetadata]: + if num_chunks == 1 or full_expert_tokens_meta is None: + return full_expert_tokens_meta + + # The existing expert_num_tokens is for the entire a1q + # input. Chunking forces recomputation of the number + # of tokens assigned to each expert. + c_expert_num_tokens = count_expert_num_tokens( + chunk_topk_ids, local_num_experts, expert_map + ) + + c_expert_num_tokens_cpu = None + need_expert_num_tokens_cpu = ( + full_expert_tokens_meta.expert_num_tokens_cpu is not None + ) + if need_expert_num_tokens_cpu: + # This is blocking as some implementations need the count + # on the CPU to determine appropriate input/out fused-moe + # buffers + c_expert_num_tokens_cpu = c_expert_num_tokens.to("cpu", non_blocking=False) + + return ExpertTokensMetadata( + expert_num_tokens=c_expert_num_tokens, + expert_num_tokens_cpu=c_expert_num_tokens_cpu, + ) + + def _prepare( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + ) -> tuple[ + torch.Tensor, + Optional[torch.Tensor], + Optional[ExpertTokensMetadata], + torch.Tensor, + torch.Tensor, + ]: + """ + The _prepare method is a wrapper around self.prepare_finalize.prepare + that handles DBO and async. + """ + if not self.prepare_finalize.supports_async(): + # We shouldn't be running an a2a kernel that doesn't + # support async prepare/finalize + # TODO(lucas): enable in follow-up + assert not dbo_enabled() + + ( + a1q, + a1q_scale, + expert_tokens_meta, + _expert_topk_ids, + _expert_topk_weights, + ) = self.prepare_finalize.prepare( + hidden_states, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, + self.fused_experts.quant_config, + ) + else: + # Overlap shared expert compute with all2all dispatch. + dbo_maybe_run_recv_hook() + prepare_ret = self.prepare_finalize.prepare_async( + hidden_states, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, + self.fused_experts.quant_config, ) - c_expert_num_tokens_cpu = None - need_expert_num_tokens_cpu = ( - full_expert_tokens_meta.expert_num_tokens_cpu is not None + # TODO(lucas): refactor this in the alternative schedules followup + # currently unpack if we have hook + receiver pair or just + # receiver (see finalize_async docstring) + hook, receiver = ( + prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret) ) - if need_expert_num_tokens_cpu: - # This is blocking as some implementations need the count - # on the CPU to determine appropriate input/out fused-moe - # buffers - c_expert_num_tokens_cpu = c_expert_num_tokens.to( - "cpu", non_blocking=False - ) - return ExpertTokensMetadata( - expert_num_tokens=c_expert_num_tokens, - expert_num_tokens_cpu=c_expert_num_tokens_cpu, + if hook is not None: + if dbo_enabled(): + # If DBO is being used, register the hook with the ubatch + # context and call it in dbo_maybe_run_recv_hook instead of + # passing it to the receiver. + dbo_register_recv_hook(hook) + dbo_yield() + else: + hook() + + ( + a1q, + a1q_scale, + expert_tokens_meta, + _expert_topk_ids, + _expert_topk_weights, + ) = receiver() + + # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. + topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids + topk_weights = ( + topk_weights if _expert_topk_weights is None else _expert_topk_weights + ) + + return a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights + + def _fused_experts( + self, + in_dtype: torch.dtype, + a1q: torch.Tensor, + a1q_scale: Optional[torch.Tensor], + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + local_num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + expert_tokens_meta: Optional[ExpertTokensMetadata], + ) -> torch.Tensor: + _, M_full, N, K, top_k = self.fused_experts.moe_problem_size( + a1q, w1, w2, topk_ids + ) + + num_chunks, CHUNK_SIZE = self._chunk_info(M_full) + + def input_chunk_range(chunk_idx: int) -> tuple[int, int]: + if num_chunks == 1: + # Use a1q.size(0) here since batched format does not + # keep M in the first dimension. + return 0, a1q.size(0) + else: + s = chunk_idx * CHUNK_SIZE + e = min(s + CHUNK_SIZE, M_full) + return s, e + + # This happens when none of the tokens from the all2all reach this + # EP rank. Also, note that this is only relevant for CUDAGraph + # incompatible all2all kernels like the DeepEP high-throughput + # kernels. CUDAGraph compatible all2all kernels like the pplx + # kernels and the DeepEP low-latency kernels are always batched + # and can never run into the tensor.numel() == 0 case. + if M_full == 0: + assert num_chunks == 0 + workspace13 = None + workspace2 = None + fused_out = torch.empty_like(a1q) + else: + assert num_chunks > 0 + workspace13, workspace2, fused_out = self._allocate_buffers( + in_dtype, + a1q.device, + CHUNK_SIZE, + M_full, + N, + K, + top_k, + global_num_experts, + local_num_experts, + expert_tokens_meta, ) for chunk_idx in range(num_chunks): - c_a1q, c_a1q_scale, c_a2_scale, c_topk_ids, c_topk_weights = ( - slice_input_tensors(chunk_idx) + s, e = input_chunk_range(chunk_idx) + + c_expert_tokens_meta = self._slice_expert_tokens_metadata( + num_chunks, + expert_tokens_meta, + topk_ids[s:e], + local_num_experts, + expert_map, ) - c_expert_tokens_meta = None - if expert_tokens_meta is not None: - c_expert_tokens_meta = slice_expert_tokens_metadata( - expert_tokens_meta, c_topk_ids, local_num_experts, expert_map - ) + c_fused_out = self._slice_output_tensor( + fused_out, chunk_idx, num_chunks, CHUNK_SIZE, M_full + ) - self._do_fused_experts( - fused_out=slice_output_tensor(chunk_idx), - a1=a1, - a1q=c_a1q, + self.fused_experts.apply( + output=c_fused_out, + hidden_states=a1q[s:e], w1=w1, w2=w2, - topk_weights=c_topk_weights, - topk_ids=c_topk_ids, + topk_weights=topk_weights[s:e], + topk_ids=topk_ids[s:e], activation=activation, global_num_experts=global_num_experts, - local_num_experts=local_num_experts, expert_map=expert_map, - a1q_scale=c_a1q_scale, - a2_scale=c_a2_scale, + a1q_scale=_slice_scales(a1q_scale, s, e), + a2_scale=_slice_scales(self.fused_experts.a2_scale, e, e), + workspace13=workspace13, + workspace2=workspace2, expert_tokens_meta=c_expert_tokens_meta, apply_router_weight_on_input=apply_router_weight_on_input, ) return fused_out + def _finalize( + self, + output: torch.Tensor, + fused_out: torch.Tensor, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + """ + The _finalize method is a wrapper around self.prepare_finalize.finalize + that handles DBO, async and shared expert overlap. + """ + shared_output: Optional[torch.Tensor] = None + + if not self.prepare_finalize.supports_async(): + assert not dbo_enabled() + + self.prepare_finalize.finalize( + output, + fused_out, + topk_weights, + topk_ids, + apply_router_weight_on_input, + self.fused_experts.finalize_weight_and_reduce_impl(), + ) + if self.shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + else: + finalize_ret = self.prepare_finalize.finalize_async( + output, + fused_out, + topk_weights, + topk_ids, + apply_router_weight_on_input, + self.fused_experts.finalize_weight_and_reduce_impl(), + ) + + if self.shared_experts is not None: + shared_output = self.shared_experts(hidden_states) + + # TODO(lucas): refactor this in the alternative schedules followup + # currently unpack if we have hook + receiver pair or just + # receiver (see finalize_async docstring) + hook, receiver = ( + finalize_ret + if isinstance(finalize_ret, tuple) + else (None, finalize_ret) + ) + + if hook is not None: + if dbo_enabled(): + # If DBO is being used, register the hook with the ubatch + # context and call it in dbo_maybe_run_recv_hook instead of + # passing it to the receiver. + dbo_register_recv_hook(hook) + dbo_yield() + else: + hook() + + receiver() + + if self.shared_experts is None: + return output + else: + assert shared_output is not None + return shared_output, output + def forward( self, hidden_states: torch.Tensor, @@ -947,156 +1139,45 @@ class FusedMoEModularKernel(torch.nn.Module): - torch.Tensor: The output tensor after applying the MoE layer. """ - a1 = hidden_states - output = a1 if inplace and self.shared_experts is None else torch.zeros_like(a1) + if inplace and self.shared_experts is None: + output = hidden_states + else: + output = torch.zeros_like(hidden_states) local_num_experts = w1.size(0) if global_num_experts == -1: global_num_experts = local_num_experts - if not self.prepare_finalize.supports_async(): - # We shouldn't be running an a2a kernel that doesn't - # support async prepare/finalize - # TODO(lucas): enable in follow-up - assert not dbo_enabled() - - ( - a1q, - a1q_scale, - expert_tokens_meta, - _expert_topk_ids, - _expert_topk_weights, - ) = self.prepare_finalize.prepare( - a1, - topk_weights, - topk_ids, - global_num_experts, - expert_map, - apply_router_weight_on_input, - self.fused_experts.quant_config, - ) - else: - # Overlap shared expert compute with all2all dispatch. - dbo_maybe_run_recv_hook() - prepare_ret = self.prepare_finalize.prepare_async( - a1, - topk_weights, - topk_ids, - global_num_experts, - expert_map, - apply_router_weight_on_input, - self.fused_experts.quant_config, - ) - - # TODO(lucas): refactor this in the alternative schedules followup - # currently unpack if we have hook + receiver pair or just - # receiver (see finalize_async docstring) - hook, receiver = ( - prepare_ret if isinstance(prepare_ret, tuple) else (None, prepare_ret) - ) - - if hook is not None: - if dbo_enabled(): - # If DBO is being used, register the hook with the ubatch - # context and call it in dbo_maybe_run_recv_hook instead of - # passing it to the receiver. - dbo_register_recv_hook(hook) - dbo_yield() - else: - hook() - - ( - a1q, - a1q_scale, - expert_tokens_meta, - _expert_topk_ids, - _expert_topk_weights, - ) = receiver() - - # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. - topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids - topk_weights = ( - topk_weights if _expert_topk_weights is None else _expert_topk_weights + a1q, a1q_scale, expert_tokens_meta, topk_ids, topk_weights = self._prepare( + hidden_states, + topk_weights, + topk_ids, + global_num_experts, + expert_map, + apply_router_weight_on_input, ) - fused_out = None + fused_out = self._fused_experts( + in_dtype=hidden_states.dtype, + a1q=a1q, + a1q_scale=a1q_scale, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + activation=activation, + global_num_experts=global_num_experts, + local_num_experts=local_num_experts, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + expert_tokens_meta=expert_tokens_meta, + ) - if a1q.numel() == 0: - # This happens when none of the tokens from the all2all reach this - # EP rank. Also, note that this is only relevant for CUDAGraph - # incompatible all2all kernels like the DeepEP high-throughput - # kernels. CUDAGraph compatible all2all kernels like the pplx - # kernels and the DeepEP low-latency kernels are always batched - # and can never run into the tensor.numel() == 0 case. - fused_out = torch.empty_like(a1q).to(dtype=a1.dtype) - else: - fused_out = self._maybe_chunk_fused_experts( - a1=a1, - a1q=a1q, - w1=w1, - w2=w2, - topk_weights=topk_weights, - topk_ids=topk_ids, - activation=activation, - global_num_experts=global_num_experts, - local_num_experts=local_num_experts, - expert_map=expert_map, - a1q_scale=a1q_scale, - expert_tokens_meta=expert_tokens_meta, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - - shared_output: Optional[torch.Tensor] = None - - if not self.prepare_finalize.supports_async(): - assert not dbo_enabled() - - self.prepare_finalize.finalize( - output, - fused_out, - topk_weights, - topk_ids, - apply_router_weight_on_input, - self.fused_experts.finalize_weight_and_reduce_impl(), - ) - if self.shared_experts is not None: - shared_output = self.shared_experts(a1) - else: - finalize_ret = self.prepare_finalize.finalize_async( - output, - fused_out, - topk_weights, - topk_ids, - apply_router_weight_on_input, - self.fused_experts.finalize_weight_and_reduce_impl(), - ) - - if self.shared_experts is not None: - shared_output = self.shared_experts(a1) - - # TODO(lucas): refactor this in the alternative schedules followup - # currently unpack if we have hook + receiver pair or just - # receiver (see finalize_async docstring) - hook, receiver = ( - finalize_ret - if isinstance(finalize_ret, tuple) - else (None, finalize_ret) - ) - - if hook is not None: - if dbo_enabled(): - # If DBO is being used, register the hook with the ubatch - # context and call it in dbo_maybe_run_recv_hook instead of - # passing it to the receiver. - dbo_register_recv_hook(hook) - dbo_yield() - else: - hook() - - receiver() - - if self.shared_experts is None: - return output - else: - assert shared_output is not None - return shared_output, output + return self._finalize( + output, + fused_out, + hidden_states, + topk_weights, + topk_ids, + apply_router_weight_on_input, + ) diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 79212c2b689dd..e87953e34eaf2 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -91,6 +91,9 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def num_dispatchers(self) -> int: return self.num_dispatchers_ + def output_is_reduced(self) -> bool: + return True + def supports_async(self) -> bool: return True diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index be6939a3f62fd..1e572d2394781 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -27,6 +27,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): def num_dispatchers(self) -> int: return 1 + def output_is_reduced(self) -> bool: + return False + def prepare( self, a1: torch.Tensor, diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 9c35d7d2fe120..94a3ba74e47fd 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -83,8 +83,6 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -92,7 +90,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # Note: the deep gemm workspaces are strictly larger than the triton # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. @@ -101,8 +99,6 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ): assert self.deep_gemm_expert is not None return self.deep_gemm_expert.workspace_shapes( - a, - aq, M, N, K, @@ -113,8 +109,6 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ) else: return self.triton_expert.workspace_shapes( - a, - aq, M, N, K, diff --git a/vllm/model_executor/layers/fused_moe/trtllm_moe.py b/vllm/model_executor/layers/fused_moe/trtllm_moe.py index 8eb724a7435f9..c84d1afeb1f97 100644 --- a/vllm/model_executor/layers/fused_moe/trtllm_moe.py +++ b/vllm/model_executor/layers/fused_moe/trtllm_moe.py @@ -52,8 +52,6 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): def workspace_shapes( self, - a: torch.Tensor, - aq: torch.Tensor, M: int, N: int, K: int, @@ -61,14 +59,12 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute): global_num_experts: int, local_num_experts: int, expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], torch.dtype]: + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # The workspaces for this implementation are managed by flashinfer. - # TODO(varun) : workspace1 is could be used as the output tensor. This - # is error-prone. Allow the `workspace_shapes` to return None workspaces - workspace1 = (M, K) - workspace2 = (0, 0) + workspace1 = (0,) + workspace2 = (0,) output = (M, K) - return (workspace1, workspace2, output, a.dtype) + return (workspace1, workspace2, output) def _get_tile_tokens_dim(self, x: torch.Tensor, top_k: int, local_num_experts: int): # Number of tokens in the input tensor.