mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-14 04:04:57 +08:00
[Performance][DeepGEMM] Estimate expected_m (#28694)
Signed-off-by: Varun Sundar Rabindranath <vsundarr@redhat.com> Co-authored-by: Varun Sundar Rabindranath <vsundarr@redhat.com>
This commit is contained in:
parent
c9e665852a
commit
6965ef436f
@ -7,6 +7,7 @@ fp8 block-quantized case.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
@ -14,6 +15,7 @@ from torch.distributed import ProcessGroup
|
|||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
from vllm.config import VllmConfig, set_current_vllm_config
|
from vllm.config import VllmConfig, set_current_vllm_config
|
||||||
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.model_executor.layers.fused_moe.config import (
|
from vllm.model_executor.layers.fused_moe.config import (
|
||||||
FusedMoEQuantConfig,
|
FusedMoEQuantConfig,
|
||||||
fp8_w8a8_moe_quant_config,
|
fp8_w8a8_moe_quant_config,
|
||||||
@ -61,6 +63,23 @@ requires_deep_gemm = pytest.mark.skipif(
|
|||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def with_dp_metadata(M: int, world_size: int):
|
||||||
|
num_tokens_across_dp = torch.tensor([M] * world_size, device="cpu", dtype=torch.int)
|
||||||
|
|
||||||
|
vllm_config = VllmConfig()
|
||||||
|
vllm_config.parallel_config.data_parallel_size = world_size
|
||||||
|
vllm_config.parallel_config.enable_expert_parallel = True
|
||||||
|
|
||||||
|
with set_forward_context(
|
||||||
|
None,
|
||||||
|
vllm_config,
|
||||||
|
num_tokens=M,
|
||||||
|
num_tokens_across_dp=num_tokens_across_dp,
|
||||||
|
):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
def next_power_of_2(x):
|
def next_power_of_2(x):
|
||||||
import math
|
import math
|
||||||
|
|
||||||
@ -285,6 +304,9 @@ def deepep_deepgemm_moe_impl(
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with with_dp_metadata(
|
||||||
|
M=test_tensors.rank_tokens.size(0), world_size=pgi.world_size
|
||||||
|
):
|
||||||
out = mk.forward(
|
out = mk.forward(
|
||||||
hidden_states=test_tensors.rank_tokens,
|
hidden_states=test_tensors.rank_tokens,
|
||||||
w1=w1,
|
w1=w1,
|
||||||
|
|||||||
@ -221,6 +221,10 @@ def get_forward_context() -> ForwardContext:
|
|||||||
return _forward_context
|
return _forward_context
|
||||||
|
|
||||||
|
|
||||||
|
def is_forward_context_available() -> bool:
|
||||||
|
return _forward_context is not None
|
||||||
|
|
||||||
|
|
||||||
def create_forward_context(
|
def create_forward_context(
|
||||||
attn_metadata: Any,
|
attn_metadata: Any,
|
||||||
vllm_config: VllmConfig,
|
vllm_config: VllmConfig,
|
||||||
|
|||||||
@ -5,6 +5,7 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||||
|
from vllm.forward_context import get_forward_context, is_forward_context_available
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||||
@ -19,7 +20,7 @@ from vllm.utils.deep_gemm import (
|
|||||||
get_mk_alignment_for_contiguous_layout,
|
get_mk_alignment_for_contiguous_layout,
|
||||||
is_deep_gemm_e8m0_used,
|
is_deep_gemm_e8m0_used,
|
||||||
)
|
)
|
||||||
from vllm.utils.math_utils import cdiv
|
from vllm.utils.math_utils import cdiv, round_up
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -313,6 +314,33 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
output = (num_experts, max_num_tokens * num_dispatchers, K)
|
output = (num_experts, max_num_tokens * num_dispatchers, K)
|
||||||
return (workspace13, workspace2, output)
|
return (workspace13, workspace2, output)
|
||||||
|
|
||||||
|
def estimate_expected_m(
|
||||||
|
self, global_num_experts: int, max_tokens_per_expert: int, topk: int
|
||||||
|
) -> int:
|
||||||
|
dp_meta = (
|
||||||
|
get_forward_context().dp_metadata
|
||||||
|
if is_forward_context_available()
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if dp_meta is None:
|
||||||
|
logger.warning_once(
|
||||||
|
"DPMetadata unavailable. Defaulting expected_m to "
|
||||||
|
f"{max_tokens_per_expert}.",
|
||||||
|
scope="local",
|
||||||
|
)
|
||||||
|
return max_tokens_per_expert
|
||||||
|
|
||||||
|
total_num_tokens = dp_meta.num_tokens_across_dp_cpu.sum().item()
|
||||||
|
total_num_tokens_replicated = total_num_tokens * topk
|
||||||
|
|
||||||
|
# Assume even load balancing
|
||||||
|
assert global_num_experts != 0
|
||||||
|
estimate = round_up(int(total_num_tokens_replicated // global_num_experts), 16)
|
||||||
|
# clamp estimate
|
||||||
|
estimate = max(estimate, 16)
|
||||||
|
estimate = min(max_tokens_per_expert, estimate)
|
||||||
|
return estimate
|
||||||
|
|
||||||
def apply(
|
def apply(
|
||||||
self,
|
self,
|
||||||
output: torch.Tensor,
|
output: torch.Tensor,
|
||||||
@ -348,10 +376,12 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
|||||||
|
|
||||||
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
|
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
|
||||||
|
|
||||||
# (from deepgemm docs) : A value hint (which is a value on CPU)
|
expected_m = self.estimate_expected_m(
|
||||||
# for the M expectation of each batch, correctly setting this value
|
global_num_experts=global_num_experts,
|
||||||
# may lead to better performance.
|
max_tokens_per_expert=max_num_tokens,
|
||||||
expected_m = max_num_tokens
|
topk=topk_ids.size(-1),
|
||||||
|
)
|
||||||
|
|
||||||
fp8_m_grouped_gemm_nt_masked(
|
fp8_m_grouped_gemm_nt_masked(
|
||||||
(a1q, a1q_scale),
|
(a1q, a1q_scale),
|
||||||
(w1, self.w1_scale),
|
(w1, self.w1_scale),
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user