From 9cfebf51ba24721bb826b4ffcce5508ddf41edf6 Mon Sep 17 00:00:00 2001 From: Bill Nell Date: Wed, 21 May 2025 23:20:54 +0000 Subject: [PATCH] basic working test Signed-off-by: Bill Nell --- tests/kernels/moe/test_batched_moe.py | 9 +- .../layers/fused_moe/fused_batched_moe.py | 96 +++++++++---------- 2 files changed, 54 insertions(+), 51 deletions(-) diff --git a/tests/kernels/moe/test_batched_moe.py b/tests/kernels/moe/test_batched_moe.py index f225d0ecc313e..87ace11486e4e 100644 --- a/tests/kernels/moe/test_batched_moe.py +++ b/tests/kernels/moe/test_batched_moe.py @@ -198,9 +198,11 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, #B_scale = torch.ones((N, K), dtype=torch.float32, device=tensors.A.device) A_scale = torch.ones(1, dtype=torch.float32, device=tensors.A.device) B_scale = torch.ones(1, dtype=torch.float32, device=tensors.B.device) + quant_block_shape = [1, 1] else: A_scale = None B_scale = None + quant_block_shape = None invoke_moe_batched_triton_kernel( tensors.A, @@ -220,7 +222,9 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, "BLOCK_SIZE_M": block_shape[0], "BLOCK_SIZE_N": block_shape[1], "BLOCK_SIZE_K": block_shape[2], - }) + }, + block_shape=quant_block_shape, + ) ref_output = ref_output.to(dtype=out_dtype) ref_output = ref_impl(tensors.A.to(dtype=out_dtype), @@ -246,5 +250,4 @@ def test_batched_mm(num_experts: int, max_tokens_per_expert: int, K: int, }[test_output.dtype] torch.testing.assert_close(ref_output, ref_output2, atol=atol, rtol=rtol) - if not use_fp8_w8a8: - torch.testing.assert_close(test_output, ref_output2, atol=atol, rtol=rtol) + torch.testing.assert_close(test_output, ref_output2, atol=atol, rtol=rtol) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 8695bd357d315..5047e1afd7a51 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -15,38 +15,38 @@ from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, @triton.jit def moe_mmk( - a_ptrs, - b_ptrs, - K, - expert_id, - a_scale_ptr, - b_scale_ptr, - # The stride variables represent how much to increase the ptr by when - # moving by 1 element in a particular dimension. E.g. `stride_am` is - # how much to increase `a_ptr` by to get the element one row down - # (A has M rows). - stride_ak, - stride_bk, - stride_asm, - stride_ask, - stride_bse, - stride_bsk, - stride_bsn, - # Offsets and masks - offs_m, - offs_n, - mask_m, - # Block size for block-wise quantization - group_n: tl.constexpr, - group_k: tl.constexpr, - # Meta-parameters - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - compute_type: tl.constexpr, - use_w8a8: tl.constexpr, - use_w8a16: tl.constexpr): - + a_ptrs, + b_ptrs, + K, + expert_id, + a_scale_ptr, + b_scale_ptr, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_ak, + stride_bk, + stride_asm, + stride_ask, + stride_bse, + stride_bsk, + stride_bsn, + # Offsets and masks + offs_m, + offs_n, + mask_m, + # Block size for block-wise quantization + group_n: tl.constexpr, + group_k: tl.constexpr, + # Meta-parameters + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + compute_type: tl.constexpr, + use_w8a8: tl.constexpr, + use_w8a16: tl.constexpr +): offs_k = tl.arange(0, BLOCK_K) if use_w8a16: @@ -310,22 +310,22 @@ def batched_triton_kernel( def invoke_moe_batched_triton_kernel( - A: torch.Tensor, # [E, max_tokens, K] - B: torch.Tensor, # [E, K, N] - C: torch.Tensor, # [E, max_tokens, N] - expert_num_tokens: torch.Tensor, # [E] - compute_type: tl.dtype, - # Quantization data - A_scale: Optional[torch.Tensor], - B_scale: Optional[torch.Tensor], - B_zp: torch.Tensor, - # Quantization schemes - use_fp8_w8a8: bool, - use_int8_w8a16: bool, - use_int4_w4a16: bool, - config: dict[str, int], - block_shape: Optional[list[int]] = None): - + A: torch.Tensor, # [E, max_tokens, K] + B: torch.Tensor, # [E, K, N] + C: torch.Tensor, # [E, max_tokens, N] + expert_num_tokens: torch.Tensor, # [E] + compute_type: tl.dtype, + # Quantization data + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + B_zp: torch.Tensor, + # Quantization schemes + use_fp8_w8a8: bool, + use_int8_w8a16: bool, + use_int4_w4a16: bool, + config: dict[str, int], + block_shape: Optional[list[int]] = None +): assert not use_int4_w4a16 max_num_tokens = A.size(1) K = A.size(2)