[Performance][DeepGEMM] Estimate expected_m (#28694)

Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
Varun Sundar Rabindranath 2025-11-15 00:52:14 -05:00 committed by GitHub
parent c9e665852a
commit 6965ef436f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 73 additions and 17 deletions

View File

@ -7,6 +7,7 @@ fp8 block-quantized case.
"""
import dataclasses
from contextlib import contextmanager
import pytest
import torch.distributed
@ -14,6 +15,7 @@ from torch.distributed import ProcessGroup
from typing_extensions import ParamSpec
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.model_executor.layers.fused_moe.config import (
FusedMoEQuantConfig,
fp8_w8a8_moe_quant_config,
@ -61,6 +63,23 @@ requires_deep_gemm = pytest.mark.skipif(
P = ParamSpec("P")
@contextmanager
def with_dp_metadata(M: int, world_size: int):
num_tokens_across_dp = torch.tensor([M] * world_size, device="cpu", dtype=torch.int)
vllm_config = VllmConfig()
vllm_config.parallel_config.data_parallel_size = world_size
vllm_config.parallel_config.enable_expert_parallel = True
with set_forward_context(
None,
vllm_config,
num_tokens=M,
num_tokens_across_dp=num_tokens_across_dp,
):
yield
def next_power_of_2(x):
import math
@ -285,18 +304,21 @@ def deepep_deepgemm_moe_impl(
quant_config=quant_config,
)
out = mk.forward(
hidden_states=test_tensors.rank_tokens,
w1=w1,
w2=w2,
topk_weights=test_tensors.topk_weights,
topk_ids=test_tensors.topk,
inplace=False,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
apply_router_weight_on_input=False,
)
with with_dp_metadata(
M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size
):
out = mk.forward(
hidden_states=test_tensors.rank_tokens,
w1=w1,
w2=w2,
topk_weights=test_tensors.topk_weights,
topk_ids=test_tensors.topk,
inplace=False,
activation="silu",
global_num_experts=num_experts,
expert_map=build_expert_map(),
apply_router_weight_on_input=False,
)
return out

View File

@ -221,6 +221,10 @@ def get_forward_context() -> ForwardContext:
return _forward_context
def is_forward_context_available() -> bool:
return _forward_context is not None
def create_forward_context(
attn_metadata: Any,
vllm_config: VllmConfig,

View File

@ -5,6 +5,7 @@
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.forward_context import get_forward_context, is_forward_context_available
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
@ -19,7 +20,7 @@ from vllm.utils.deep_gemm import (
get_mk_alignment_for_contiguous_layout,
is_deep_gemm_e8m0_used,
)
from vllm.utils.math_utils import cdiv
from vllm.utils.math_utils import cdiv, round_up
logger = init_logger(__name__)
@ -313,6 +314,33 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output)
def estimate_expected_m(
self, global_num_experts: int, max_tokens_per_expert: int, topk: int
) -> int:
dp_meta = (
get_forward_context().dp_metadata
if is_forward_context_available()
else None
)
if dp_meta is None:
logger.warning_once(
"DPMetadata unavailable. Defaulting expected_m to "
f"{max_tokens_per_expert}.",
scope="local",
)
return max_tokens_per_expert
total_num_tokens = dp_meta.num_tokens_across_dp_cpu.sum().item()
total_num_tokens_replicated = total_num_tokens * topk
# Assume even load balancing
assert global_num_experts != 0
estimate = round_up(int(total_num_tokens_replicated // global_num_experts), 16)
# clamp estimate
estimate = max(estimate, 16)
estimate = min(max_tokens_per_expert, estimate)
return estimate
def apply(
self,
output: torch.Tensor,
@ -348,10 +376,12 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
expected_m = max_num_tokens
expected_m = self.estimate_expected_m(
global_num_experts=global_num_experts,
max_tokens_per_expert=max_num_tokens,
topk=topk_ids.size(-1),
)
fp8_m_grouped_gemm_nt_masked(
(a1q, a1q_scale),
(w1, self.w1_scale),