Signed-off-by: Tyler Michael Smith <tysmith@redhat.com>
This commit is contained in:
Tyler Michael Smith 2025-06-20 15:36:59 +00:00
parent 7a821f0e7f
commit 39d5d33f8f

View File

@ -17,14 +17,12 @@ has_deep_gemm = importlib.util.find_spec("deep_gemm") is not None
@triton.jit @triton.jit
def _silu_mul_fp8_quant_deep_gemm( def _silu_mul_fp8_quant_deep_gemm(
# Pointers ------------------------------------------------------------ # Pointers ------------------------------------------------------------
input_ptr, # *FP32 activations (E, T, 2*H) input_ptr, # 16-bit activations (E, T, 2*H)
y_q_ptr, # *FP8 quantised activations (E, T, H) y_q_ptr, # fp88 quantized activations (E, T, H)
y_s_ptr, # *FP32 scales (E, T, G) y_s_ptr, # 16-bit scales (E, T, G)
counts_ptr, # *INT32 number of tokens per expert (E) counts_ptr, # int32 num tokens per expert (E)
# Sizes --------------------------------------------------------------- # Sizes ---------------------------------------------------------------
E: tl.constexpr, # num_experts
T: tl.constexpr, # max_num_tokens
H: tl.constexpr, # hidden dimension (per output) H: tl.constexpr, # hidden dimension (per output)
GROUP_SIZE: tl.constexpr, # elements per group (usually 128) GROUP_SIZE: tl.constexpr, # elements per group (usually 128)
@ -159,8 +157,6 @@ def silu_mul_fp8_quant_deep_gemm(
y_q, y_q,
y_s, y_s,
tokens_per_expert, tokens_per_expert,
E,
T,
H, H,
group_size, group_size,
stride_i_e, stride_i_e,