mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 15:41:54 +08:00
[Bug] Fix DeepGemm for EP low latency case (#20833)
Signed-off-by: yewentao256 <zhyanwentao@126.com>
This commit is contained in:
parent
f56d2996ca
commit
0d4891cd03
@ -11,7 +11,8 @@ from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceDelegate)
|
||||
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
|
||||
from vllm.triton_utils import tl, triton
|
||||
from vllm.utils.deep_gemm import fp8_m_grouped_gemm_nt_masked
|
||||
from vllm.utils.deep_gemm import (fp8_m_grouped_gemm_nt_masked,
|
||||
is_blackwell_deep_gemm_used)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -50,6 +51,7 @@ def _silu_mul_fp8_quant_deep_gemm(
|
||||
eps: tl.constexpr,
|
||||
fp8_min: tl.constexpr,
|
||||
fp8_max: tl.constexpr,
|
||||
use_ue8m0: tl.constexpr,
|
||||
|
||||
# Meta ---------------------------------------------------------------
|
||||
BLOCK: tl.constexpr,
|
||||
@ -92,7 +94,9 @@ def _silu_mul_fp8_quant_deep_gemm(
|
||||
y = x * y2
|
||||
|
||||
_absmax = tl.maximum(tl.max(tl.abs(y)), eps)
|
||||
y_s = _absmax / fp8_max
|
||||
scale_raw = _absmax / fp8_max
|
||||
y_s = tl.math.exp2(tl.ceil(
|
||||
tl.log2(scale_raw))) if use_ue8m0 else scale_raw
|
||||
y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
|
||||
|
||||
tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask)
|
||||
@ -174,6 +178,7 @@ def silu_mul_fp8_quant_deep_gemm(
|
||||
eps,
|
||||
fp8_min,
|
||||
fp8_max,
|
||||
is_blackwell_deep_gemm_used(),
|
||||
BLOCK=group_size,
|
||||
num_warps=4,
|
||||
)
|
||||
@ -290,14 +295,10 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
# may lead to better performance.
|
||||
expected_m = max_num_tokens
|
||||
fp8_m_grouped_gemm_nt_masked((a1q, a1q_scale), (w1, w1_scale),
|
||||
out=workspace1,
|
||||
masked_m=expert_num_tokens,
|
||||
expected_m=expected_m)
|
||||
workspace1, expert_num_tokens, expected_m)
|
||||
|
||||
a2q, a2q_scale = silu_mul_fp8_quant_deep_gemm(workspace1,
|
||||
expert_num_tokens)
|
||||
|
||||
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale),
|
||||
out=output,
|
||||
masked_m=expert_num_tokens,
|
||||
expected_m=expected_m)
|
||||
fp8_m_grouped_gemm_nt_masked((a2q, a2q_scale), (w2, w2_scale), output,
|
||||
expert_num_tokens, expected_m)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user