From dcc6cfb991cd76369aad96e04424f29c8fecdbd8 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Sat, 19 Jul 2025 11:39:51 +0530 Subject: [PATCH] [Kernel][Performance] Tweak MoE Batched silu_mul_fp8_quant_deep_gemm kernel (#21193) Signed-off-by: Varun Sundar Rabindranath Co-authored-by: Varun Sundar Rabindranath --- .../layers/fused_moe/batched_deep_gemm_moe.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 628aa5c7bb068..3ccddb52998b2 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -55,6 +55,7 @@ def _silu_mul_fp8_quant_deep_gemm( # Meta --------------------------------------------------------------- BLOCK: tl.constexpr, + NUM_STAGES: tl.constexpr, ): G = H // GROUP_SIZE @@ -73,8 +74,7 @@ def _silu_mul_fp8_quant_deep_gemm( cols = cols.to(tl.int64) mask_h = cols < BLOCK - t = tl.zeros([], tl.int64) - while t < n_tokens: + for t in tl.range(0, n_tokens, num_stages=NUM_STAGES): base_i_offset = (e * stride_i_e + t * stride_i_t + g * GROUP_SIZE * stride_i_h) base_yq_offset = (e * stride_yq_e + t * stride_yq_t + @@ -102,8 +102,6 @@ def _silu_mul_fp8_quant_deep_gemm( tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask) tl.store(y_s_ptr + base_ys_offset, y_s) - t += 1 - def silu_mul_fp8_quant_deep_gemm( y: torch.Tensor, # (E, T, 2*H) float32 @@ -180,7 +178,8 @@ def silu_mul_fp8_quant_deep_gemm( fp8_max, is_blackwell_deep_gemm_used(), BLOCK=group_size, - num_warps=4, + NUM_STAGES=8, + num_warps=1, ) return y_q, y_s