mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-06-08 00:49:08 +08:00
batched_deepgemm_contiguous
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
850876a183
commit
1f4472ba5f
@ -80,6 +80,81 @@ def run_batched_deepgemm_masked_fp8(
|
|||||||
print(output)
|
print(output)
|
||||||
|
|
||||||
|
|
||||||
|
def ceil_div(x: int, y: int) -> int:
|
||||||
|
return (x + y - 1) // y
|
||||||
|
|
||||||
|
|
||||||
|
def align(x: int, y: int) -> int:
|
||||||
|
return ceil_div(x, y) * y
|
||||||
|
|
||||||
|
|
||||||
|
def run_batched_deepgemm_contiguous_bf16(
|
||||||
|
expected_group_batch_size: int,
|
||||||
|
num_groups: int,
|
||||||
|
output_size: int,
|
||||||
|
input_size: int,
|
||||||
|
):
|
||||||
|
actual_ms = [
|
||||||
|
int(expected_group_batch_size * random.uniform(0.7, 1.3))
|
||||||
|
for _ in range(num_groups)
|
||||||
|
]
|
||||||
|
# Magic number in deepseek pacakge
|
||||||
|
# TODO(zhuohan): change it to real deepseek number
|
||||||
|
MK_ALIGNMENT_FOR_CONTIGUOUS_LAYOUT = 128
|
||||||
|
aligned_ms = [
|
||||||
|
align(actual_m, MK_ALIGNMENT_FOR_CONTIGUOUS_LAYOUT) for actual_m in actual_ms
|
||||||
|
]
|
||||||
|
batch_size = sum(aligned_ms)
|
||||||
|
|
||||||
|
weight = torch.randn(
|
||||||
|
num_groups,
|
||||||
|
output_size,
|
||||||
|
input_size,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
x = torch.randn(
|
||||||
|
batch_size,
|
||||||
|
input_size,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
expert_ids = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
reference_output = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
output_size,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
start = 0
|
||||||
|
for i in range(num_groups):
|
||||||
|
actual_end = start + actual_ms[i]
|
||||||
|
aligned_end = start + aligned_ms[i]
|
||||||
|
expert_ids[start:actual_end] = i
|
||||||
|
expert_ids[actual_end:aligned_end] = -1
|
||||||
|
reference_output[start:aligned_end] = x[start:aligned_end] @ weight[i].t()
|
||||||
|
start = aligned_end
|
||||||
|
|
||||||
|
output = torch.zeros(
|
||||||
|
batch_size,
|
||||||
|
output_size,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
|
vllm_deep_gemm.m_grouped_bf16_gemm_nt_contiguous(
|
||||||
|
x,
|
||||||
|
weight,
|
||||||
|
output,
|
||||||
|
expert_ids,
|
||||||
|
)
|
||||||
|
torch.testing.assert_close(output, reference_output)
|
||||||
|
|
||||||
|
|
||||||
def run_batched_deepgemm_masked_bf16(
|
def run_batched_deepgemm_masked_bf16(
|
||||||
expected_group_batch_size: int,
|
expected_group_batch_size: int,
|
||||||
num_groups: int,
|
num_groups: int,
|
||||||
@ -275,6 +350,7 @@ def run_triton_group_gemm_masked_bf16(
|
|||||||
|
|
||||||
|
|
||||||
# run_batched_deepgemm_masked_fp8(512, 8, 1024, 512)
|
# run_batched_deepgemm_masked_fp8(512, 8, 1024, 512)
|
||||||
run_batched_deepgemm_masked_bf16(512, 8, 1024, 512)
|
run_batched_deepgemm_contiguous_bf16(512, 8, 1024, 512)
|
||||||
run_triton_group_gemm_contiguous_bf16(512, 8, 1024, 512, 4)
|
# run_batched_deepgemm_masked_bf16(512, 8, 1024, 512)
|
||||||
run_triton_group_gemm_masked_bf16(512, 8, 1024, 512)
|
# run_triton_group_gemm_contiguous_bf16(512, 8, 1024, 512, 4)
|
||||||
|
# run_triton_group_gemm_masked_bf16(512, 8, 1024, 512)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user