batched_deepgemm_contiguous

Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
Zhuohan Li 2025-10-15 16:54:53 -07:00
parent 850876a183
commit 1f4472ba5f

View File

@ -80,6 +80,81 @@ def run_batched_deepgemm_masked_fp8(
print(output)
def ceil_div(x: int, y: int) -> int:
return (x + y - 1) // y
def align(x: int, y: int) -> int:
return ceil_div(x, y) * y
def run_batched_deepgemm_contiguous_bf16(
expected_group_batch_size: int,
num_groups: int,
output_size: int,
input_size: int,
):
actual_ms = [
int(expected_group_batch_size * random.uniform(0.7, 1.3))
for _ in range(num_groups)
]
# Magic number in deepseek pacakge
# TODO(zhuohan): change it to real deepseek number
MK_ALIGNMENT_FOR_CONTIGUOUS_LAYOUT = 128
aligned_ms = [
align(actual_m, MK_ALIGNMENT_FOR_CONTIGUOUS_LAYOUT) for actual_m in actual_ms
]
batch_size = sum(aligned_ms)
weight = torch.randn(
num_groups,
output_size,
input_size,
dtype=torch.bfloat16,
device="cuda",
)
x = torch.randn(
batch_size,
input_size,
dtype=torch.bfloat16,
device="cuda",
)
expert_ids = torch.zeros(
batch_size,
dtype=torch.int32,
device="cuda",
)
reference_output = torch.zeros(
batch_size,
output_size,
dtype=torch.bfloat16,
device="cuda",
)
start = 0
for i in range(num_groups):
actual_end = start + actual_ms[i]
aligned_end = start + aligned_ms[i]
expert_ids[start:actual_end] = i
expert_ids[actual_end:aligned_end] = -1
reference_output[start:aligned_end] = x[start:aligned_end] @ weight[i].t()
start = aligned_end
output = torch.zeros(
batch_size,
output_size,
dtype=torch.bfloat16,
device="cuda",
)
vllm_deep_gemm.m_grouped_bf16_gemm_nt_contiguous(
x,
weight,
output,
expert_ids,
)
torch.testing.assert_close(output, reference_output)
def run_batched_deepgemm_masked_bf16(
expected_group_batch_size: int,
num_groups: int,
@ -275,6 +350,7 @@ def run_triton_group_gemm_masked_bf16(
# run_batched_deepgemm_masked_fp8(512, 8, 1024, 512)
run_batched_deepgemm_masked_bf16(512, 8, 1024, 512)
run_triton_group_gemm_contiguous_bf16(512, 8, 1024, 512, 4)
run_triton_group_gemm_masked_bf16(512, 8, 1024, 512)
run_batched_deepgemm_contiguous_bf16(512, 8, 1024, 512)
# run_batched_deepgemm_masked_bf16(512, 8, 1024, 512)
# run_triton_group_gemm_contiguous_bf16(512, 8, 1024, 512, 4)
# run_triton_group_gemm_masked_bf16(512, 8, 1024, 512)