From 6965ef436fb398bfbbdce5b6f88dd842c5944771 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sat, 15 Nov 2025 00:52:14 -0500 Subject: [PATCH] [Performance][DeepGEMM] Estimate expected_m (#28694) Signed-off-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath --- tests/kernels/moe/test_deepep_deepgemm_moe.py | 46 ++++++++++++++----- vllm/forward_context.py | 4 ++ .../layers/fused_moe/batched_deep_gemm_moe.py | 40 ++++++++++++++-- 3 files changed, 73 insertions(+), 17 deletions(-) diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py index 0faf8bc95d2e..455ecacef5ec 100644 --- a/tests/kernels/moe/test_deepep_deepgemm_moe.py +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -7,6 +7,7 @@ fp8 block-quantized case. """ import dataclasses +from contextlib import contextmanager import pytest import torch.distributed @@ -14,6 +15,7 @@ from torch.distributed import ProcessGroup from typing_extensions import ParamSpec from vllm.config import VllmConfig, set_current_vllm_config +from vllm.forward_context import set_forward_context from vllm.model_executor.layers.fused_moe.config import ( FusedMoEQuantConfig, fp8_w8a8_moe_quant_config, @@ -61,6 +63,23 @@ requires_deep_gemm = pytest.mark.skipif( P = ParamSpec("P") +@contextmanager +def with_dp_metadata(M: int, world_size: int): + num_tokens_across_dp = torch.tensor([M] * world_size, device="cpu", dtype=torch.int) + + vllm_config = VllmConfig() + vllm_config.parallel_config.data_parallel_size = world_size + vllm_config.parallel_config.enable_expert_parallel = True + + with set_forward_context( + None, + vllm_config, + num_tokens=M, + num_tokens_across_dp=num_tokens_across_dp, + ): + yield + + def next_power_of_2(x): import math @@ -285,18 +304,21 @@ def deepep_deepgemm_moe_impl( quant_config=quant_config, ) - out = mk.forward( - hidden_states=test_tensors.rank_tokens, - w1=w1, - w2=w2, - topk_weights=test_tensors.topk_weights, - topk_ids=test_tensors.topk, - inplace=False, - activation="silu", - global_num_experts=num_experts, - expert_map=build_expert_map(), - apply_router_weight_on_input=False, - ) + with with_dp_metadata( + M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size + ): + out = mk.forward( + hidden_states=test_tensors.rank_tokens, + w1=w1, + w2=w2, + topk_weights=test_tensors.topk_weights, + topk_ids=test_tensors.topk, + inplace=False, + activation="silu", + global_num_experts=num_experts, + expert_map=build_expert_map(), + apply_router_weight_on_input=False, + ) return out diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 44bc2a4cda31..25fb7181a8f2 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -221,6 +221,10 @@ def get_forward_context() -> ForwardContext: return _forward_context +def is_forward_context_available() -> bool: + return _forward_context is not None + + def create_forward_context( attn_metadata: Any, vllm_config: VllmConfig, diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 79c92eb48612..53362277dae8 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -5,6 +5,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.forward_context import get_forward_context, is_forward_context_available from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( @@ -19,7 +20,7 @@ from vllm.utils.deep_gemm import ( get_mk_alignment_for_contiguous_layout, is_deep_gemm_e8m0_used, ) -from vllm.utils.math_utils import cdiv +from vllm.utils.math_utils import cdiv, round_up logger = init_logger(__name__) @@ -313,6 +314,33 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): output = (num_experts, max_num_tokens * num_dispatchers, K) return (workspace13, workspace2, output) + def estimate_expected_m( + self, global_num_experts: int, max_tokens_per_expert: int, topk: int + ) -> int: + dp_meta = ( + get_forward_context().dp_metadata + if is_forward_context_available() + else None + ) + if dp_meta is None: + logger.warning_once( + "DPMetadata unavailable. Defaulting expected_m to " + f"{max_tokens_per_expert}.", + scope="local", + ) + return max_tokens_per_expert + + total_num_tokens = dp_meta.num_tokens_across_dp_cpu.sum().item() + total_num_tokens_replicated = total_num_tokens * topk + + # Assume even load balancing + assert global_num_experts != 0 + estimate = round_up(int(total_num_tokens_replicated // global_num_experts), 16) + # clamp estimate + estimate = max(estimate, 16) + estimate = min(max_tokens_per_expert, estimate) + return estimate + def apply( self, output: torch.Tensor, @@ -348,10 +376,12 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N)) - # (from deepgemm docs) : A value hint (which is a value on CPU) - # for the M expectation of each batch, correctly setting this value - # may lead to better performance. - expected_m = max_num_tokens + expected_m = self.estimate_expected_m( + global_num_experts=global_num_experts, + max_tokens_per_expert=max_num_tokens, + topk=topk_ids.size(-1), + ) + fp8_m_grouped_gemm_nt_masked( (a1q, a1q_scale), (w1, self.w1_scale),