mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-28 21:17:05 +08:00
[DO NOT MERGE] Experiments related to MoE kernels
Signed-off-by: Zhuohan Li <zhuohan123@gmail.com>
This commit is contained in:
parent
7ef6052804
commit
2d4215b9a2
137
vllm/model_executor/layers/moe/grouped_gemm_no_abstraction.py
Normal file
137
vllm/model_executor/layers/moe/grouped_gemm_no_abstraction.py
Normal file
@ -0,0 +1,137 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["VLLM_USE_DEEP_GEMM"] = "1"
|
||||||
|
|
||||||
|
import math
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm.utils import deep_gemm as vllm_deep_gemm
|
||||||
|
|
||||||
|
BLOCK_SIZE = (128, 128)
|
||||||
|
BLOCK_N, BLOCK_K = BLOCK_SIZE
|
||||||
|
|
||||||
|
|
||||||
|
def generate_bf16_and_downcast_to_fp8(shape, device="cuda"):
|
||||||
|
bf16_weight = torch.randn(shape, dtype=torch.bfloat16, device=device)
|
||||||
|
fp8_weight = bf16_weight.to(dtype=torch.float8_e4m3fn)
|
||||||
|
return fp8_weight
|
||||||
|
|
||||||
|
|
||||||
|
def run_batched_deepgemm_fp8(
|
||||||
|
expected_group_batch_size: int,
|
||||||
|
num_groups: int,
|
||||||
|
output_size: int,
|
||||||
|
input_size: int,
|
||||||
|
):
|
||||||
|
weight = generate_bf16_and_downcast_to_fp8((num_groups, output_size, input_size))
|
||||||
|
output_tiles = math.ceil(output_size / BLOCK_N)
|
||||||
|
input_tiles = math.ceil(input_size / BLOCK_K)
|
||||||
|
weight_scale = torch.randn(
|
||||||
|
num_groups,
|
||||||
|
output_tiles,
|
||||||
|
input_tiles,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
group_batch_size = [
|
||||||
|
int(expected_group_batch_size * random.uniform(0.7, 1.3))
|
||||||
|
for _ in range(num_groups)
|
||||||
|
]
|
||||||
|
batch_size = sum(group_batch_size)
|
||||||
|
group_batch_size = torch.tensor(
|
||||||
|
group_batch_size,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
x = generate_bf16_and_downcast_to_fp8((num_groups, batch_size, input_size))
|
||||||
|
x_scale = torch.randn(
|
||||||
|
num_groups,
|
||||||
|
batch_size,
|
||||||
|
input_tiles,
|
||||||
|
dtype=torch.float32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
output = torch.zeros(
|
||||||
|
num_groups,
|
||||||
|
batch_size,
|
||||||
|
output_size,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
|
vllm_deep_gemm.fp8_m_grouped_gemm_nt_masked(
|
||||||
|
(x, x_scale),
|
||||||
|
(weight, weight_scale),
|
||||||
|
output,
|
||||||
|
group_batch_size,
|
||||||
|
expected_group_batch_size,
|
||||||
|
)
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
|
||||||
|
def run_batched_deepgemm_bf16(
|
||||||
|
expected_group_batch_size: int,
|
||||||
|
num_groups: int,
|
||||||
|
output_size: int,
|
||||||
|
input_size: int,
|
||||||
|
):
|
||||||
|
weight = torch.randn(
|
||||||
|
num_groups,
|
||||||
|
output_size,
|
||||||
|
input_size,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
group_batch_size = [
|
||||||
|
int(expected_group_batch_size * random.uniform(0.7, 1.3))
|
||||||
|
for _ in range(num_groups)
|
||||||
|
]
|
||||||
|
batch_size = sum(group_batch_size)
|
||||||
|
group_batch_size = torch.tensor(
|
||||||
|
group_batch_size,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
x = torch.randn(
|
||||||
|
num_groups,
|
||||||
|
batch_size,
|
||||||
|
input_size,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
ground_truth_output = torch.einsum("bnk, bmk -> bnm", x, weight)
|
||||||
|
output = torch.zeros(
|
||||||
|
num_groups,
|
||||||
|
batch_size,
|
||||||
|
output_size,
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
vllm_deep_gemm.bf16_m_grouped_gemm_nt_masked(
|
||||||
|
x,
|
||||||
|
weight,
|
||||||
|
output,
|
||||||
|
group_batch_size,
|
||||||
|
expected_group_batch_size,
|
||||||
|
)
|
||||||
|
for i in range(num_groups):
|
||||||
|
torch.testing.assert_close(
|
||||||
|
output[i, : group_batch_size[i]],
|
||||||
|
ground_truth_output[i, : group_batch_size[i]],
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
(
|
||||||
|
output[i, : group_batch_size[i]]
|
||||||
|
- ground_truth_output[i, : group_batch_size[i]]
|
||||||
|
)
|
||||||
|
.abs()
|
||||||
|
.max()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
run_batched_deepgemm_fp8(512, 8, 1024, 512)
|
||||||
|
run_batched_deepgemm_bf16(512, 8, 1024, 512)
|
||||||
@ -69,28 +69,37 @@ def _missing(*_: Any, **__: Any) -> NoReturn:
|
|||||||
|
|
||||||
|
|
||||||
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
|
_fp8_gemm_nt_impl: Callable[..., Any] | None = None
|
||||||
_grouped_impl: Callable[..., Any] | None = None
|
_fp8_grouped_impl: Callable[..., Any] | None = None
|
||||||
_grouped_masked_impl: Callable[..., Any] | None = None
|
_fp8_grouped_masked_impl: Callable[..., Any] | None = None
|
||||||
_fp8_mqa_logits_impl: Callable[..., Any] | None = None
|
_fp8_mqa_logits_impl: Callable[..., Any] | None = None
|
||||||
_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
|
_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None
|
||||||
|
_bf16_grouped_impl: Callable[..., Any] | None = None
|
||||||
|
_bf16_grouped_masked_impl: Callable[..., Any] | None = None
|
||||||
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
|
_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None
|
||||||
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
|
_get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
def _lazy_init() -> None:
|
def _lazy_init() -> None:
|
||||||
"""Import deep_gemm and resolve symbols on first use."""
|
"""Import deep_gemm and resolve symbols on first use."""
|
||||||
global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
|
global _fp8_gemm_nt_impl
|
||||||
global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl
|
global _fp8_grouped_impl
|
||||||
|
global _fp8_grouped_masked_impl
|
||||||
|
global _fp8_mqa_logits_impl
|
||||||
|
global _fp8_paged_mqa_logits_impl
|
||||||
|
global _bf16_grouped_impl
|
||||||
|
global _bf16_grouped_masked_impl
|
||||||
global _get_paged_mqa_logits_metadata_impl
|
global _get_paged_mqa_logits_metadata_impl
|
||||||
global _get_mn_major_tma_aligned_tensor_impl
|
global _get_mn_major_tma_aligned_tensor_impl
|
||||||
|
|
||||||
# fast path
|
# fast path
|
||||||
if (
|
if (
|
||||||
_fp8_gemm_nt_impl is not None
|
_fp8_gemm_nt_impl is not None
|
||||||
or _grouped_impl is not None
|
or _fp8_grouped_impl is not None
|
||||||
or _grouped_masked_impl is not None
|
or _fp8_grouped_masked_impl is not None
|
||||||
or _fp8_mqa_logits_impl is not None
|
or _fp8_mqa_logits_impl is not None
|
||||||
or _fp8_paged_mqa_logits_impl is not None
|
or _fp8_paged_mqa_logits_impl is not None
|
||||||
|
or _bf16_grouped_impl is not None
|
||||||
|
or _bf16_grouped_masked_impl is not None
|
||||||
or _get_paged_mqa_logits_metadata_impl is not None
|
or _get_paged_mqa_logits_metadata_impl is not None
|
||||||
):
|
):
|
||||||
return
|
return
|
||||||
@ -108,8 +117,10 @@ def _lazy_init() -> None:
|
|||||||
_dg = importlib.import_module("deep_gemm")
|
_dg = importlib.import_module("deep_gemm")
|
||||||
|
|
||||||
_fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
|
_fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
|
||||||
_grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
|
_fp8_grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
|
||||||
_grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None)
|
_fp8_grouped_masked_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_masked", None)
|
||||||
|
_bf16_grouped_impl = getattr(_dg, "m_grouped_bf16_gemm_nt_contiguous", None)
|
||||||
|
_bf16_grouped_masked_impl = getattr(_dg, "m_grouped_bf16_gemm_nt_masked", None)
|
||||||
_fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None)
|
_fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None)
|
||||||
_fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None)
|
_fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None)
|
||||||
_get_paged_mqa_logits_metadata_impl = getattr(
|
_get_paged_mqa_logits_metadata_impl = getattr(
|
||||||
@ -148,22 +159,36 @@ def fp8_gemm_nt(*args, **kwargs):
|
|||||||
|
|
||||||
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
|
def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs):
|
||||||
_lazy_init()
|
_lazy_init()
|
||||||
if _grouped_impl is None:
|
if _fp8_grouped_impl is None:
|
||||||
return _missing(*args, **kwargs)
|
return _missing(*args, **kwargs)
|
||||||
return _grouped_impl(
|
return _fp8_grouped_impl(
|
||||||
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
|
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
|
def fp8_m_grouped_gemm_nt_masked(*args, **kwargs):
|
||||||
_lazy_init()
|
_lazy_init()
|
||||||
if _grouped_masked_impl is None:
|
if _fp8_grouped_masked_impl is None:
|
||||||
return _missing(*args, **kwargs)
|
return _missing(*args, **kwargs)
|
||||||
return _grouped_masked_impl(
|
return _fp8_grouped_masked_impl(
|
||||||
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
|
*args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def m_grouped_bf16_gemm_nt_contiguous(*args, **kwargs):
|
||||||
|
_lazy_init()
|
||||||
|
if _bf16_grouped_impl is None:
|
||||||
|
return _missing(*args, **kwargs)
|
||||||
|
return _bf16_grouped_impl(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def bf16_m_grouped_gemm_nt_masked(*args, **kwargs):
|
||||||
|
_lazy_init()
|
||||||
|
if _bf16_grouped_masked_impl is None:
|
||||||
|
return _missing(*args, **kwargs)
|
||||||
|
return _bf16_grouped_masked_impl(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def fp8_mqa_logits(
|
def fp8_mqa_logits(
|
||||||
q: torch.Tensor,
|
q: torch.Tensor,
|
||||||
kv: tuple[torch.Tensor, torch.Tensor],
|
kv: tuple[torch.Tensor, torch.Tensor],
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user