From 2d4215b9a2a16fb2bdc1cc904d1ea100f031081b Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Tue, 14 Oct 2025 17:27:11 -0700 Subject: [PATCH] [DO NOT MERGE] Experiments related to MoE kernels Signed-off-by: Zhuohan Li --- .../layers/moe/grouped_gemm_no_abstraction.py | 137 ++++++++++++++++++ vllm/utils/deep_gemm.py | 49 +++++-- 2 files changed, 174 insertions(+), 12 deletions(-) create mode 100644 vllm/model_executor/layers/moe/grouped_gemm_no_abstraction.py diff --git a/vllm/model_executor/layers/moe/grouped_gemm_no_abstraction.py b/vllm/model_executor/layers/moe/grouped_gemm_no_abstraction.py new file mode 100644 index 0000000000000..3e91e1a1175ef --- /dev/null +++ b/vllm/model_executor/layers/moe/grouped_gemm_no_abstraction.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os + +os.environ["VLLM_USE_DEEP_GEMM"] = "1" + +import math +import random + +import torch + +from vllm.utils import deep_gemm as vllm_deep_gemm + +BLOCK_SIZE = (128, 128) +BLOCK_N, BLOCK_K = BLOCK_SIZE + + +def generate_bf16_and_downcast_to_fp8(shape, device="cuda"): + bf16_weight = torch.randn(shape, dtype=torch.bfloat16, device=device) + fp8_weight = bf16_weight.to(dtype=torch.float8_e4m3fn) + return fp8_weight + + +def run_batched_deepgemm_fp8( + expected_group_batch_size: int, + num_groups: int, + output_size: int, + input_size: int, +): + weight = generate_bf16_and_downcast_to_fp8((num_groups, output_size, input_size)) + output_tiles = math.ceil(output_size / BLOCK_N) + input_tiles = math.ceil(input_size / BLOCK_K) + weight_scale = torch.randn( + num_groups, + output_tiles, + input_tiles, + dtype=torch.float32, + device="cuda", + ) + group_batch_size = [ + int(expected_group_batch_size * random.uniform(0.7, 1.3)) + for _ in range(num_groups) + ] + batch_size = sum(group_batch_size) + group_batch_size = torch.tensor( + group_batch_size, + dtype=torch.int32, + device="cuda", + ) + x = generate_bf16_and_downcast_to_fp8((num_groups, batch_size, input_size)) + x_scale = torch.randn( + num_groups, + batch_size, + input_tiles, + dtype=torch.float32, + device="cuda", + ) + output = torch.zeros( + num_groups, + batch_size, + output_size, + dtype=torch.bfloat16, + device="cuda", + ) + + vllm_deep_gemm.fp8_m_grouped_gemm_nt_masked( + (x, x_scale), + (weight, weight_scale), + output, + group_batch_size, + expected_group_batch_size, + ) + print(output) + + +def run_batched_deepgemm_bf16( + expected_group_batch_size: int, + num_groups: int, + output_size: int, + input_size: int, +): + weight = torch.randn( + num_groups, + output_size, + input_size, + dtype=torch.bfloat16, + device="cuda", + ) + group_batch_size = [ + int(expected_group_batch_size * random.uniform(0.7, 1.3)) + for _ in range(num_groups) + ] + batch_size = sum(group_batch_size) + group_batch_size = torch.tensor( + group_batch_size, + dtype=torch.int32, + device="cuda", + ) + x = torch.randn( + num_groups, + batch_size, + input_size, + dtype=torch.bfloat16, + device="cuda", + ) + ground_truth_output = torch.einsum("bnk, bmk -> bnm", x, weight) + output = torch.zeros( + num_groups, + batch_size, + output_size, + dtype=torch.bfloat16, + device="cuda", + ) + vllm_deep_gemm.bf16_m_grouped_gemm_nt_masked( + x, + weight, + output, + group_batch_size, + expected_group_batch_size, + ) + for i in range(num_groups): + torch.testing.assert_close( + output[i, : group_batch_size[i]], + ground_truth_output[i, : group_batch_size[i]], + ) + print( + ( + output[i, : group_batch_size[i]] + - ground_truth_output[i, : group_batch_size[i]] + ) + .abs() + .max() + ) + + +run_batched_deepgemm_fp8(512, 8, 1024, 512) +run_batched_deepgemm_bf16(512, 8, 1024, 512) diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 39ffba3137df8..12a8edbce0335 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -69,28 +69,37 @@ def _missing(*_: Any, **__: Any) -> NoReturn: _fp8_gemm_nt_impl: Callable[..., Any] | None = None -_grouped_impl: Callable[..., Any] | None = None -_grouped_masked_impl: Callable[..., Any] | None = None +_fp8_grouped_impl: Callable[..., Any] | None = None +_fp8_grouped_masked_impl: Callable[..., Any] | None = None _fp8_mqa_logits_impl: Callable[..., Any] | None = None _fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None +_bf16_grouped_impl: Callable[..., Any] | None = None +_bf16_grouped_masked_impl: Callable[..., Any] | None = None _get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None _get_mn_major_tma_aligned_tensor_impl: Callable[..., Any] | None = None def _lazy_init() -> None: """Import deep_gemm and resolve symbols on first use.""" - global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl - global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl + global _fp8_gemm_nt_impl + global _fp8_grouped_impl + global _fp8_grouped_masked_impl + global _fp8_mqa_logits_impl + global _fp8_paged_mqa_logits_impl + global _bf16_grouped_impl + global _bf16_grouped_masked_impl global _get_paged_mqa_logits_metadata_impl global _get_mn_major_tma_aligned_tensor_impl # fast path if ( _fp8_gemm_nt_impl is not None - or _grouped_impl is not None - or _grouped_masked_impl is not None + or _fp8_grouped_impl is not None + or _fp8_grouped_masked_impl is not None or _fp8_mqa_logits_impl is not None or _fp8_paged_mqa_logits_impl is not None + or _bf16_grouped_impl is not None + or _bf16_grouped_masked_impl is not None or _get_paged_mqa_logits_metadata_impl is not None ): return @@ -108,8 +117,10 @@ def _lazy_init() -> None: _dg = importlib.import_module("deep_gemm") _fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None) - _grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None) - _grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None) + _fp8_grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None) + _fp8_grouped_masked_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_masked", None) + _bf16_grouped_impl = getattr(_dg, "m_grouped_bf16_gemm_nt_contiguous", None) + _bf16_grouped_masked_impl = getattr(_dg, "m_grouped_bf16_gemm_nt_masked", None) _fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None) _fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None) _get_paged_mqa_logits_metadata_impl = getattr( @@ -148,22 +159,36 @@ def fp8_gemm_nt(*args, **kwargs): def m_grouped_fp8_gemm_nt_contiguous(*args, **kwargs): _lazy_init() - if _grouped_impl is None: + if _fp8_grouped_impl is None: return _missing(*args, **kwargs) - return _grouped_impl( + return _fp8_grouped_impl( *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs ) def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): _lazy_init() - if _grouped_masked_impl is None: + if _fp8_grouped_masked_impl is None: return _missing(*args, **kwargs) - return _grouped_masked_impl( + return _fp8_grouped_masked_impl( *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs ) +def m_grouped_bf16_gemm_nt_contiguous(*args, **kwargs): + _lazy_init() + if _bf16_grouped_impl is None: + return _missing(*args, **kwargs) + return _bf16_grouped_impl(*args, **kwargs) + + +def bf16_m_grouped_gemm_nt_masked(*args, **kwargs): + _lazy_init() + if _bf16_grouped_masked_impl is None: + return _missing(*args, **kwargs) + return _bf16_grouped_masked_impl(*args, **kwargs) + + def fp8_mqa_logits( q: torch.Tensor, kv: tuple[torch.Tensor, torch.Tensor],