mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 05:15:42 +08:00
[Performance][DeepGEMM] Estimate expected_m (#28694)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
parent
c9e665852a
commit
6965ef436f
@ -7,6 +7,7 @@ fp8 block-quantized case.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytest
|
||||
import torch.distributed
|
||||
@ -14,6 +15,7 @@ from torch.distributed import ProcessGroup
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEQuantConfig,
|
||||
fp8_w8a8_moe_quant_config,
|
||||
@ -61,6 +63,23 @@ requires_deep_gemm = pytest.mark.skipif(
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def with_dp_metadata(M: int, world_size: int):
|
||||
num_tokens_across_dp = torch.tensor([M] * world_size, device="cpu", dtype=torch.int)
|
||||
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.parallel_config.data_parallel_size = world_size
|
||||
vllm_config.parallel_config.enable_expert_parallel = True
|
||||
|
||||
with set_forward_context(
|
||||
None,
|
||||
vllm_config,
|
||||
num_tokens=M,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
def next_power_of_2(x):
|
||||
import math
|
||||
|
||||
@ -285,18 +304,21 @@ def deepep_deepgemm_moe_impl(
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
out = mk.forward(
|
||||
hidden_states=test_tensors.rank_tokens,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
topk_ids=test_tensors.topk,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
with with_dp_metadata(
|
||||
M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size
|
||||
):
|
||||
out = mk.forward(
|
||||
hidden_states=test_tensors.rank_tokens,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
topk_weights=test_tensors.topk_weights,
|
||||
topk_ids=test_tensors.topk,
|
||||
inplace=False,
|
||||
activation="silu",
|
||||
global_num_experts=num_experts,
|
||||
expert_map=build_expert_map(),
|
||||
apply_router_weight_on_input=False,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
@ -221,6 +221,10 @@ def get_forward_context() -> ForwardContext:
|
||||
return _forward_context
|
||||
|
||||
|
||||
def is_forward_context_available() -> bool:
|
||||
return _forward_context is not None
|
||||
|
||||
|
||||
def create_forward_context(
|
||||
attn_metadata: Any,
|
||||
vllm_config: VllmConfig,
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.forward_context import get_forward_context, is_forward_context_available
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
@ -19,7 +20,7 @@ from vllm.utils.deep_gemm import (
|
||||
get_mk_alignment_for_contiguous_layout,
|
||||
is_deep_gemm_e8m0_used,
|
||||
)
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -313,6 +314,33 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
output = (num_experts, max_num_tokens * num_dispatchers, K)
|
||||
return (workspace13, workspace2, output)
|
||||
|
||||
def estimate_expected_m(
|
||||
self, global_num_experts: int, max_tokens_per_expert: int, topk: int
|
||||
) -> int:
|
||||
dp_meta = (
|
||||
get_forward_context().dp_metadata
|
||||
if is_forward_context_available()
|
||||
else None
|
||||
)
|
||||
if dp_meta is None:
|
||||
logger.warning_once(
|
||||
"DPMetadata unavailable. Defaulting expected_m to "
|
||||
f"{max_tokens_per_expert}.",
|
||||
scope="local",
|
||||
)
|
||||
return max_tokens_per_expert
|
||||
|
||||
total_num_tokens = dp_meta.num_tokens_across_dp_cpu.sum().item()
|
||||
total_num_tokens_replicated = total_num_tokens * topk
|
||||
|
||||
# Assume even load balancing
|
||||
assert global_num_experts != 0
|
||||
estimate = round_up(int(total_num_tokens_replicated // global_num_experts), 16)
|
||||
# clamp estimate
|
||||
estimate = max(estimate, 16)
|
||||
estimate = min(max_tokens_per_expert, estimate)
|
||||
return estimate
|
||||
|
||||
def apply(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
@ -348,10 +376,12 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
|
||||
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
|
||||
|
||||
# (from deepgemm docs) : A value hint (which is a value on CPU)
|
||||
# for the M expectation of each batch, correctly setting this value
|
||||
# may lead to better performance.
|
||||
expected_m = max_num_tokens
|
||||
expected_m = self.estimate_expected_m(
|
||||
global_num_experts=global_num_experts,
|
||||
max_tokens_per_expert=max_num_tokens,
|
||||
topk=topk_ids.size(-1),
|
||||
)
|
||||
|
||||
fp8_m_grouped_gemm_nt_masked(
|
||||
(a1q, a1q_scale),
|
||||
(w1, self.w1_scale),
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user