From 39d5d33f8f15e823afa9abcb74265c1c474f4563 Mon Sep 17 00:00:00 2001 From: Tyler Michael Smith Date: Fri, 20 Jun 2025 15:36:59 +0000 Subject: [PATCH] tweaks Signed-off-by: Tyler Michael Smith --- .../layers/fused_moe/batched_deep_gemm_moe.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py index a92125a6faba9..fae8d3745fef6 100644 --- a/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py @@ -17,14 +17,12 @@ has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None @triton.jit def _silu_mul_fp8_quant_deep_gemm( # Pointers ------------------------------------------------------------ - input_ptr, # *FP32 activations (E, T, 2*H) - y_q_ptr, # *FP8 quantised activations (E, T, H) - y_s_ptr, # *FP32 scales (E, T, G) - counts_ptr, # *INT32 number of tokens per expert (E) + input_ptr, # 16-bit activations (E, T, 2*H) + y_q_ptr, # fp88 quantized activations (E, T, H) + y_s_ptr, # 16-bit scales (E, T, G) + counts_ptr, # int32 num tokens per expert (E) # Sizes --------------------------------------------------------------- - E: tl.constexpr, # num_experts - T: tl.constexpr, # max_num_tokens H: tl.constexpr, # hidden dimension (per output) GROUP_SIZE: tl.constexpr, # elements per group (usually 128) @@ -159,8 +157,6 @@ def silu_mul_fp8_quant_deep_gemm( y_q, y_s, tokens_per_expert, - E, - T, H, group_size, stride_i_e,