mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-14 00:34:37 +08:00
[DO NOT MERGE] Experiments related to MoE kernels
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
7ef6052804
commit
2d4215b9a2
137
vllm/model_executor/layers/moe/grouped_gemm_no_abstraction.py
Normal file
137
vllm/model_executor/layers/moe/grouped_gemm_no_abstraction.py
Normal file
@ -0,0 +1,137 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
|
||||
os.environ["VLLM_USE_DEEP_GEMM"] = "1"
|
||||
|
||||
import math
|
||||
import random
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.utils import deep_gemm as vllm_deep_gemm
|
||||
|
||||
BLOCK_SIZE = (128, 128)
|
||||
BLOCK_N, BLOCK_K = BLOCK_SIZE
|
||||
|
||||
|
||||
def generate_bf16_and_downcast_to_fp8(shape, device="cuda"):
|
||||
bf16_weight = torch.randn(shape, dtype=torch.bfloat16, device=device)
|
||||
fp8_weight = bf16_weight.to(dtype=torch.float8_e4m3fn)
|
||||
return fp8_weight
|
||||
|
||||
|
||||
def run_batched_deepgemm_fp8(
|
||||
expected_group_batch_size: int,
|
||||
num_groups: int,
|
||||
output_size: int,
|
||||
input_size: int,
|
||||
):
|
||||
weight = generate_bf16_and_downcast_to_fp8((num_groups, output_size, input_size))
|
||||
output_tiles = math.ceil(output_size / BLOCK_N)
|
||||
input_tiles = math.ceil(input_size / BLOCK_K)
|
||||
weight_scale = torch.randn(
|
||||
num_groups,
|
||||
output_tiles,
|
||||
input_tiles,
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
group_batch_size = [
|
||||
int(expected_group_batch_size * random.uniform(0.7, 1.3))
|
||||
for _ in range(num_groups)
|
||||
]
|
||||
batch_size = sum(group_batch_size)
|
||||
group_batch_size = torch.tensor(
|
||||
group_batch_size,
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
x = generate_bf16_and_downcast_to_fp8((num_groups, batch_size, input_size))
|
||||
x_scale = torch.randn(
|
||||
num_groups,
|
||||
batch_size,
|
||||
input_tiles,
|
||||
dtype=torch.float32,
|
||||
device="cuda",
|
||||
)
|
||||
output = torch.zeros(
|
||||
num_groups,
|
||||
batch_size,
|
||||
output_size,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
vllm_deep_gemm.fp8_m_grouped_gemm_nt_masked(
|
||||
(x, x_scale),
|
||||
(weight, weight_scale),
|
||||
output,
|
||||
group_batch_size,
|
||||
expected_group_batch_size,
|
||||
)
|
||||
print(output)
|
||||
|
||||
|
||||
def run_batched_deepgemm_bf16(
|
||||
expected_group_batch_size: int,
|
||||
num_groups: int,
|
||||
output_size: int,
|
||||
input_size: int,
|
||||
):
|
||||
weight = torch.randn(
|
||||
num_groups,
|
||||
output_size,
|
||||
input_size,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
group_batch_size = [
|
||||
int(expected_group_batch_size * random.uniform(0.7, 1.3))
|
||||
for _ in range(num_groups)
|
||||
]
|
||||
batch_size = sum(group_batch_size)
|
||||
group_batch_size = torch.tensor(
|
||||
group_batch_size,
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
x = torch.randn(
|
||||
num_groups,
|
||||
batch_size,
|
||||
input_size,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
ground_truth_output = torch.einsum("bnk, bmk -> bnm", x, weight)
|
||||
output = torch.zeros(
|
||||
num_groups,
|
||||
batch_size,
|
||||
output_size,
|
||||
dtype=torch.bfloat16,
|
||||
device="cuda",
|
||||
)
|
||||
vllm_deep_gemm.bf16_m_grouped_gemm_nt_masked(
|
||||
x,
|
||||
weight,
|
||||
output,
|
||||
group_batch_size,
|
||||
expected_group_batch_size,
|
||||
)
|
||||
for i in range(num_groups):
|
||||
torch.testing.assert_close(
|
||||
output[i, : group_batch_size[i]],
|
||||
ground_truth_output[i, : group_batch_size[i]],
|
||||
)
|
||||
print(
|
||||
(
|
||||
output[i, : group_batch_size[i]]
|
||||
- ground_truth_output[i, : group_batch_size[i]]
|
||||
)
|
||||
.abs()
|
||||
.max()
|
||||
)
|
||||
|
||||
|
||||
run_batched_deepgemm_fp8(512, 8, 1024, 512)
|
||||
run_batched_deepgemm_bf16(512, 8, 1024, 512)
|
||||
@ -69,28 +69,37 @@ def _missing(*_: Any, **__: Any) -> NoReturn:
|
||||
|
||||
|
||||
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
|
||||
_grouped_impl: Callable[..., Any] | None = None
|
||||
_grouped_masked_impl: Callable[..., Any] | None = None
|
||||
_fp8_grouped_impl: Callable[..., Any] | None = None
|
||||
_fp8_grouped_masked_impl: Callable[..., Any] | None = None
|
||||
_fp8_mqa_logits_impl: Callable[..., Any] | None = None
|
||||
_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
|
||||
_bf16_grouped_impl: Callable[..., Any] | None = None
|
||||
_bf16_grouped_masked_impl: Callable[..., Any] | None = None
|
||||
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
|
||||
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
|
||||
|
||||
|
||||
def _lazy_init() -> None:
|
||||
"""Import deep_gemm and resolve symbols on first use."""
|
||||
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
|
||||
global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl
|
||||
global _fp8_gemm_nt_impl
|
||||
global _fp8_grouped_impl
|
||||
global _fp8_grouped_masked_impl
|
||||
global _fp8_mqa_logits_impl
|
||||
global _fp8_paged_mqa_logits_impl
|
||||
global _bf16_grouped_impl
|
||||
global _bf16_grouped_masked_impl
|
||||
global _get_paged_mqa_logits_metadata_impl
|
||||
global _get_mn_major_tma_aligned_tensor_impl
|
||||
|
||||
# fast path
|
||||
if (
|
||||
_fp8_gemm_nt_impl is not None
|
||||
or _grouped_impl is not None
|
||||
or _grouped_masked_impl is not None
|
||||
or _fp8_grouped_impl is not None
|
||||
or _fp8_grouped_masked_impl is not None
|
||||
or _fp8_mqa_logits_impl is not None
|
||||
or _fp8_paged_mqa_logits_impl is not None
|
||||
or _bf16_grouped_impl is not None
|
||||
or _bf16_grouped_masked_impl is not None
|
||||
or _get_paged_mqa_logits_metadata_impl is not None
|
||||
):
|
||||
return
|
||||
@ -108,8 +117,10 @@ def _lazy_init() -> None:
|
||||
_dg = importlib.import_module("deep_gemm")
|
||||
|
||||
_fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
|
||||
_grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
|
||||
_grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None)
|
||||
_fp8_grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
|
||||
_fp8_grouped_masked_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_masked", None)
|
||||
_bf16_grouped_impl = getattr(_dg, "m_grouped_bf16_gemm_nt_contiguous", None)
|
||||
_bf16_grouped_masked_impl = getattr(_dg, "m_grouped_bf16_gemm_nt_masked", None)
|
||||
_fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None)
|
||||
_fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None)
|
||||
_get_paged_mqa_logits_metadata_impl = getattr(
|
||||
@ -148,22 +159,36 @@ def fp8_gemm_nt(*args, **kwargs):
|
||||
|
||||
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
|
||||
_lazy_init()
|
||||
if _grouped_impl is None:
|
||||
if _fp8_grouped_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _grouped_impl(
|
||||
return _fp8_grouped_impl(
|
||||
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
|
||||
)
|
||||
|
||||
|
||||
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
|
||||
_lazy_init()
|
||||
if _grouped_masked_impl is None:
|
||||
if _fp8_grouped_masked_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _grouped_masked_impl(
|
||||
return _fp8_grouped_masked_impl(
|
||||
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
|
||||
)
|
||||
|
||||
|
||||
def m_grouped_bf16_gemm_nt_contiguous(*args, **kwargs):
|
||||
_lazy_init()
|
||||
if _bf16_grouped_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _bf16_grouped_impl(*args, **kwargs)
|
||||
|
||||
|
||||
def bf16_m_grouped_gemm_nt_masked(*args, **kwargs):
|
||||
_lazy_init()
|
||||
if _bf16_grouped_masked_impl is None:
|
||||
return _missing(*args, **kwargs)
|
||||
return _bf16_grouped_masked_impl(*args, **kwargs)
|
||||
|
||||
|
||||
def fp8_mqa_logits(
|
||||
q: torch.Tensor,
|
||||
kv: tuple[torch.Tensor, torch.Tensor],
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user