diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index 628aa5c7bb06..3ccddb52998b 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -55,6 +55,7 @@ def _silu_mul_fp8_quant_deep_gemm( # Meta --------------------------------------------------------------- BLOCK: tl.constexpr, + NUM_STAGES: tl.constexpr, ): G = H // GROUP_SIZE @@ -73,8 +74,7 @@ def _silu_mul_fp8_quant_deep_gemm( cols = cols.to(tl.int64) mask_h = cols < BLOCK - t = tl.zeros([], tl.int64) - while t < n_tokens: + for t in tl.range(0, n_tokens, num_stages=NUM_STAGES): base_i_offset = (e * stride_i_e + t * stride_i_t + g * GROUP_SIZE * stride_i_h) base_yq_offset = (e * stride_yq_e + t * stride_yq_t + @@ -102,8 +102,6 @@ def _silu_mul_fp8_quant_deep_gemm( tl.store(y_q_ptr + base_yq_offset + cols * stride_yq_h, y_q, mask=mask) tl.store(y_s_ptr + base_ys_offset, y_s) - t += 1 - def silu_mul_fp8_quant_deep_gemm( y: torch.Tensor, # (E, T, 2*H) float32 @@ -180,7 +178,8 @@ def silu_mul_fp8_quant_deep_gemm( fp8_max, is_blackwell_deep_gemm_used(), BLOCK=group_size, - num_warps=4, + NUM_STAGES=8, + num_warps=1, ) return y_q, y_s