From 1f4472ba5fe9ceb8868f87a21f1474ea987b3bdb Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Wed, 15 Oct 2025 16:54:53 -0700 Subject: [PATCH] batched_deepgemm_contiguous Signed-off-by: Zhuohan Li --- .../layers/moe/grouped_gemm_no_abstraction.py | 82 ++++++++++++++++++- 1 file changed, 79 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/moe/grouped_gemm_no_abstraction.py b/vllm/model_executor/layers/moe/grouped_gemm_no_abstraction.py index c23d70bda23e0..128201f6f34da 100644 --- a/vllm/model_executor/layers/moe/grouped_gemm_no_abstraction.py +++ b/vllm/model_executor/layers/moe/grouped_gemm_no_abstraction.py @@ -80,6 +80,81 @@ def run_batched_deepgemm_masked_fp8( print(output) +def ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +def align(x: int, y: int) -> int: + return ceil_div(x, y) * y + + +def run_batched_deepgemm_contiguous_bf16( + expected_group_batch_size: int, + num_groups: int, + output_size: int, + input_size: int, +): + actual_ms = [ + int(expected_group_batch_size * random.uniform(0.7, 1.3)) + for _ in range(num_groups) + ] + # Magic number in deepseek pacakge + # TODO(zhuohan): change it to real deepseek number + MK_ALIGNMENT_FOR_CONTIGUOUS_LAYOUT = 128 + aligned_ms = [ + align(actual_m, MK_ALIGNMENT_FOR_CONTIGUOUS_LAYOUT) for actual_m in actual_ms + ] + batch_size = sum(aligned_ms) + + weight = torch.randn( + num_groups, + output_size, + input_size, + dtype=torch.bfloat16, + device="cuda", + ) + x = torch.randn( + batch_size, + input_size, + dtype=torch.bfloat16, + device="cuda", + ) + expert_ids = torch.zeros( + batch_size, + dtype=torch.int32, + device="cuda", + ) + reference_output = torch.zeros( + batch_size, + output_size, + dtype=torch.bfloat16, + device="cuda", + ) + start = 0 + for i in range(num_groups): + actual_end = start + actual_ms[i] + aligned_end = start + aligned_ms[i] + expert_ids[start:actual_end] = i + expert_ids[actual_end:aligned_end] = -1 + reference_output[start:aligned_end] = x[start:aligned_end] @ weight[i].t() + start = aligned_end + + output = torch.zeros( + batch_size, + output_size, + dtype=torch.bfloat16, + device="cuda", + ) + + vllm_deep_gemm.m_grouped_bf16_gemm_nt_contiguous( + x, + weight, + output, + expert_ids, + ) + torch.testing.assert_close(output, reference_output) + + def run_batched_deepgemm_masked_bf16( expected_group_batch_size: int, num_groups: int, @@ -275,6 +350,7 @@ def run_triton_group_gemm_masked_bf16( # run_batched_deepgemm_masked_fp8(512, 8, 1024, 512) -run_batched_deepgemm_masked_bf16(512, 8, 1024, 512) -run_triton_group_gemm_contiguous_bf16(512, 8, 1024, 512, 4) -run_triton_group_gemm_masked_bf16(512, 8, 1024, 512) +run_batched_deepgemm_contiguous_bf16(512, 8, 1024, 512) +# run_batched_deepgemm_masked_bf16(512, 8, 1024, 512) +# run_triton_group_gemm_contiguous_bf16(512, 8, 1024, 512, 4) +# run_triton_group_gemm_masked_bf16(512, 8, 1024, 512)