mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-07 09:49:37 +08:00
batched_deepgemm_contiguous
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
850876a183
commit
1f4472ba5f
@ -80,6 +80,81 @@ def run_batched_deepgemm_masked_fp8(
|
||||
print(output)
|
||||
|
||||
|
||||
def ceil_div(x: int, y: int) -> int:
|
||||
return (x + y - 1) // y
|
||||
|
||||
|
||||
def align(x: int, y: int) -> int:
|
||||
return ceil_div(x, y) * y
|
||||
|
||||
|
||||
def run_batched_deepgemm_contiguous_bf16(
|
||||
expected_group_batch_size: int,
|
||||
num_groups: int,
|
||||
output_size: int,
|
||||
input_size: int,
|
||||
):
|
||||
actual_ms = [
|
||||
int(expected_group_batch_size * random.uniform(0.7, 1.3))
|
||||
for _ in range(num_groups)
|
||||
]
|
||||
# Magic number in deepseek pacakge
|
||||
# TODO(zhuohan): change it to real deepseek number
|
||||
MK_ALIGNMENT_FOR_CONTIGUOUS_LAYOUT = 128
|
||||
aligned_ms = [
|
||||
align(actual_m, MK_ALIGNMENT_FOR_CONTIGUOUS_LAYOUT) for actual_m in actual_ms
|
||||
]
|
||||
batch_size = sum(aligned_ms)
|
||||
|
||||
weight = torch.randn(
|
||||
num_groups,
|
||||
output_size,
|
||||
input_size,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
x = torch.randn(
|
||||
batch_size,
|
||||
input_size,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
expert_ids = torch.zeros(
|
||||
batch_size,
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
reference_output = torch.zeros(
|
||||
batch_size,
|
||||
output_size,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
start = 0
|
||||
for i in range(num_groups):
|
||||
actual_end = start + actual_ms[i]
|
||||
aligned_end = start + aligned_ms[i]
|
||||
expert_ids[start:actual_end] = i
|
||||
expert_ids[actual_end:aligned_end] = -1
|
||||
reference_output[start:aligned_end] = x[start:aligned_end] @ weight[i].t()
|
||||
start = aligned_end
|
||||
|
||||
output = torch.zeros(
|
||||
batch_size,
|
||||
output_size,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
vllm_deep_gemm.m_grouped_bf16_gemm_nt_contiguous(
|
||||
x,
|
||||
weight,
|
||||
output,
|
||||
expert_ids,
|
||||
)
|
||||
torch.testing.assert_close(output, reference_output)
|
||||
|
||||
|
||||
def run_batched_deepgemm_masked_bf16(
|
||||
expected_group_batch_size: int,
|
||||
num_groups: int,
|
||||
@ -275,6 +350,7 @@ def run_triton_group_gemm_masked_bf16(
|
||||
|
||||
|
||||
# run_batched_deepgemm_masked_fp8(512, 8, 1024, 512)
|
||||
run_batched_deepgemm_masked_bf16(512, 8, 1024, 512)
|
||||
run_triton_group_gemm_contiguous_bf16(512, 8, 1024, 512, 4)
|
||||
run_triton_group_gemm_masked_bf16(512, 8, 1024, 512)
|
||||
run_batched_deepgemm_contiguous_bf16(512, 8, 1024, 512)
|
||||
# run_batched_deepgemm_masked_bf16(512, 8, 1024, 512)
|
||||
# run_triton_group_gemm_contiguous_bf16(512, 8, 1024, 512, 4)
|
||||
# run_triton_group_gemm_masked_bf16(512, 8, 1024, 512)
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user