[DO NOT MERGE] Experiments related to MoE kernels

Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
Zhuohan Li 2025-10-14 17:27:11 -07:00
parent 7ef6052804
commit 2d4215b9a2
2 changed files with 174 additions and 12 deletions

View File

@ -0,0 +1,137 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
os.environ["VLLM_USE_DEEP_GEMM"] = "1"
import math
import random
import torch
from vllm.utils import deep_gemm as vllm_deep_gemm
BLOCK_SIZE = (128, 128)
BLOCK_N, BLOCK_K = BLOCK_SIZE
def generate_bf16_and_downcast_to_fp8(shape, device="cuda"):
bf16_weight = torch.randn(shape, dtype=torch.bfloat16, device=device)
fp8_weight = bf16_weight.to(dtype=torch.float8_e4m3fn)
return fp8_weight
def run_batched_deepgemm_fp8(
expected_group_batch_size: int,
num_groups: int,
output_size: int,
input_size: int,
):
weight = generate_bf16_and_downcast_to_fp8((num_groups, output_size, input_size))
output_tiles = math.ceil(output_size / BLOCK_N)
input_tiles = math.ceil(input_size / BLOCK_K)
weight_scale = torch.randn(
num_groups,
output_tiles,
input_tiles,
dtype=torch.float32,
device="cuda",
)
group_batch_size = [
int(expected_group_batch_size * random.uniform(0.7, 1.3))
for _ in range(num_groups)
]
batch_size = sum(group_batch_size)
group_batch_size = torch.tensor(
group_batch_size,
dtype=torch.int32,
device="cuda",
)
x = generate_bf16_and_downcast_to_fp8((num_groups, batch_size, input_size))
x_scale = torch.randn(
num_groups,
batch_size,
input_tiles,
dtype=torch.float32,
device="cuda",
)
output = torch.zeros(
num_groups,
batch_size,
output_size,
dtype=torch.bfloat16,
device="cuda",
)
vllm_deep_gemm.fp8_m_grouped_gemm_nt_masked(
(x, x_scale),
(weight, weight_scale),
output,
group_batch_size,
expected_group_batch_size,
)
print(output)
def run_batched_deepgemm_bf16(
expected_group_batch_size: int,
num_groups: int,
output_size: int,
input_size: int,
):
weight = torch.randn(
num_groups,
output_size,
input_size,
dtype=torch.bfloat16,
device="cuda",
)
group_batch_size = [
int(expected_group_batch_size * random.uniform(0.7, 1.3))
for _ in range(num_groups)
]
batch_size = sum(group_batch_size)
group_batch_size = torch.tensor(
group_batch_size,
dtype=torch.int32,
device="cuda",
)
x = torch.randn(
num_groups,
batch_size,
input_size,
dtype=torch.bfloat16,
device="cuda",
)
ground_truth_output = torch.einsum("bnk, bmk -> bnm", x, weight)
output = torch.zeros(
num_groups,
batch_size,
output_size,
dtype=torch.bfloat16,
device="cuda",
)
vllm_deep_gemm.bf16_m_grouped_gemm_nt_masked(
x,
weight,
output,
group_batch_size,
expected_group_batch_size,
)
for i in range(num_groups):
torch.testing.assert_close(
output[i, : group_batch_size[i]],
ground_truth_output[i, : group_batch_size[i]],
)
print(
(
output[i, : group_batch_size[i]]
- ground_truth_output[i, : group_batch_size[i]]
)
.abs()
.max()
)
run_batched_deepgemm_fp8(512, 8, 1024, 512)
run_batched_deepgemm_bf16(512, 8, 1024, 512)

View File

@ -69,28 +69,37 @@ def _missing(*_: Any, **__: Any) -> NoReturn:
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
_grouped_impl: Callable[..., Any] | None = None
_grouped_masked_impl: Callable[..., Any] | None = None
_fp8_grouped_impl: Callable[..., Any] | None = None
_fp8_grouped_masked_impl: Callable[..., Any] | None = None
_fp8_mqa_logits_impl: Callable[..., Any] | None = None
_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
_bf16_grouped_impl: Callable[..., Any] | None = None
_bf16_grouped_masked_impl: Callable[..., Any] | None = None
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
def _lazy_init() -> None:
"""Import deep_gemm and resolve symbols on first use."""
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl
global _fp8_gemm_nt_impl
global _fp8_grouped_impl
global _fp8_grouped_masked_impl
global _fp8_mqa_logits_impl
global _fp8_paged_mqa_logits_impl
global _bf16_grouped_impl
global _bf16_grouped_masked_impl
global _get_paged_mqa_logits_metadata_impl
global _get_mn_major_tma_aligned_tensor_impl
# fast path
if (
_fp8_gemm_nt_impl is not None
or _grouped_impl is not None
or _grouped_masked_impl is not None
or _fp8_grouped_impl is not None
or _fp8_grouped_masked_impl is not None
or _fp8_mqa_logits_impl is not None
or _fp8_paged_mqa_logits_impl is not None
or _bf16_grouped_impl is not None
or _bf16_grouped_masked_impl is not None
or _get_paged_mqa_logits_metadata_impl is not None
):
return
@ -108,8 +117,10 @@ def _lazy_init() -> None:
_dg = importlib.import_module("deep_gemm")
_fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
_grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
_grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None)
_fp8_grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
_fp8_grouped_masked_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_masked", None)
_bf16_grouped_impl = getattr(_dg, "m_grouped_bf16_gemm_nt_contiguous", None)
_bf16_grouped_masked_impl = getattr(_dg, "m_grouped_bf16_gemm_nt_masked", None)
_fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None)
_fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None)
_get_paged_mqa_logits_metadata_impl = getattr(
@ -148,22 +159,36 @@ def fp8_gemm_nt(*args, **kwargs):
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
_lazy_init()
if _grouped_impl is None:
if _fp8_grouped_impl is None:
return _missing(*args, **kwargs)
return _grouped_impl(
return _fp8_grouped_impl(
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
)
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
_lazy_init()
if _grouped_masked_impl is None:
if _fp8_grouped_masked_impl is None:
return _missing(*args, **kwargs)
return _grouped_masked_impl(
return _fp8_grouped_masked_impl(
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
)
def m_grouped_bf16_gemm_nt_contiguous(*args, **kwargs):
_lazy_init()
if _bf16_grouped_impl is None:
return _missing(*args, **kwargs)
return _bf16_grouped_impl(*args, **kwargs)
def bf16_m_grouped_gemm_nt_masked(*args, **kwargs):
_lazy_init()
if _bf16_grouped_masked_impl is None:
return _missing(*args, **kwargs)
return _bf16_grouped_masked_impl(*args, **kwargs)
def fp8_mqa_logits(
q: torch.Tensor,
kv: tuple[torch.Tensor, torch.Tensor],