[Bug] Fix DeepGemm for EP low latency case (#20833)

Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
Wentao Ye 2025-07-12 02:05:12 -04:00 committed by GitHub
parent f56d2996ca
commit 0d4891cd03
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -11,7 +11,8 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate)
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.triton_utils import tl, triton
from vllm.utils.deep_gemm import fp8_m_grouped_gemm_nt_masked
from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked,
is_blackwell_deep_gemm_used)
logger = init_logger(__name__)
@ -50,6 +51,7 @@ def _silu_mul_fp8_quant_deep_gemm(
eps: tl.constexpr,
fp8_min: tl.constexpr,
fp8_max: tl.constexpr,
use_ue8m0: tl.constexpr,
# Meta ---------------------------------------------------------------
BLOCK: tl.constexpr,
@ -92,7 +94,9 @@ def _silu_mul_fp8_quant_deep_gemm(
y = x * y2
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
y_s = _absmax / fp8_max
scale_raw = _absmax / fp8_max
y_s = tl.math.exp2(tl.ceil(
tl.log2(scale_raw))) if use_ue8m0 else scale_raw
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask)
@ -174,6 +178,7 @@ def silu_mul_fp8_quant_deep_gemm(
eps,
fp8_min,
fp8_max,
is_blackwell_deep_gemm_used(),
BLOCK=group_size,
num_warps=4,
)
@ -290,14 +295,10 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
# may lead to better performance.
expected_m = max_num_tokens
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale),
out=workspace1,
masked_m=expert_num_tokens,
expected_m=expected_m)
workspace1, expert_num_tokens, expected_m)
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
expert_num_tokens)
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale),
out=output,
masked_m=expert_num_tokens,
expected_m=expected_m)
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output,
expert_num_tokens, expected_m)