From fa98d77773c649de05a4bda9847682c80287aa36 Mon Sep 17 00:00:00 2001 From: Varun Sundar Rabindranath Date: Tue, 3 Jun 2025 15:30:02 -0400 Subject: [PATCH] [Kernel] DeepEP dispatch-combine kernel integration (#18434) Signed-off-by: Varun Co-authored-by: Varun Sundar Rabindranath --- csrc/moe/topk_softmax_kernels.cu | 16 +- tests/kernels/moe/__init__.py | 0 tests/kernels/moe/deepep_utils.py | 188 +++++++ tests/kernels/moe/test_deepep_deepgemm_moe.py | 371 ++++++++++++++ tests/kernels/moe/test_deepep_moe.py | 459 ++++++++++++++++++ vllm/config.py | 2 + .../device_communicators/all2all.py | 146 +++++- .../device_communicators/cuda_communicator.py | 8 + vllm/envs.py | 2 + .../layers/fused_moe/deep_gemm_moe.py | 32 +- .../fused_moe/deepep_ht_prepare_finalize.py | 236 +++++++++ .../fused_moe/deepep_ll_prepare_finalize.py | 152 ++++++ .../layers/fused_moe/fused_batched_moe.py | 57 ++- .../layers/fused_moe/fused_moe.py | 2 +- vllm/model_executor/layers/fused_moe/layer.py | 148 ++++-- .../layers/fused_moe/modular_kernel.py | 158 ++++-- .../layers/fused_moe/moe_permute_unpermute.py | 5 +- .../layers/fused_moe/pplx_prepare_finalize.py | 11 +- .../layers/fused_moe/prepare_finalize.py | 12 +- .../layers/fused_moe/triton_deep_gemm_moe.py | 7 +- vllm/model_executor/layers/fused_moe/utils.py | 4 +- .../model_executor/layers/quantization/fp8.py | 41 +- vllm/platforms/cuda.py | 15 + 23 files changed, 1950 insertions(+), 122 deletions(-) create mode 100644 tests/kernels/moe/__init__.py create mode 100644 tests/kernels/moe/deepep_utils.py create mode 100644 tests/kernels/moe/test_deepep_deepgemm_moe.py create mode 100644 tests/kernels/moe/test_deepep_moe.py create mode 100644 vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py create mode 100644 vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py diff --git a/csrc/moe/topk_softmax_kernels.cu b/csrc/moe/topk_softmax_kernels.cu index a9379032245d9..10be47966f611 100644 --- a/csrc/moe/topk_softmax_kernels.cu +++ b/csrc/moe/topk_softmax_kernels.cu @@ -516,9 +516,8 @@ void topk_softmax( topk, stream); } - else + else if (topk_indices.scalar_type() == at::ScalarType::UInt32) { - assert(topk_indices.scalar_type() == at::ScalarType::UInt32); vllm::moe::topkGatingSoftmaxKernelLauncher( gating_output.data_ptr(), topk_weights.data_ptr(), @@ -530,4 +529,17 @@ void topk_softmax( topk, stream); } + else { + assert(topk_indices.scalar_type() == at::ScalarType::Int64); + vllm::moe::topkGatingSoftmaxKernelLauncher( + gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + stream); + } } diff --git a/tests/kernels/moe/__init__.py b/tests/kernels/moe/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/kernels/moe/deepep_utils.py b/tests/kernels/moe/deepep_utils.py new file mode 100644 index 0000000000000..2bc9b657da859 --- /dev/null +++ b/tests/kernels/moe/deepep_utils.py @@ -0,0 +1,188 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +DeepEP test utilities +""" +import dataclasses +import importlib +import traceback +from typing import Callable, Optional + +import torch +from torch.distributed import ProcessGroup +from torch.multiprocessing import ( + spawn) # pyright: ignore[reportPrivateImportUsage] +from typing_extensions import Concatenate, ParamSpec + +has_deep_ep = importlib.util.find_spec("deep_ep") is not None +if has_deep_ep: + from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 + DeepEPHTPrepareAndFinalize) + from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 + DeepEPLLPrepareAndFinalize) + +## Parallel Processes Utils + +P = ParamSpec("P") + + +@dataclasses.dataclass +class ProcessGroupInfo: + world_size: int + world_local_size: int + rank: int + node_rank: int + local_rank: int + device: torch.device + + +def _worker_parallel_launch( + local_rank: int, + world_size: int, + world_local_size: int, + node_rank: int, + init_method: str, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + rank = node_rank * world_local_size + local_rank + torch.cuda.set_device(local_rank) + device = torch.device("cuda", local_rank) + torch.distributed.init_process_group( + backend="cpu:gloo,cuda:nccl", + init_method=init_method, + rank=rank, + world_size=world_size, + device_id=device, + ) + barrier = torch.tensor([rank], device=device) + torch.distributed.all_reduce(barrier) + + try: + worker( + ProcessGroupInfo( + world_size=world_size, + world_local_size=world_local_size, + rank=rank, + node_rank=node_rank, + local_rank=local_rank, + device=device, + ), + *args, + **kwargs, + ) + except Exception as ex: + print(ex) + traceback.print_exc() + raise + finally: + torch.distributed.destroy_process_group() + + +def parallel_launch( + world_size: int, + worker: Callable[Concatenate[ProcessGroupInfo, P], None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: + assert not kwargs + spawn( + _worker_parallel_launch, + args=( + world_size, + world_size, + 0, + "tcp://localhost:29500", + worker, + ) + args, + nprocs=world_size, + join=True, + ) + + +## DeepEP specific utils + + +@dataclasses.dataclass +class DeepEPHTArgs: + num_local_experts: int + + +@dataclasses.dataclass +class DeepEPLLArgs: + max_tokens_per_rank: int + hidden_size: int + num_experts: int + use_fp8_dispatch: bool + + +def make_deepep_ht_a2a(pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + ht_args: DeepEPHTArgs, + q_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None): + + import deep_ep + + # high throughput a2a + num_nvl_bytes = 1024 * 1024 * 1024 # 1GB + num_rdma_bytes, low_latency_mode, num_qps_per_rank = 0, False, 1 + buffer = deep_ep.Buffer(group=pg, + num_nvl_bytes=num_nvl_bytes, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=low_latency_mode, + num_qps_per_rank=num_qps_per_rank) + return DeepEPHTPrepareAndFinalize(buffer=buffer, + world_size=pgi.world_size, + rank=pgi.rank, + dp_size=dp_size, + rank_expert_offset=pgi.rank * + ht_args.num_local_experts, + quant_dtype=q_dtype, + block_shape=block_shape) + + +def make_deepep_ll_a2a(pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + deepep_ll_args: DeepEPLLArgs, + q_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None): + + import deep_ep + + # low-latency a2a + num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( + deepep_ll_args.max_tokens_per_rank, deepep_ll_args.hidden_size, + pgi.world_size, deepep_ll_args.num_experts) + + buffer = deep_ep.Buffer(group=pg, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=deepep_ll_args.num_experts // + pgi.world_size) + return DeepEPLLPrepareAndFinalize( + buffer=buffer, + world_size=pgi.world_size, + dp_size=dp_size, + max_tokens_per_rank=deepep_ll_args.max_tokens_per_rank, + quant_dtype=q_dtype, + use_fp8_dispatch=deepep_ll_args.use_fp8_dispatch, + ) + + +def make_deepep_a2a(pg: ProcessGroup, + pgi: ProcessGroupInfo, + dp_size: int, + deepep_ht_args: Optional[DeepEPHTArgs], + deepep_ll_args: Optional[DeepEPLLArgs], + q_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None): + if deepep_ht_args is not None: + assert deepep_ll_args is None + return make_deepep_ht_a2a(pg, pgi, dp_size, deepep_ht_args, q_dtype, + block_shape) + + assert deepep_ll_args is not None + return make_deepep_ll_a2a(pg, pgi, dp_size, deepep_ll_args, q_dtype) diff --git a/tests/kernels/moe/test_deepep_deepgemm_moe.py b/tests/kernels/moe/test_deepep_deepgemm_moe.py new file mode 100644 index 0000000000000..a1fdc1d5ff47b --- /dev/null +++ b/tests/kernels/moe/test_deepep_deepgemm_moe.py @@ -0,0 +1,371 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Test DeepEP + DeepGEMM integration +""" + +import dataclasses +import importlib +from typing import Optional + +import pytest +import torch.distributed +from torch.distributed import ProcessGroup +from typing_extensions import ParamSpec + +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.fused_moe.fused_moe import fused_experts +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) +from vllm.platforms import current_platform + +from .deepep_utils import ProcessGroupInfo, parallel_launch + +has_deep_ep = importlib.util.find_spec("deep_ep") is not None + +try: + import deep_gemm + has_deep_gemm = True +except ImportError: + has_deep_gemm = False + +if has_deep_ep: + from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 + DeepEPHTPrepareAndFinalize) + + from .deepep_utils import DeepEPHTArgs, make_deepep_a2a + +if has_deep_gemm: + from vllm.model_executor.layers.fused_moe.deep_gemm_moe import ( + DeepGemmExperts) + +requires_deep_ep = pytest.mark.skipif( + not has_deep_ep, + reason="Requires deep_ep kernels", +) + +requires_deep_gemm = pytest.mark.skipif( + not has_deep_gemm, + reason="Requires deep_gemm kernels", +) + +P = ParamSpec("P") + + +def per_block_cast_to_fp8( + x: torch.Tensor, + block_size_n: int = 128) -> tuple[torch.Tensor, torch.Tensor]: + assert x.dim() == 2 + m, n = x.shape + x_padded = torch.zeros( + (deep_gemm.ceil_div(m, 128) * 128, + deep_gemm.ceil_div(n, block_size_n) * block_size_n), + dtype=x.dtype, + device=x.device) + x_padded[:m, :n] = x + x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n) + x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4) + x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn) + x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous() + scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2)) + return x_scaled_sub, scales + + +def make_block_quant_fp8_weights( + e: int, + n: int, + k: int, + block_size: list[int], +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Return weights w1, w2, w1q, w2q, w1_scale, w2_scale + """ + dtype = torch.bfloat16 + + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + w1_bf16 = torch.randn((e, 2 * n, k), dtype=dtype) / 10 + w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype) + + w2_bf16 = torch.randn((e, k, n), dtype=dtype) / 10 + w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype) + + block_n, block_k = block_size[0], block_size[1] + n_tiles_w1 = ((2 * n) + block_n - 1) // block_n + k_tiles_w1 = (k + block_k - 1) // block_k + n_tiles_w2 = (k + block_n - 1) // block_n + k_tiles_w2 = (n + block_k - 1) // block_k + + w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn) + w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn) + + w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1), + device="cuda", + dtype=torch.float32) + w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2), + device="cuda", + dtype=torch.float32) + + assert w1_s.shape == (e, (2 * n + 127) // 128, (k + 127) // 128) + assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2] + + for i in range(e): + w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i]) + w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i]) + + return w1, w2, w1_s, w2_s + + +@dataclasses.dataclass +class TestConfig: + topk: int + m: int + k: int + n: int + num_experts: int + block_size: list[int] + + +@dataclasses.dataclass +class TestTensors: + rank_tokens: torch.Tensor # all ranks make this many tokens + rank_token_scales: Optional[torch.Tensor] + topk: torch.Tensor + topk_weights: torch.Tensor + config: TestConfig + + @staticmethod + def make(config: TestConfig, rank) -> "TestTensors": + + dtype = torch.bfloat16 + topk, m, k, block_size = (config.topk, config.m, config.k, + config.block_size) + + fp8_info = torch.finfo(torch.float8_e4m3fn) + fp8_max, fp8_min = fp8_info.max, fp8_info.min + + rank_tokens = torch.randn( + (m, k), device=torch.cuda.current_device(), dtype=dtype) / 10.0 + rank_tokens = rank_tokens.clamp(min=fp8_min, max=fp8_max) + + block_k = block_size[1] + _, rank_token_scales = per_token_group_quant_fp8(rank_tokens, block_k) + + topk_ids = torch.randint( + low=0, + high=config.num_experts, + size=(m, topk), + device=torch.cuda.current_device()).to(dtype=torch.int64) + + topk_weights = torch.randn(topk_ids.shape, + dtype=torch.float32, + device=torch.cuda.current_device()) + + return TestTensors(rank_tokens=rank_tokens, + rank_token_scales=rank_token_scales, + topk=topk_ids, + topk_weights=topk_weights, + config=config) + + +def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, + num_local_experts: int, q_dtype: Optional[torch.dtype], + block_shape: list[int]) -> FusedMoEModularKernel: + + a2a: DeepEPHTPrepareAndFinalize = make_deepep_a2a( + pg=pg, + pgi=pgi, + dp_size=dp_size, + deepep_ht_args=DeepEPHTArgs(num_local_experts=num_local_experts), + deepep_ll_args=None, + q_dtype=q_dtype, + block_shape=block_shape) + + fused_experts = DeepGemmExperts() + mk = FusedMoEModularKernel(prepare_finalize=a2a, + fused_experts=fused_experts) + return mk + + +def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, dp_size: int, + test_tensors: TestTensors, w1: torch.Tensor, + w2: torch.Tensor, w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + num_experts: int) -> torch.Tensor: + + num_local_experts = w1.size(0) + + def build_expert_map(): + num_local_experts = w1.size(0) + expert_map = torch.full((num_experts, ), + fill_value=-1, + dtype=torch.int32) + s = pgi.rank * num_local_experts + e = s + num_local_experts + expert_map[s:e] = torch.tensor(list(range(num_local_experts))) + return expert_map.to(device=torch.cuda.current_device(), + dtype=torch.int32) + + q_dtype = torch.float8_e4m3fn + + # Make modular kernel + mk: FusedMoEModularKernel = make_modular_kernel( + pg, pgi, dp_size, num_local_experts, q_dtype, + test_tensors.config.block_size) + + a1_scale = test_tensors.rank_token_scales + + out = mk.forward(hidden_states=test_tensors.rank_tokens, + w1=w1, + w2=w2, + topk_weights=test_tensors.topk_weights, + topk_ids=test_tensors.topk, + inplace=False, + activation="silu", + global_num_experts=num_experts, + expert_map=build_expert_map(), + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=None, + w2_zp=None, + a1_scale=a1_scale, + a2_scale=None, + apply_router_weight_on_input=False) + return out + + +def triton_impl(a: torch.Tensor, topk_ids: torch.Tensor, + topk_weights: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor, + w1_scale: torch.Tensor, w2_scale: torch.Tensor, + a1_scale: torch.Tensor, block_shape: list[int]): + + return fused_experts( + hidden_states=a, + w1=w1, + w2=w2, + topk_weights=topk_weights, + topk_ids=topk_ids, + inplace=False, + use_fp8_w8a8=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + block_shape=block_shape, + # Make sure this is set to False so we + # dont end up comparing the same implementation. + allow_deep_gemm=False) + + +def _deep_ep_moe( + pgi: ProcessGroupInfo, + dp_size: int, + config: TestConfig, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, +): + current_platform.seed_everything(pgi.rank) + + w1 = w1.to(device=torch.cuda.current_device()) + w2 = w2.to(device=torch.cuda.current_device()) + w1_scale = w1_scale.to(device=torch.cuda.current_device()) + w2_scale = w2_scale.to(device=torch.cuda.current_device()) + + pg = torch.distributed.new_group(list(range(pgi.world_size))) + test_tensors = TestTensors.make(config, pgi.rank) + block_shape = [ + w1.size(1) // w1_scale.size(1), + w1.size(2) // w1_scale.size(2) + ] + + with set_current_vllm_config(VllmConfig()): + # Reference + triton_moe = triton_impl(a=test_tensors.rank_tokens, + topk_ids=test_tensors.topk, + topk_weights=test_tensors.topk_weights, + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=test_tensors.rank_token_scales, + block_shape=block_shape) + + # Slice experts for this rank. + num_local_experts = config.num_experts // pgi.world_size + e_start = num_local_experts * pgi.rank + e_end = e_start + num_local_experts + w1_ep = w1[e_start:e_end] + w2_ep = w2[e_start:e_end] + w1_scale_ep = w1_scale[e_start:e_end] + w2_scale_ep = w2_scale[e_start:e_end] + + deepep_moe = deep_ep_moe_impl( + pg, + pgi, + dp_size, + test_tensors, + w1_ep, + w2_ep, + w1_scale_ep, + w2_scale_ep, + config.num_experts, + ) + + torch.testing.assert_close( + triton_moe, + deepep_moe, + atol=6e-2, + rtol=6e-2, + ) + + +MNKs = [ + (8, 128, 128), + (8, 128, 512), + (8, 512, 512), + (3, 1024, 2048), + (32, 128, 1024), + (45, 512, 2048), + (64, 1024, 1024), + (129, 128, 256), + (129, 1024, 2048), + (222, 1024, 2048), +] + + +@pytest.mark.parametrize("mnk", MNKs) +@pytest.mark.parametrize("num_experts", [32]) +@pytest.mark.parametrize("topk", [2, 6]) +@pytest.mark.parametrize("world_dp_size", [(2, 1)]) +@requires_deep_ep +@requires_deep_gemm +def test_deep_ep_moe(mnk: tuple[int, int, int], num_experts: int, topk: int, + world_dp_size: tuple[int, int]): + + m, n, k = mnk + current_platform.seed_everything(7) + + if topk > num_experts: + pytest.skip(f"Skipping test: topk={topk} > E={num_experts}") + + block_m = deep_gemm.get_m_alignment_for_contiguous_layout() + block_size = [block_m, block_m] + + world_size, dp_size = world_dp_size + config = TestConfig( + topk=topk, + m=m, + k=k, + n=n, + num_experts=num_experts, + block_size=block_size, + ) + + w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights( + num_experts, n, k, block_size) + + parallel_launch(world_size, _deep_ep_moe, dp_size, config, w1, w2, + w1_scale, w2_scale) diff --git a/tests/kernels/moe/test_deepep_moe.py b/tests/kernels/moe/test_deepep_moe.py new file mode 100644 index 0000000000000..7e029ea950555 --- /dev/null +++ b/tests/kernels/moe/test_deepep_moe.py @@ -0,0 +1,459 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Test deepep dispatch-combine logic +""" + +import dataclasses +import importlib +from typing import Optional, Union + +import pytest +import torch.distributed +from torch.distributed import ProcessGroup + +from vllm import _custom_ops as ops +from vllm.config import VllmConfig, set_current_vllm_config +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe import TritonExperts +from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedTritonExperts) +from vllm.model_executor.layers.fused_moe.modular_kernel import ( + FusedMoEModularKernel) +from vllm.model_executor.layers.quantization.utils.fp8_utils import ( + per_token_group_quant_fp8) +from vllm.platforms import current_platform + +from .deepep_utils import ProcessGroupInfo, parallel_launch + +has_deep_ep = importlib.util.find_spec("deep_ep") is not None + +if has_deep_ep: + from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501 + DeepEPHTPrepareAndFinalize) + from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501 + DeepEPLLPrepareAndFinalize) + + from .deepep_utils import DeepEPHTArgs, DeepEPLLArgs, make_deepep_a2a + +requires_deep_ep = pytest.mark.skipif( + not has_deep_ep, + reason="Requires deep_ep kernels", +) + +MAX_TOKENS_PER_RANK = 64 + + +def make_weights( + e, n, k, dtype +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Return weights w1, w2, w1_scale, w2_scale + """ + if dtype in [torch.float16, torch.bfloat16]: + w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 + return w1, w2, None, None + + # per-out-channel weight quantization + assert dtype == torch.float8_e4m3fn + w1 = torch.empty((e, 2 * n, k), device="cuda", dtype=torch.float16) + w2 = torch.empty((e, k, n), device="cuda", dtype=torch.float16) + + n_b_scales = 2 * n + k_b_scales = k + w1_q = torch.empty_like(w1, dtype=dtype) + w2_q = torch.empty_like(w2, dtype=dtype) + w1_scale = torch.empty((e, n_b_scales, 1), + device="cuda", + dtype=torch.float32) + w2_scale = torch.empty((e, k_b_scales, 1), + device="cuda", + dtype=torch.float32) + for expert in range(e): + w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant( + w1[expert], use_per_token_if_dynamic=True) + w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant( + w2[expert], use_per_token_if_dynamic=True) + return w1_q, w2_q, w1_scale, w2_scale + + +@dataclasses.dataclass +class TestConfig: + dtype: torch.dtype + topk: int + m: int + k: int + n: int + num_experts: int + + +@dataclasses.dataclass +class TestTensors: + rank_tokens: torch.Tensor # all ranks make this many tokens + rank_token_scales: Optional[torch.Tensor] + topk: torch.Tensor + topk_weights: torch.Tensor + config: TestConfig + + @staticmethod + def make(config: TestConfig, low_latency_mode: bool) -> "TestTensors": + # TODO (varun) - check that float16 works ? + assert config.dtype in [torch.bfloat16, torch.float8_e4m3fn] + token_dtype = (torch.bfloat16 if config.dtype == torch.float8_e4m3fn + else config.dtype) + rank_tokens = torch.randn( + (config.m, config.k), device="cuda", dtype=token_dtype) / 10 + rank_token_scales = None + if config.dtype == torch.float8_e4m3fn: + # low_latency_mode kernels dont support per-token quant. + _, rank_token_scales = ops.scaled_fp8_quant( + rank_tokens, use_per_token_if_dynamic=not low_latency_mode) + + topk = torch.randint(low=0, + high=config.num_experts, + size=(config.m, config.topk), + device="cuda").to(dtype=torch.int64) + topk_weights = torch.randn(topk.shape, + dtype=torch.float32, + device="cuda") + return TestTensors(rank_tokens=rank_tokens, + rank_token_scales=rank_token_scales, + topk=topk, + topk_weights=topk_weights, + config=config) + + +def make_modular_kernel(pg: ProcessGroup, pgi: ProcessGroupInfo, + low_latency_mode: bool, hidden_size: int, dp_size: int, + num_experts: int, num_local_experts: int, + q_dtype: Optional[torch.dtype], + use_fp8_dispatch: bool) -> FusedMoEModularKernel: + + is_quantized = q_dtype is not None + + ht_args: Optional[DeepEPHTArgs] = None + ll_args: Optional[DeepEPLLArgs] = None + + if low_latency_mode: + ll_args = DeepEPLLArgs(max_tokens_per_rank=MAX_TOKENS_PER_RANK, + hidden_size=hidden_size, + num_experts=num_experts, + use_fp8_dispatch=use_fp8_dispatch) + else: + assert not use_fp8_dispatch, ( + "FP8 Dispatch is valid only for low-latency kernels") + ht_args = DeepEPHTArgs(num_local_experts=num_local_experts) + + a2a : Union[DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize] = \ + make_deepep_a2a(pg = pg, + pgi = pgi, + dp_size = dp_size, + q_dtype = q_dtype, + block_shape = None, + deepep_ht_args = ht_args, + deepep_ll_args = ll_args) + + if low_latency_mode: + fused_experts = BatchedTritonExperts( + max_num_tokens=MAX_TOKENS_PER_RANK, + world_size=pgi.world_size, + dp_size=dp_size, + use_fp8_w8a8=is_quantized, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False) + else: + fused_experts = TritonExperts(use_fp8_w8a8=is_quantized, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + per_channel_quant=False) + + mk = FusedMoEModularKernel(prepare_finalize=a2a, + fused_experts=fused_experts) + return mk + + +def deep_ep_moe_impl(pg: ProcessGroup, pgi: ProcessGroupInfo, + low_latency_mode: bool, dp_size: int, + test_tensors: TestTensors, w1: torch.Tensor, + w2: torch.Tensor, w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], num_experts: int, + use_fp8_dispatch: bool) -> torch.Tensor: + + num_local_experts = w1.size(0) + + def build_expert_map(): + num_local_experts = w1.size(0) + expert_map = torch.full((num_experts, ), + fill_value=-1, + dtype=torch.int32) + s = pgi.rank * num_local_experts + e = s + num_local_experts + expert_map[s:e] = torch.tensor(list(range(num_local_experts))) + return expert_map.to(device=torch.cuda.current_device(), + dtype=torch.int32) + + hidden_size = test_tensors.rank_tokens.size(1) + is_quantized = w1.dtype == torch.float8_e4m3fn + q_dtype = None + if is_quantized: + q_dtype = torch.float8_e4m3fn + + # Make modular kernel + mk: FusedMoEModularKernel = make_modular_kernel(pg, pgi, low_latency_mode, + hidden_size, dp_size, + num_experts, + num_local_experts, q_dtype, + use_fp8_dispatch) + + out_hidden_states = torch.empty_like(test_tensors.rank_tokens) + total_num_tokens = test_tensors.rank_tokens.size(0) + + def process_chunk(chunk_start, chunk_end, skip_result_store=False): + rank_tokens_chunk = test_tensors.rank_tokens[chunk_start:chunk_end] + topk_weights_chunk = test_tensors.topk_weights[chunk_start:chunk_end] + topk_chunk = test_tensors.topk[chunk_start:chunk_end] + rank_token_scales_chunk = test_tensors.rank_token_scales + if rank_token_scales_chunk is not None and rank_token_scales_chunk.size( + 0) == total_num_tokens: + # per act token + rank_token_scales_chunk = rank_token_scales_chunk[ + chunk_start:chunk_end] + + out = mk.forward(hidden_states=rank_tokens_chunk, + w1=w1, + w2=w2, + topk_weights=topk_weights_chunk, + topk_ids=topk_chunk, + inplace=False, + activation="silu", + global_num_experts=num_experts, + expert_map=build_expert_map(), + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=None, + w2_zp=None, + a1_scale=rank_token_scales_chunk, + a2_scale=None, + apply_router_weight_on_input=False) + + if not skip_result_store: + out_hidden_states[chunk_start:chunk_end, :].copy_( + out, non_blocking=True) + + max_num_tokens_per_dp = (MAX_TOKENS_PER_RANK + if low_latency_mode else total_num_tokens) + + for chunk_start_ in range(0, total_num_tokens, max_num_tokens_per_dp): + chunk_start = chunk_start_ + chunk_end = min(chunk_start + max_num_tokens_per_dp, total_num_tokens) + # clamp start and end + chunk_start = min(chunk_start, total_num_tokens - 1) + chunk_end = min(chunk_end, total_num_tokens) + + process_chunk(chunk_start, + chunk_end, + skip_result_store=chunk_start_ >= total_num_tokens) + + return out_hidden_states + + +def torch_moe_impl(test_tensors: TestTensors, w1: torch.Tensor, + w2: torch.Tensor, w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], using_fp8_dispatch: bool): + + a, topk_ids, topk_weights = (test_tensors.rank_tokens, test_tensors.topk, + test_tensors.topk_weights) + if using_fp8_dispatch: + # The DeepEP implementation is requested to dispatch using FP8. + # For numerical stability for testing, emulate the fp8 dispatch by + # blockwise quant and de-quant. + a = test_tensors.rank_tokens + aq, aq_scale = per_token_group_quant_fp8(a, 128) + a = (aq.view(-1, 128).to(torch.float32) * aq_scale.view(-1, 1)).view( + a.shape).to(a.dtype) + + is_quantized = w1.dtype == torch.float8_e4m3fn + a_dtype = a.dtype + if is_quantized: + w1 = w1.to(dtype=torch.float32) * w1_scale + w2 = w2.to(dtype=torch.float32) * w2_scale + a = a.to(dtype=torch.float32) + + m, _ = a.shape + topk = topk_ids.size(1) + out = torch.zeros_like(a) + + for i in range(m): + a_i = a[i] + o_i = out[i] + for j in range(topk): + e = topk_ids[i][j] + e_w = topk_weights[i][j] + w1_e = w1[e] + w2_e = w2[e] + o_i += (SiluAndMul() + (a_i @ w1_e.transpose(0, 1)) @ w2_e.transpose(0, 1)) * e_w + + if is_quantized: + out = out.to(dtype=a_dtype) + + return out + + +def _deep_ep_moe( + pgi: ProcessGroupInfo, + low_latency_mode: bool, + dp_size: int, + config: TestConfig, + w1: torch.Tensor, + w2: torch.Tensor, + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + use_fp8_dispatch: bool, +): + + if not low_latency_mode: + assert not use_fp8_dispatch, ( + "FP8 dispatch interface is available only in low-latency mode") + + is_quantized = w1.dtype == torch.float8_e4m3fn + w1 = w1.to(device=torch.cuda.current_device()) + w2 = w2.to(device=torch.cuda.current_device()) + if is_quantized: + w1_scale = w1_scale.to( # type: ignore + device=torch.cuda.current_device()) + w2_scale = w2_scale.to( # type: ignore + device=torch.cuda.current_device()) + + pg = torch.distributed.new_group(list(range(pgi.world_size))) + test_tensors = TestTensors.make(config, low_latency_mode) + + with set_current_vllm_config(VllmConfig()): + # Reference + torch_combined = torch_moe_impl(test_tensors, w1, w2, w1_scale, + w2_scale, use_fp8_dispatch) + + # Splice experts for this rank. + num_local_experts = config.num_experts // pgi.world_size + e_start = num_local_experts * pgi.rank + e_end = e_start + num_local_experts + w1_ep = w1[e_start:e_end] + w2_ep = w2[e_start:e_end] + + w1_scale_ep, w2_scale_ep = None, None + if is_quantized: + w1_scale_ep = w1_scale[e_start:e_end] # type: ignore + w2_scale_ep = w2_scale[e_start:e_end] # type: ignore + deepep_combined = deep_ep_moe_impl( + pg, + pgi, + low_latency_mode, + dp_size, + test_tensors, + w1_ep, + w2_ep, + w1_scale_ep, + w2_scale_ep, + config.num_experts, + use_fp8_dispatch, + ) + + torch.testing.assert_close( + torch_combined, + deepep_combined, + atol=6e-2, + rtol=6e-2, + ) + + +MNKs = [ + (1, 128, 128), + (2, 128, 512), + (3, 1024, 2048), + (32, 128, 1024), + (45, 512, 2048), + (64, 1024, 1024), + (222, 1024, 2048), +] + +DTYPES = [torch.bfloat16, torch.float8_e4m3fn] + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("mnk", MNKs) +@pytest.mark.parametrize("num_experts", [32]) +@pytest.mark.parametrize("topk", [6]) +@pytest.mark.parametrize("world_dp_size", [(2, 1)]) +@requires_deep_ep +def test_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], + num_experts: int, topk: int, world_dp_size: tuple[int, + int]): + low_latency_mode = False + use_fp8_dispatch = False + m, n, k = mnk + + current_platform.seed_everything(7) + world_size, dp_size = world_dp_size + config = TestConfig(dtype=dtype, + topk=topk, + m=m, + k=k, + n=n, + num_experts=num_experts) + + w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype) + + parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size, + config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch) + + +MNKs = [ + (1, 128, 2560), + (2, 128, 2560), + (3, 1024, 2560), + (32, 128, 2560), + (45, 512, 2560), + (64, 1024, 2560), + (222, 1024, 2560), +] +DTYPES = [torch.float8_e4m3fn, torch.bfloat16] +USE_FP8_DISPATCH = [True, False] + + +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("mnk", MNKs) +@pytest.mark.parametrize("num_experts", [32]) +@pytest.mark.parametrize("topk", [6]) +@pytest.mark.parametrize("world_dp_size", [(2, 1)]) +@pytest.mark.parametrize("use_fp8_dispatch", USE_FP8_DISPATCH) +@requires_deep_ep +def test_low_latency_deep_ep_moe(dtype: torch.dtype, mnk: tuple[int, int, int], + num_experts: int, topk: int, + world_dp_size: tuple[int, int], + use_fp8_dispatch: bool): + + low_latency_mode = True + m, n, k = mnk + + if (low_latency_mode + and k not in DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES): + pytest.skip( + f"Skipping test as hidden size {k} is not in list of supported " + f"hidden sizes {DeepEPLLPrepareAndFinalize.SUPPORTED_HIDDEN_SIZES}" + ) + + current_platform.seed_everything(7) + world_size, dp_size = world_dp_size + config = TestConfig(dtype=dtype, + topk=topk, + m=m, + k=k, + n=n, + num_experts=num_experts) + + w1, w2, w1_scale, w2_scale = make_weights(num_experts, n, k, dtype) + + parallel_launch(world_size, _deep_ep_moe, low_latency_mode, dp_size, + config, w1, w2, w1_scale, w2_scale, use_fp8_dispatch) diff --git a/vllm/config.py b/vllm/config.py index d99e501ca279a..f6ca9328b8a19 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1856,6 +1856,8 @@ class ParallelConfig: factors.append(self.pipeline_parallel_size) factors.append(self.tensor_parallel_size) factors.append(self.enable_expert_parallel) + factors.append(self.data_parallel_size) + factors.append(envs.VLLM_ALL2ALL_BACKEND) return hashlib.sha256(str(factors).encode()).hexdigest() def __post_init__(self) -> None: diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index ae75902994423..2ab3779ece056 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -1,7 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import importlib.util -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import torch import torch.distributed as dist @@ -129,3 +129,147 @@ class PPLXAll2AllManager(All2AllManagerBase): from pplx_kernels.nvshmem import nvshmem_finalize logger.debug("PPLX NVSHMEM finalize") nvshmem_finalize() + + +class DeepEPAll2AllManagerBase(All2AllManagerBase): + """ + All2All communication based on DeepEP High-Throughput kernels. + """ + + def __init__(self, cpu_group): + has_deepep = importlib.util.find_spec("deep_ep") is not None + assert has_deepep, "DeepEP kernels not found. Please follow https://github.com/vllm-project/vllm/blob/main/tools/ep_kernels/README.md to install DeepEP kernels." # noqa + super().__init__(cpu_group) + self.handle_cache = Cache() + + # This is the DeepEP default. Stick to it till we can establish + # reasonable defaults based on profiling. + self.num_sms = 20 + + def get_handle(self, kwargs): + raise NotImplementedError + + def dispatch(self, hidden_states: torch.Tensor, + router_logits: torch.Tensor): + raise NotImplementedError + + def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + def destroy(self): + pass + + +class DeepEPHTAll2AllManager(DeepEPAll2AllManagerBase): + """ + All2All communication based on DeepEP High-Throughput kernels. + """ + + def __init__(self, cpu_group): + super().__init__(cpu_group) + + def _make_all2all_kwargs(self) -> dict[Any, Any]: + # Defaults for internode and intranode are taken from DeepEP tests. + num_nvl_bytes = 1024 * 1024 * 1024 + num_rdma_bytes = None + num_qps_per_rank = None + + if self.internode: + num_rdma_bytes = 1024 * 1024 * 1024 + num_qps_per_rank = self.num_sms // 2 + else: + assert self.intranode + num_rdma_bytes = 0 + num_qps_per_rank = 1 + + assert num_rdma_bytes is not None + assert num_qps_per_rank is not None + return dict(group=self.cpu_group, + num_nvl_bytes=num_nvl_bytes, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=False, + num_qps_per_rank=num_qps_per_rank) + + def get_handle(self, kwargs): + + assert len(kwargs) == 0, ( + "DeepEPHTAll2AllManager expects no arguments. All the required " + "args are computed in the Manager itself.") + + import deep_ep + buffer_kwargs = self._make_all2all_kwargs() + logger.debug("DeepEP all2all args %s", buffer_kwargs) + handle: deep_ep.Buffer = self.handle_cache.get_or_create( + buffer_kwargs, deep_ep.Buffer) + # It is dangerous to set num sms outside this function. num_sms is not + # a part of the hash-key that identifies this object. If we are in a + # situation where we make objects with different num_sms, the hash key + # in get_or_create must be updated. + handle.set_num_sms(self.num_sms) + return handle + + +class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): + """ + All2All communication based on DeepEP Low-Latency kernels. + """ + + def __init__(self, cpu_group): + super().__init__(cpu_group) + + def _make_all2all_kwargs( + self, + max_num_tokens_per_dp_rank: int, + token_hidden_size: int, + num_ep_ranks: int, + num_global_experts: int, + num_local_experts: int, + ) -> dict[Any, Any]: + """ + max_num_tokens_per_dp_rank : the maximum number of tokens a DP rank + can dispatch all the ranks must hold the same value. + token_hidden_size: the hidden dimension of each token. + num_ep_ranks: the number of EP group ranks. + num_global_experts: Number of experts in the model. + num_local_experts: Number of experts in an EP rank. + """ + import deep_ep + + # Defaults for internode and intranode are taken from DeepEP tests. + num_nvl_bytes = 1024 * 1024 * 1024 + num_qps_per_rank = num_local_experts + num_rdma_bytes = None + + if self.internode: + num_rdma_bytes = 1024 * 1024 * 1024 + else: + assert self.intranode + num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( + num_max_dispatch_tokens_per_rank=max_num_tokens_per_dp_rank, + hidden=token_hidden_size, + num_ranks=num_ep_ranks, + num_experts=num_global_experts) + + assert num_rdma_bytes is not None + return dict(group=self.cpu_group, + num_nvl_bytes=num_nvl_bytes, + num_rdma_bytes=num_rdma_bytes, + low_latency_mode=True, + num_qps_per_rank=num_qps_per_rank) + + def get_handle(self, kwargs): + """ + The kwargs for DeepEPLLAll2AllManager is dictated by + _make_all2all_kwargs. + """ + import deep_ep + buffer_kwargs = self._make_all2all_kwargs(**kwargs) + logger.debug("DeepEP all2all args %s", buffer_kwargs) + handle: deep_ep.Buffer = self.handle_cache.get_or_create( + buffer_kwargs, deep_ep.Buffer) + # It is dangerous to set num sms outside this function. num_sms is not + # a part of the hash-key that identifies this object. If we are in a + # situation where we make objects with different num_sms, the hash key + # in get_or_create must be updated. + handle.set_num_sms(self.num_sms) + return handle diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index 0eebdf8736ce2..055d91690e676 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -67,6 +67,14 @@ class CudaCommunicator(DeviceCommunicatorBase): from .all2all import PPLXAll2AllManager self.all2all_manager = PPLXAll2AllManager(self.cpu_group) logger.info("Using PPLX all2all manager.") + elif all2all_backend == "deepep_high_throughput": + from .all2all import DeepEPHTAll2AllManager + self.all2all_manager = DeepEPHTAll2AllManager(self.cpu_group) + logger.info("Using DeepEP High-Throughput all2all manager.") + elif all2all_backend == "deepep_low_latency": + from .all2all import DeepEPLLAll2AllManager + self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group) + logger.info("Using DeepEP Low-Latency all2all manager.") else: raise ValueError(f"Unknown all2all backend: {all2all_backend}") diff --git a/vllm/envs.py b/vllm/envs.py index 2e3d6eeb57e8a..08bf2dad44554 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -826,6 +826,8 @@ environment_variables: dict[str, Callable[[], Any]] = { # Available options: # - "naive": naive all2all implementation using all-reduce # - "pplx": use pplx kernels + # - "deepep_high_throughput", use deepep high-throughput kernels + # - "deepep_low_latency", use deepep low-latency kernels "VLLM_ALL2ALL_BACKEND": lambda: os.getenv("VLLM_ALL2ALL_BACKEND", "naive"), diff --git a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py index 331544d64ff83..97b4a49c064eb 100644 --- a/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/deep_gemm_moe.py @@ -12,8 +12,8 @@ from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import ( _moe_permute) from vllm.model_executor.layers.fused_moe.prepare_finalize import ( MoEPrepareAndFinalizeNoEP) -from vllm.model_executor.layers.fused_moe.utils import (_fp8_quantize, - _resize_cache) +from vllm.model_executor.layers.fused_moe.utils import ( + _resize_cache, per_token_group_quant_fp8) from vllm.utils import round_up logger = init_logger(__name__) @@ -34,10 +34,8 @@ def _valid_deep_gemm_shape(M: int, N: int, K: int): return align <= M and N % align == 0 and K % align == 0 -def _valid_deep_gemm(hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - expert_map: Optional[torch.Tensor] = None) -> bool: +def _valid_deep_gemm(hidden_states: torch.Tensor, w1: torch.Tensor, + w2: torch.Tensor) -> bool: """ Check if the given problem size is supported by the DeepGemm grouped gemm kernel. All of M, N, K and the quantization block_shape must be @@ -47,10 +45,6 @@ def _valid_deep_gemm(hidden_states: torch.Tensor, logger.debug("DeepGemm disabled: deep_gemm not available.") return False - if expert_map is not None: - logger.debug("DeepGemm disabled: expert map NYI.") - return False - M = hidden_states.size(0) _, K, N = w2.size() if not _valid_deep_gemm_shape(M, N, K): @@ -116,7 +110,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): a1q = hidden_states _, N, K = w1.size() - assert global_num_experts != -1 + if global_num_experts == -1: + global_num_experts = w1.size(0) + assert w2.size(1) == K a1q, a1q_scale, _, expert_ids, inv_perm = _moe_permute( @@ -128,6 +124,14 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): self.block_shape[0], ) + if expert_map is not None: + # DeepGemm (Grouped Contiguous) kernel needs a valid B index + # for all rows of A. To that effect, simply compute with + # the 0th weight matrix. + # Note that this relies on the fact that corresponding topk + # weights would be 0 during weight multiplication. + expert_ids = torch.where(expert_ids == -1, 0, expert_ids) + # Note: M_sum is different than the pre-permuted shape of a1q. M_sum = a1q.size(0) workspace1 = _resize_cache(workspace13, (M_sum, N)) @@ -140,9 +144,9 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): self.activation(activation, workspace2, workspace1.view(-1, N)) a2q_scale: Optional[torch.Tensor] = None - - a2q, a2q_scale = _fp8_quantize(workspace2, a2_scale, False, - self.block_shape) + a2q, a2q_scale = per_token_group_quant_fp8(workspace2, + self.block_shape[1], + column_major_scales=True) dg.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous( (a2q, a2q_scale), (w2, w2_scale), workspace3, expert_ids) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py new file mode 100644 index 0000000000000..48cf01638ade4 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py @@ -0,0 +1,236 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import deep_ep +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) + + +class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): + """ + Prepare/Finalize using DeepEP High-Throughput kernels. + """ + + def __init__(self, + buffer: deep_ep.Buffer, + world_size: int, + rank: int, + dp_size: int, + rank_expert_offset: int, + quant_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None): + super().__init__() + self.buffer = buffer + self.world_size = world_size + self.rank = rank + self.dp_size = dp_size + self.rank_expert_offset = rank_expert_offset + self.quant_dtype = quant_dtype + self.block_shape = block_shape + # The dispatch function returns a handle that the combine function + # requires. We store the handle here so it is available to the + # combine function. + self.handle = None + + # From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164 + self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160] + + def max_num_tokens_per_rank(self) -> Optional[int]: + return None + + def topk_indices_dtype(self) -> Optional[torch.dtype]: + return torch.int64 + + def _get_dispatch_config(self) -> Optional[deep_ep.Config]: + if self.dp_size not in self.available_rank_configs: + return None + return deep_ep.Buffer.get_dispatch_config(self.dp_size) + + def _get_combine_config(self) -> Optional[deep_ep.Config]: + if self.dp_size not in self.available_rank_configs: + return None + return deep_ep.Buffer.get_combine_config(self.dp_size) + + def _do_quant(self, tokens: torch.Tensor, + token_scales: Optional[torch.Tensor], per_act_token: bool): + tokens, token_scales = moe_kernel_quantize_input( + tokens, token_scales, self.quant_dtype, per_act_token, + self.block_shape) + return tokens, token_scales + + def _do_dispatch(self, tokens: torch.Tensor, + token_scales: Optional[torch.Tensor], + rank_topk_ids: torch.Tensor, + rank_topk_weights: torch.Tensor, num_experts: int): + + has_scales = token_scales is not None + + (num_tokens_per_rank, num_tokens_per_rdma_rank, expert_num_tokens, + is_token_in_rank, event) = self.buffer.get_dispatch_layout( + topk_idx=rank_topk_ids, + num_experts=num_experts, + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False) + + token_data = tokens + if has_scales: + token_data = (tokens, token_scales) + + ( + token_data, expert_topk_ids, expert_topk_weights, + expert_num_tokens_per_expert_list, self.handle, event + ) = self.buffer.dispatch( + x=token_data, + handle=None, + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, + num_tokens_per_expert=expert_num_tokens, + topk_idx=rank_topk_ids, + topk_weights=rank_topk_weights, + # expert_alignment rounds the number of tokens per expert + # to this value. + expert_alignment=1, + config=self._get_dispatch_config(), + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False) + + if has_scales: + expert_x, expert_x_scale = token_data + else: + expert_x, expert_x_scale = token_data, None + + # The existing MOE kernels assume that all entries of topk_ids are + # valid. To that effect, set the -1s in expert_topk_ids to some expert + # outside this rank so the expert_map can remap it to -1 when safe. + # With Expert Parallel, the experts are divided amongst the rank + # sequentially. For rank 0, set it to num_experts - 1 and for all other + # ranks set it to 0 as we know that expert_map will have a -1 in those + # regions for those ranks. + # + # DeepEP's topk_ids output refers to the local experts directly. Offset + # the topk_ids to move it back to the global experts space so it aligns + # with existing vLLM interfaces. + expert_topk_ids = torch.where( + expert_topk_ids == -1, + num_experts - 1 if self.rank_expert_offset == 0 else 0, + expert_topk_ids + self.rank_expert_offset) + + return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, + expert_topk_weights) + + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + rank_topk_weights: torch.Tensor, + rank_topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor], Optional[torch.Tensor]]: + + if apply_router_weight_on_input: + topk = rank_topk_ids.size(1) + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1") + a1 = a1 * rank_topk_weights.to(a1.dtype) + + # Check if there is a block_shape / or if we can infer the quantization + # schemes from the scales. + per_token_quant = None + if all([x is None for x in [self.block_shape, a1_scale, a2_scale] + ]) and self.quant_dtype is not None: + # Quantization required despite none of the inputs suggesting + # quantization. Fallback to per_token_dynamic quant. + per_token_quant = True + else: + per_token_quant = ((self.block_shape is not None) or + (a1_scale is not None and a1_scale.numel() != 1) + or (a2_scale is not None + and a2_scale.numel() != 1)) + + if per_token_quant: + a1q, a1q_scale = self._do_quant(a1, a1_scale, per_act_token=True) + (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, + expert_topk_weights) = self._do_dispatch( + tokens=a1q, + token_scales=a1q_scale, + rank_topk_ids=rank_topk_ids, + rank_topk_weights=rank_topk_weights, + num_experts=num_experts) + else: + # DeepEP kernels only support dispatching per-token-quant + # quantization. dispatch in bfloat16. + (expert_x, _, expert_num_tokens, expert_topk_ids, + expert_topk_weights) = self._do_dispatch( + tokens=a1, + token_scales=None, + rank_topk_ids=rank_topk_ids, + rank_topk_weights=rank_topk_weights, + num_experts=num_experts) + # quantize now + expert_x_scale = None + if expert_x.numel() != 0: + expert_x, expert_x_scale = self._do_quant(expert_x, + a1_scale, + per_act_token=False) + + return (expert_x, expert_x_scale, expert_num_tokens, expert_topk_ids, + expert_topk_weights) + + def _apply_weights_and_reduce(self, num_tokens: int, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + apply_router_weight_on_input: bool, + output_dtype: torch.dtype): + + if fused_expert_output.ndim == 2: + hidden_dim = fused_expert_output.size(-1) + fused_expert_output = fused_expert_output.view( + num_tokens, -1, hidden_dim) + + if not apply_router_weight_on_input: + # The DeepEP combine kernels don't do the topk weight + # multiplication. We multiply the weights locally. + fused_expert_output = fused_expert_output.to(torch.float32) + fused_expert_output = fused_expert_output * topk_weights.view( + fused_expert_output.size(0), -1, 1) + fused_expert_output = fused_expert_output.to(output_dtype) + + return fused_expert_output.sum(dim=1).to(output_dtype) + + def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool) -> None: + + assert self.handle is not None + + # fused_expert_output can have 0 tokens - This happens when none of the + # tokens from the all2all reach this EP rank. + if fused_expert_output.numel() != 0: + fused_expert_output = self._apply_weights_and_reduce( + num_tokens=topk_ids.size(0), + fused_expert_output=fused_expert_output, + topk_weights=topk_weights, + apply_router_weight_on_input=apply_router_weight_on_input, + output_dtype=output.dtype) + + combined_x, _, event = self.buffer.combine( + x=fused_expert_output, + handle=self.handle, + topk_weights=None, + config=self._get_combine_config(), + previous_event=None, + async_finish=False, + allocate_on_comm_stream=False) + # Respect inplace outputs. + output.copy_(combined_x, non_blocking=True) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py new file mode 100644 index 0000000000000..b9d817a14d57e --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -0,0 +1,152 @@ +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional + +import deep_ep +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.model_executor.layers.fused_moe.utils import ( + moe_kernel_quantize_input) + +# DeepEP kernels quantize dispatch inputs in 128 element chunks. +DEEPEP_QUANT_BLOCK_SIZE = 128 + + +def dequant_fp8(expert_x_fp8: torch.Tensor, + expert_x_scales: torch.Tensor) -> torch.Tensor: + """ + Return dequantized tensor in fp32 + """ + # TODO (varun) : Optimize leverage num_tokens_per_expert counts + assert expert_x_fp8.is_contiguous() + expert_x_scales = expert_x_scales.contiguous() + num_experts = expert_x_fp8.size(0) + + expert_x_fp32 = expert_x_fp8.to(torch.float32).view( + num_experts, -1, DEEPEP_QUANT_BLOCK_SIZE) + expert_x_scales = expert_x_scales.view(num_experts, -1, 1) + return (expert_x_fp32 * expert_x_scales).view(expert_x_fp8.shape) + + +class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): + """ + Prepare/Finalize using DeepEP low-latency kernels. + """ + + # DeepEP low-latency kernels are compiled only for certain + # specific hidden sizes. + SUPPORTED_HIDDEN_SIZES = [2560, 4096, 5120, 7168] + + def __init__(self, + buffer: deep_ep.Buffer, + world_size: int, + dp_size: int, + max_tokens_per_rank: int, + quant_dtype: Optional[torch.dtype] = None, + block_shape: Optional[list[int]] = None, + use_fp8_dispatch: bool = False): + super().__init__() + + self.buffer = buffer + self.world_size = world_size + self.dp_size = dp_size + self.quant_dtype = quant_dtype + self.block_shape = block_shape + self.max_tokens_per_rank = max_tokens_per_rank + self.use_fp8_dispatch = use_fp8_dispatch + # The dispatch function returns a handle that the combine function + # requires. We store the handle here so it is available to the + # combine function. + self.handle = None + + def max_num_tokens_per_rank(self) -> Optional[int]: + return self.max_tokens_per_rank + + def topk_indices_dtype(self) -> Optional[torch.dtype]: + return torch.int64 + + def prepare( + self, + a1: torch.Tensor, + a1_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], + rank_topk_weights: torch.Tensor, + rank_topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor], Optional[torch.Tensor]]: + + hidden_size = a1.size(1) + assert hidden_size in self.SUPPORTED_HIDDEN_SIZES, \ + (f"Hidden Size {hidden_size} not in supported list of hidden sizes" + f"{self.SUPPORTED_HIDDEN_SIZES}") + + if self.use_fp8_dispatch: + assert hidden_size % 128 == 0, \ + "DeepEP kernels quantize the inputs in blocks of shape 128" + + # Quantize + per_act_token = a1_scale.numel() != 1 if a1_scale is not None else ( + a2_scale.numel() != 1 if a2_scale is not None else False) + assert not per_act_token, ( + "low_latency kernels don't support per-act-token quant") + + if apply_router_weight_on_input: + topk = rank_topk_ids.size(1) + # TODO: this only works for topK=1, will need to update for topK>1 + assert topk == 1, ( + "apply_router_weight_on_input is only implemented for topk=1") + a1 = a1 * rank_topk_weights.to(a1.dtype) + + # Dispatch + expert_x, expert_num_tokens, self.handle, event, hook = \ + self.buffer.low_latency_dispatch(a1, + rank_topk_ids, + self.max_tokens_per_rank, + num_experts, + use_fp8=self.use_fp8_dispatch, + async_finish=False, + return_recv_hook=False) + + if self.use_fp8_dispatch: + # TODO (varun) : In the case of dynamic quantization, we could + # probably skip the quant below and use the results directly. + # Although note that the deepep quant is per token 128 elements. + expert_x_fp8, expert_x_scales = expert_x + expert_x = dequant_fp8(expert_x_fp8, + expert_x_scales).to(dtype=a1.dtype) + + num_experts = expert_x.size(0) + hidden_dim = expert_x.size(-1) + + expert_x = expert_x.view((-1, expert_x.size(-1))) + expert_x, expert_x_scale = moe_kernel_quantize_input( + expert_x, a1_scale, self.quant_dtype, per_act_token, + self.block_shape) + expert_x = expert_x.view((num_experts, -1, hidden_dim)) + + return (expert_x, expert_x_scale, expert_num_tokens, None, None) + + def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + apply_router_weight_on_input: bool) -> None: + + assert self.handle is not None + + combine_topk_weights = topk_weights + if apply_router_weight_on_input: + # weights have already been applied. + combine_topk_weights = torch.ones_like(topk_weights) + + # TODO (varun) : Enable zero copy mode + _, event, hook = self.buffer.low_latency_combine( + fused_expert_output, + topk_ids, + combine_topk_weights, + self.handle, + async_finish=False, + zero_copy=False, + return_recv_hook=False, + out=output) diff --git a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py index 205a95e7ff1e4..7490a192df945 100644 --- a/vllm/model_executor/layers/fused_moe/fused_batched_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_batched_moe.py @@ -10,7 +10,8 @@ import triton.language as tl import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.model_executor.layers.fused_moe.fused_moe import ( get_config_dtype_str, try_get_optimal_moe_config) -from vllm.model_executor.layers.fused_moe.utils import _resize_cache +from vllm.model_executor.layers.fused_moe.utils import ( + _resize_cache, moe_kernel_quantize_input) @triton.jit @@ -397,6 +398,12 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): self.rank = rank self.max_num_tokens = max_num_tokens + def max_num_tokens_per_rank(self) -> Optional[int]: + return self.max_num_tokens + + def topk_indices_dtype(self) -> Optional[torch.dtype]: + return None + def prepare( self, a1: torch.Tensor, @@ -407,7 +414,8 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor], Optional[torch.Tensor]]: assert a1.dim() == 2 assert topk_ids.dim() == 2 assert topk_ids.size(0) == a1.size(0) @@ -450,7 +458,7 @@ class BatchedPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): first_expert, :rows, :] = a1[:topks.numel()][topks] tokens_per_expert[expert_id - first_expert] = rows - return b_a1, a1_scale, tokens_per_expert + return b_a1, a1_scale, tokens_per_expert, None, None def finalize( self, @@ -601,6 +609,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): use_int8_w8a8: bool = False, use_int8_w8a16: bool = False, use_int4_w4a16: bool = False, + per_channel_quant: bool = False, block_shape: Optional[list[int]] = None, world_size: int = 1, dp_size: int = 1, @@ -611,12 +620,15 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): self.use_int4_w4a16 = use_int4_w4a16 self.use_int8_w8a16 = use_int8_w8a16 self.block_shape = block_shape + self.per_channel_quant = per_channel_quant self.max_num_tokens = max_num_tokens - assert not use_int8_w8a8, "NYI" - assert not use_int4_w4a16, "NYI" self.world_size = world_size self.dp_size = dp_size + assert not use_int8_w8a8, "NYI" + assert not use_int4_w4a16, "NYI" + assert self.block_shape is None, "NYI" + def workspace_shapes( self, a: torch.Tensor, @@ -670,8 +682,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): torch.float32, torch.float16, torch.bfloat16, torch.float8_e4m3fn ] - # TODO: num_tokens -> max_num_tokens? - E, num_tokens, N, K, top_k_num = mk._moe_problem_size( + E, max_num_tokens, N, K, top_k_num = mk._moe_problem_size( hidden_states, w1, w2, topk_ids) assert w1.size(0) == E @@ -687,7 +698,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): w2.size(), top_k_num, config_dtype, - num_tokens, + max_num_tokens, block_shape=self.block_shape, ) @@ -706,10 +717,12 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): #print(f"shape: E={E}, M={num_tokens}, N={N}, K={K}, top_k={top_k_num}") # We can reuse the memory between these because by the time we need # cache3, we're done with cache1 - intermediate_cache1 = _resize_cache(workspace13, (E, num_tokens, N)) + intermediate_cache1 = _resize_cache(workspace13, + (E, max_num_tokens, N)) intermediate_cache2 = _resize_cache(workspace2, - (E, num_tokens, N // 2)) - intermediate_cache3 = _resize_cache(workspace13, (E, num_tokens, K)) + (E, max_num_tokens, N // 2)) + intermediate_cache3 = _resize_cache(workspace13, + (E, max_num_tokens, K)) # MM1 invoke_moe_batched_triton_kernel(A=hidden_states, @@ -731,15 +744,20 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): self.activation(activation, intermediate_cache2.view(-1, N // 2), intermediate_cache1.view(-1, N)) - #qintermediate_cache2 = intermediate_cache2 - a2q_scale = a2_scale - # TODO (varun) : support w8a8 - assert not self.use_fp8_w8a8 - #if self.use_fp8_w8a8: - # qintermediate_cache2, a2q_scale = _fp8_quantize( - # intermediate_cache2, a2_scale, self.block_shape) + ic2_hidden_size = intermediate_cache2.size(-1) + intermediate_cache2 = intermediate_cache2.view(-1, ic2_hidden_size) - invoke_moe_batched_triton_kernel(A=intermediate_cache2, + qintermediate_cache2, a2q_scale = moe_kernel_quantize_input( + A=intermediate_cache2, + A_scale=a2_scale, + qtype=torch.float8_e4m3fn if self.use_fp8_w8a8 else None, + per_channel_quant=self.per_channel_quant, + block_shape=self.block_shape) + + qintermediate_cache2 = qintermediate_cache2.view( + (E, -1, ic2_hidden_size)) + + invoke_moe_batched_triton_kernel(A=qintermediate_cache2, B=w2, C=intermediate_cache3, expert_num_tokens=expert_num_tokens, @@ -752,5 +770,4 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute): use_int4_w4a16=self.use_int4_w4a16, config=config, block_shape=self.block_shape) - return intermediate_cache3 diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index 883a48c984f21..de7a9a8d0b3bc 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -1164,7 +1164,7 @@ def fused_experts(hidden_states: torch.Tensor, # permute/unpermute ops are available. N = w1.shape[1] if (allow_deep_gemm and use_fp8_w8a8 and N > 512 - and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): + and _valid_deep_gemm(hidden_states, w1, w2)): assert apply_router_weight_on_input is False return deep_gemm_moe_fp8( hidden_states=hidden_states, diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 3ce4cbc2838e9..1812f3b6759a4 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -5,7 +5,7 @@ import importlib from abc import abstractmethod from dataclasses import dataclass from enum import Enum -from typing import Callable, Optional +from typing import Callable, Optional, Union import torch import torch.nn.functional as F @@ -30,16 +30,19 @@ from vllm.platforms.interface import CpuArchEnum from vllm.utils import direct_register_custom_op has_pplx = importlib.util.find_spec("pplx_kernels") is not None +has_deepep = importlib.util.find_spec("deep_ep") is not None if current_platform.is_cuda_alike(): - from .fused_batched_moe import (BatchedPrepareAndFinalize, - BatchedTritonExperts) + from .fused_batched_moe import BatchedTritonExperts from .fused_moe import TritonExperts, fused_experts from .modular_kernel import (FusedMoEModularKernel, FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize) if has_pplx: from .pplx_prepare_finalize import PplxPrepareAndFinalize + if has_deepep: + from .deepep_ht_prepare_finalize import DeepEPHTPrepareAndFinalize + from .deepep_ll_prepare_finalize import DeepEPLLPrepareAndFinalize else: fused_experts = None # type: ignore FusedMoEPermuteExpertsUnpermute = None # type: ignore @@ -71,10 +74,24 @@ class FusedMoEParallelConfig: use_ep: bool # whether to use EP or not + @property + def use_all2all_kernels(self): + return self.dp_size > 1 and self.use_ep + @property def use_pplx_kernels(self): - return self.dp_size > 1 and self.use_ep and \ - envs.VLLM_ALL2ALL_BACKEND == "pplx" + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "pplx") + + @property + def use_deepep_ht_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput") + + @property + def use_deepep_ll_kernels(self): + return (self.use_all2all_kernels + and envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency") @staticmethod def make(tp_size_: int, dp_size_: int, @@ -231,6 +248,14 @@ class MoEConfig: def use_pplx_kernels(self): return self.moe_parallel_config.use_pplx_kernels + @property + def use_deepep_ht_kernels(self): + return self.moe_parallel_config.use_deepep_ht_kernels + + @property + def use_deepep_ll_kernels(self): + return self.moe_parallel_config.use_deepep_ll_kernels + class FusedMoeWeightScaleSupported(Enum): TENSOR = "tensor" @@ -252,7 +277,16 @@ class FusedMoEMethodBase(QuantizeMethodBase): all2all_manager = get_ep_group().device_communicator.all2all_manager assert all2all_manager is not None - prepare_finalize = None + quant_dtype = None + act_quant_block_size = None + from vllm.model_executor.layers.quantization.fp8 import Fp8Config + if isinstance(quant_config, Fp8Config): + act_quant_block_size = quant_config.weight_block_size + quant_dtype = torch.float8_e4m3fn + + prepare_finalize: Optional[Union[PplxPrepareAndFinalize, + DeepEPHTPrepareAndFinalize, + DeepEPLLPrepareAndFinalize]] = None if moe.use_pplx_kernels: all_to_all_args = dict( max_num_tokens=moe.max_num_tokens, @@ -288,8 +322,49 @@ class FusedMoEMethodBase(QuantizeMethodBase): dp_size=all2all_manager.tp_group.world_size, quant_dtype=moe.in_dtype, ) + elif moe.use_deepep_ht_kernels: + assert moe.dp_size == all2all_manager.dp_world_size + all_to_all_args = dict() + handle = all2all_manager.get_handle(all_to_all_args) + prepare_finalize = DeepEPHTPrepareAndFinalize( + handle, + world_size=all2all_manager.world_size, + rank=all2all_manager.rank, + dp_size=all2all_manager.dp_world_size, + rank_expert_offset=all2all_manager.rank * + moe.num_local_experts, + quant_dtype=quant_dtype, + block_shape=act_quant_block_size, + ) + + elif moe.use_deepep_ll_kernels: + assert moe.dp_size == all2all_manager.dp_world_size + + all_to_all_args = dict( + max_num_tokens_per_dp_rank=moe.max_num_tokens, + token_hidden_size=moe.hidden_dim, + num_ep_ranks=all2all_manager.world_size, + num_global_experts=moe.num_experts, + num_local_experts=moe.num_experts // + all2all_manager.world_size) + handle = all2all_manager.get_handle(all_to_all_args) + + # Note (varun): Whether to use FP8 dispatch or not needs some + # profiling. Turning it off for now. + prepare_finalize = DeepEPLLPrepareAndFinalize( + handle, + world_size=all2all_manager.world_size, + dp_size=all2all_manager.dp_world_size, + max_tokens_per_rank=moe.max_num_tokens, + quant_dtype=quant_dtype, + block_shape=act_quant_block_size, + use_fp8_dispatch=False, + ) + + self.topk_indices_dtype = None if prepare_finalize is not None: + self.topk_indices_dtype = prepare_finalize.topk_indices_dtype() experts = self.select_gemm_impl(prepare_finalize) self.fused_experts = FusedMoEModularKernel( prepare_finalize, @@ -297,7 +372,7 @@ class FusedMoEMethodBase(QuantizeMethodBase): ) def select_gemm_impl( - self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize] + self, prepare_finalize: FusedMoEPrepareAndFinalize ) -> FusedMoEPermuteExpertsUnpermute: # based on the all2all implementation, select the appropriate # gemm implementation @@ -334,6 +409,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): def __init__(self, moe: MoEConfig): super().__init__() self.fused_experts = fused_experts # type: ignore + self.topk_indices_dtype = None self.moe = moe self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() @@ -343,8 +419,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): else: self.rocm_aiter_fused_experts = None # type: ignore - def select_gemm_impl( - self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize]): + def select_gemm_impl(self, prepare_finalize: FusedMoEPrepareAndFinalize): assert self.fused_experts == fused_experts @@ -353,11 +428,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): experts: Optional[FusedMoEPermuteExpertsUnpermute] = None - if isinstance(prepare_finalize, - (BatchedPrepareAndFinalize, PplxPrepareAndFinalize)): + use_batched_experts = prepare_finalize.max_num_tokens_per_rank( + ) is not None + if use_batched_experts: logger.debug("BatchedTritonExperts %s", self.moe) + assert self.moe.dp_size == all2all_manager.dp_world_size experts = BatchedTritonExperts( - max_num_tokens=MOE_DP_CHUNK_SIZE, + max_num_tokens=self.moe.max_num_tokens, world_size=all2all_manager.world_size, # dp_size actually means tp_size, bug in pplx kernels dp_size=all2all_manager.tp_group.world_size, @@ -366,6 +443,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): use_int8_w8a16=False, use_int4_w4a16=False, block_shape=None, + per_channel_quant=False, ) else: logger.debug("TritonExperts %s", self.moe) @@ -494,6 +572,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: + topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -505,7 +584,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, - indices_type=torch.uint32 if self.moe.use_pplx_kernels else None) + indices_type=self.topk_indices_dtype) if self.rocm_aiter_moe_enabled: assert expert_map is None @@ -806,11 +885,8 @@ class FusedMoE(torch.nn.Module): # Note: get_quant_method will look at the layer's local_num_experts # for heuristic purposes, so it must be initialized first. quant_method: Optional[QuantizeMethodBase] = None - - if quant_config is None: - quant_method = UnquantizedFusedMoEMethod(moe) - else: - quant_method = quant_config.get_quant_method(self, prefix) + quant_method = (UnquantizedFusedMoEMethod(moe) if quant_config is None + else quant_config.get_quant_method(self, prefix)) assert quant_method is not None assert isinstance(quant_method, FusedMoEMethodBase) @@ -836,7 +912,8 @@ class FusedMoE(torch.nn.Module): # Chunked all2all staging tensor self.batched_hidden_states: Optional[torch.Tensor] = None self.batched_router_logits: Optional[torch.Tensor] = None - if self.moe_parallel_config.use_pplx_kernels: + if (self.moe_parallel_config.use_pplx_kernels + or self.moe_parallel_config.use_deepep_ll_kernels): act_dtype = vllm_config.model_config.dtype self.batched_hidden_states = torch.zeros( (MOE_DP_CHUNK_SIZE, self.hidden_size), @@ -880,6 +957,14 @@ class FusedMoE(torch.nn.Module): def use_pplx_kernels(self): return self.moe_parallel_config.use_pplx_kernels + @property + def use_deepep_ht_kernels(self): + return self.moe_parallel_config.use_deepep_ht_kernels + + @property + def use_deepep_ll_kernels(self): + return self.moe_parallel_config.use_deepep_ll_kernels + def _load_per_tensor_weight_scale(self, shard_id: str, param: torch.nn.Parameter, loaded_weight: torch.Tensor, @@ -1210,19 +1295,21 @@ class FusedMoE(torch.nn.Module): When just tensor-parallel is used, it is not required to reduce the shared_experts results immediately. Instead we reduce at the once at the end of the MoE op. (Refer to DeepSeekV2MoE module) - With EP and the pplx kernels - this is no longer viable as all + With EP and all2all kernels - this is no longer viable as all GPU ranks in DP, produce the complete set of hidden_states. Therefore it is required that we reduce the shared_experts output early. """ - return self.use_pplx_kernels + return (self.use_pplx_kernels or self.use_deepep_ht_kernels + or self.use_deepep_ll_kernels) def maybe_all_reduce_tensor_model_parallel( self, final_hidden_states: torch.Tensor): """ The pplx combine kernel reduces across GPU ranks by default. """ - if self.use_pplx_kernels: + if (self.use_pplx_kernels or self.use_deepep_ht_kernels + or self.use_deepep_ll_kernels): return final_hidden_states else: return tensor_model_parallel_all_reduce(final_hidden_states) @@ -1289,7 +1376,7 @@ class FusedMoE(torch.nn.Module): ctx = get_forward_context() max_tokens_across_dp = ctx.dp_metadata.max_tokens_across_dp_cpu - moe_dp_chunk_size_per_rank = MOE_DP_CHUNK_SIZE + moe_dp_chunk_size_per_rank = self.moe_config.max_num_tokens num_tokens = full_hidden_states.size(0) for chunk_start_ in range(0, max_tokens_across_dp, @@ -1310,12 +1397,17 @@ class FusedMoE(torch.nn.Module): def forward_impl(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): assert self.quant_method is not None - if self.moe_parallel_config.use_pplx_kernels: + if (self.moe_parallel_config.use_pplx_kernels + or self.moe_parallel_config.use_deepep_ll_kernels): return self.forward_impl_chunked(hidden_states, router_logits) - if self.dp_size > 1: + do_naive_dispatch_combine: bool = ( + self.dp_size > 1 + and not self.moe_parallel_config.use_deepep_ht_kernels) + if do_naive_dispatch_combine: hidden_states, router_logits = get_ep_group().dispatch( hidden_states, router_logits) + # Matrix multiply. final_hidden_states = self.quant_method.apply( layer=self, @@ -1335,12 +1427,12 @@ class FusedMoE(torch.nn.Module): apply_router_weight_on_input=self.apply_router_weight_on_input, ) - if self.dp_size > 1: + if do_naive_dispatch_combine: final_hidden_states = get_ep_group().combine(final_hidden_states) if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): - # Default set to False. (May have to add shared expert outputs.) - final_hidden_states = tensor_model_parallel_all_reduce( + # Default set to False. (May have to add shared expert outputs. + final_hidden_states = self.maybe_all_reduce_tensor_model_parallel( final_hidden_states) return final_hidden_states diff --git a/vllm/model_executor/layers/fused_moe/modular_kernel.py b/vllm/model_executor/layers/fused_moe/modular_kernel.py index 5e321c9b43af7..2c27d31eb6eb9 100644 --- a/vllm/model_executor/layers/fused_moe/modular_kernel.py +++ b/vllm/model_executor/layers/fused_moe/modular_kernel.py @@ -94,7 +94,8 @@ class FusedMoEPrepareAndFinalize(ABC): num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor], Optional[torch.Tensor]]: """ Perform any quantization (and/or) dispatching needed for this kernel. @@ -113,6 +114,10 @@ class FusedMoEPrepareAndFinalize(ABC): Returns a tuple of: - quantized + dispatched a. - quantized + dispatched a1_scales. + - Optional tensor as big as number of local experts that contains the + number of tokens assigned to each local expert. + - Optional dispatched expert topk IDs + - Optional dispatched expert topk weight """ raise NotImplementedError @@ -138,6 +143,27 @@ class FusedMoEPrepareAndFinalize(ABC): """ raise NotImplementedError + @abstractmethod + def topk_indices_dtype(self) -> Optional[torch.dtype]: + """ + The PrepareFinalize All2All implementations generally constrain the + dtype of the topk_ids they support. This function returns the + required topk indices dtype so it can be respected. + Return None if there are no such restrictions. + """ + raise NotImplementedError + + @abstractmethod + def max_num_tokens_per_rank(self) -> Optional[int]: + """ + Some PrepareFinalize All2All implementations are batched. Meaning, + they can processes only as set of tokens at a time. This + function returns the batch size i.e the maximum number of tokens + the implementation can process at a time. + Return None if there are no such restrictions. + """ + raise NotImplementedError + class FusedMoEPermuteExpertsUnpermute(ABC): """ @@ -261,6 +287,61 @@ class FusedMoEModularKernel(torch.nn.Module): self.prepare_finalize = prepare_finalize self.fused_experts = fused_experts + def _do_fused_experts( + self, + a1: torch.Tensor, # input to forward fn + a1q: torch.Tensor, # output of prepare fn + w1: torch.Tensor, + w2: torch.Tensor, + topk_ids: torch.Tensor, + expert_num_tokens: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + w1_scale: Optional[torch.Tensor], + w2_scale: Optional[torch.Tensor], + w1_zp: Optional[torch.Tensor], + w2_zp: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor]) -> torch.Tensor: + + _, M, N, K, top_k = _moe_problem_size(a1q, w1, w2, topk_ids) + + # Use a1 here to decipher the correct workspace datatype + workspace13_shape, workspace2_shape, workspace_dtype = ( + self.fused_experts.workspace_shapes(a1, M, N, K, top_k, + global_num_experts)) + + # We can reuse the memory between cache1 and cache3 because by the time + # we need cache3, we're done with cache1 + workspace13 = torch.zeros(workspace13_shape, + device=a1.device, + dtype=workspace_dtype) + workspace2 = torch.zeros(workspace2_shape, + device=a1.device, + dtype=workspace_dtype) + + fused_out = self.fused_experts.apply( + a1q, + w1, + w2, + topk_ids, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=a1q_scale, + a2_scale=a2_scale, + workspace13=workspace13, + workspace2=workspace2, + expert_num_tokens=expert_num_tokens, + ) + + return fused_out + def forward( self, hidden_states: torch.Tensor, @@ -315,49 +396,48 @@ class FusedMoEModularKernel(torch.nn.Module): Returns: - torch.Tensor: The output tensor after applying the MoE layer. """ + a1 = hidden_states - E, M, N, K, top_k = _moe_problem_size(a1, w1, w2, topk_ids) - - if global_num_experts == -1: - global_num_experts = E - output = a1 if inplace else torch.zeros_like(a1) - workspace13_shape, workspace2_shape, workspace_dtype = ( - self.fused_experts.workspace_shapes(a1, M, N, K, top_k, - global_num_experts)) + if global_num_experts == -1: + global_num_experts = w1.size(0) - # We can reuse the memory between cache1 and cache3 because by the time - # we need cache3, we're done with cache1 - workspace13 = torch.zeros(workspace13_shape, - device=a1.device, - dtype=workspace_dtype) - workspace2 = torch.zeros(workspace2_shape, - device=a1.device, - dtype=workspace_dtype) + (a1q, a1q_scale, expert_num_tokens, _expert_topk_ids, + _expert_topk_weights) = self.prepare_finalize.prepare( + a1, a1_scale, a2_scale, topk_weights, topk_ids, + global_num_experts, expert_map, apply_router_weight_on_input) + # Maybe prepare gathered topk_ids and topk_weights from other EP ranks. + topk_ids = topk_ids if _expert_topk_ids is None else _expert_topk_ids + topk_weights = (topk_weights if _expert_topk_weights is None else + _expert_topk_weights) - a1q, a1q_scale, expert_num_tokens = self.prepare_finalize.prepare( - a1, a1_scale, a2_scale, topk_weights, topk_ids, global_num_experts, - expert_map, apply_router_weight_on_input) - - fused_out = self.fused_experts.apply( - a1q, - w1, - w2, - topk_ids, - activation=activation, - global_num_experts=global_num_experts, - expert_map=expert_map, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_zp=w1_zp, - w2_zp=w2_zp, - a1q_scale=a1q_scale, - a2_scale=a2_scale, - workspace13=workspace13, - workspace2=workspace2, - expert_num_tokens=expert_num_tokens, - ) + fused_out = None + if a1q.numel() == 0: + # This happens when none of the tokens from the all2all reach this + # EP rank. Also, note that this is only relevant for CUDAGraph + # incompatible all2all kernels like the DeepEP high-throughput + # kernels. CUDAGraph compatible all2all kernels like the pplx + # kernels and the DeepEP low-latency kernels are always batched + # and can never run into the tensor.numel() == 0 case. + fused_out = torch.empty_like(a1q).to(dtype=a1.dtype) + else: + fused_out = self._do_fused_experts( + a1=a1, + a1q=a1q, + w1=w1, + w2=w2, + topk_ids=topk_ids, + expert_num_tokens=expert_num_tokens, + activation=activation, + global_num_experts=global_num_experts, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_zp=w1_zp, + w2_zp=w2_zp, + a1q_scale=a1q_scale, + a2_scale=a2_scale) self.prepare_finalize.finalize(output, fused_out, topk_weights, topk_ids, apply_router_weight_on_input) diff --git a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py index da78714341513..89481e5bd6b0a 100644 --- a/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +++ b/vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py @@ -25,7 +25,7 @@ def _moe_permute( """ top_k_num = curr_topk_ids.size(1) - tokens_in_chunk = curr_hidden_states.sizze(0) + tokens_in_chunk = curr_hidden_states.size(0) sorted_token_ids, expert_ids, num_tokens_post_padded = ( moe_align_block_size(curr_topk_ids, @@ -37,11 +37,12 @@ def _moe_permute( inv_perm: Optional[torch.Tensor] = None num_tokens = top_k_num * tokens_in_chunk - sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) expert_ids = torch.repeat_interleave(expert_ids, block_m, dim=0) inv_perm = torch.argsort(sorted_token_ids)[:num_tokens] # Permute according to sorted token ids. + sorted_token_ids = sorted_token_ids.clamp(max=num_tokens - 1) + curr_hidden_states = _fp8_perm(curr_hidden_states, sorted_token_ids // top_k_num) diff --git a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py index 8405603cf28a0..1170a16f3de2f 100644 --- a/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py @@ -32,6 +32,12 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): self.dp_size = dp_size self.quant_dtype = quant_dtype + def max_num_tokens_per_rank(self) -> Optional[int]: + return self.max_num_tokens + + def topk_indices_dtype(self) -> Optional[torch.dtype]: + return torch.uint32 + def prepare( self, a1: torch.Tensor, @@ -42,7 +48,8 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor], Optional[torch.Tensor]]: num_tokens = a1.size(0) # M hidden_dim = a1.size(-1) # K @@ -115,7 +122,7 @@ class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): bound_m=bound_m, ) - return expert_x, expert_x_scale, expert_num_tokens + return expert_x, expert_x_scale, expert_num_tokens, None, None def finalize( self, diff --git a/vllm/model_executor/layers/fused_moe/prepare_finalize.py b/vllm/model_executor/layers/fused_moe/prepare_finalize.py index 77a9686c93a63..9ed95e1de9fed 100644 --- a/vllm/model_executor/layers/fused_moe/prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/prepare_finalize.py @@ -24,6 +24,12 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): self.block_shape = block_shape self.quant_dtype = quant_dtype + def max_num_tokens_per_rank(self) -> Optional[int]: + return None + + def topk_indices_dtype(self) -> Optional[torch.dtype]: + return None + def prepare( self, a1: torch.Tensor, @@ -34,7 +40,9 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): num_experts: int, expert_map: Optional[torch.Tensor], apply_router_weight_on_input: bool = False, - ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: + ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], + Optional[torch.Tensor], Optional[torch.Tensor]]: + if apply_router_weight_on_input: topk = topk_ids.size(1) # TODO: this only works for topK=1, will need to update for topK>1 @@ -47,7 +55,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): self.per_channel_quant, self.block_shape) - return a1q, a1q_scale, None + return a1q, a1q_scale, None, None, None def finalize( self, diff --git a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py index 373e8ab396bc3..920931a93d3e8 100644 --- a/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +++ b/vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py @@ -29,9 +29,10 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): per_channel_quant=per_channel_quant, block_shape=block_shape, block_m=block_m) - self.deep_gemm_expert = DeepGemmExperts() self.allow_deep_gemm = allow_deep_gemm self.use_fp8_w8a8 = use_fp8_w8a8 + self.deep_gemm_expert = DeepGemmExperts( + ) if self.allow_deep_gemm else None def workspace_shapes( self, @@ -46,6 +47,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): # workspaces so we can be pessimistic here and allocate for DeepGemm # even if we fall back to triton later, e.g. if expert maps are set. if self.allow_deep_gemm and _valid_deep_gemm_shape(M, N, K): + assert self.deep_gemm_expert is not None return self.deep_gemm_expert.workspace_shapes( a, M, N, K, topk, num_experts) else: @@ -73,7 +75,8 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute): ) -> torch.Tensor: N = w1.size(1) if (self.allow_deep_gemm and self.use_fp8_w8a8 and N > 512 - and _valid_deep_gemm(hidden_states, w1, w2, expert_map)): + and _valid_deep_gemm(hidden_states, w1, w2)): + assert self.deep_gemm_expert is not None return self.deep_gemm_expert.apply( hidden_states, w1, diff --git a/vllm/model_executor/layers/fused_moe/utils.py b/vllm/model_executor/layers/fused_moe/utils.py index c3a58478247a7..692482c2ea692 100644 --- a/vllm/model_executor/layers/fused_moe/utils.py +++ b/vllm/model_executor/layers/fused_moe/utils.py @@ -18,8 +18,8 @@ def _resize_cache(x: torch.Tensor, v: tuple[int, ...]) -> torch.Tensor: Shrink the given tensor and apply the given view to it. This is used to resize the intermediate fused_moe caches. """ - assert prod( - v) <= x.numel(), f"{prod(v)} <= {x.numel()}" # CUDAGRAPH unfriendly? + assert prod(v) <= x.numel( + ), f"{v} ({prod(v)}) <= {x.shape} ({x.numel()})" # CUDAGRAPH unfriendly? return x.flatten()[:prod(v)].view(*v) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index cea4d26a4c48f..2438ec30bdd2b 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -3,7 +3,7 @@ import functools import importlib.util -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Union import torch import torch.nn.functional as F @@ -452,6 +452,9 @@ class Fp8MoEMethod(FusedMoEMethodBase): if envs.VLLM_USE_DEEP_GEMM: if not has_deep_gemm: logger.warning_once("Failed to import DeepGemm kernels.") + elif not self.block_quant: + logger.warning_once("Model is not block quantized. Not using " + " DeepGemm kernels") elif (current_platform.is_cuda() and current_platform.has_device_capability(90)): logger.info_once("Using DeepGemm kernels for Fp8MoEMethod.") @@ -460,8 +463,10 @@ class Fp8MoEMethod(FusedMoEMethodBase): logger.warning_once( "DeepGemm not supported on the current platform.") + self.topk_indices_dtype = None self.fused_experts = functools.partial( # type: ignore fused_experts, + use_fp8_w8a8=True, block_shape=self.quant_config.weight_block_size, allow_deep_gemm=self.allow_deep_gemm) @@ -765,18 +770,39 @@ class Fp8MoEMethod(FusedMoEMethodBase): del layer.w2_input_scale def select_gemm_impl(self, prepare_finalize): + + from vllm.model_executor.layers.fused_moe.fused_batched_moe import ( + BatchedTritonExperts) from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import ( TritonOrDeepGemmExperts) assert not self.use_marlin and not self.rocm_aiter_moe_enabled, ( "Marlin and ROCm AITER are not supported with all2all yet.") - experts = TritonOrDeepGemmExperts( - use_fp8_w8a8=True, - block_shape=self.quant_config.weight_block_size, - allow_deep_gemm=self.allow_deep_gemm, - ) + experts: Optional[Union[BatchedTritonExperts, + TritonOrDeepGemmExperts]] = None + max_num_tokens_per_rank = prepare_finalize.max_num_tokens_per_rank() + use_batched_experts = max_num_tokens_per_rank is not None + if use_batched_experts: + experts = BatchedTritonExperts( + max_num_tokens=max_num_tokens_per_rank, + world_size=prepare_finalize.world_size, + dp_size=prepare_finalize.dp_size, + use_fp8_w8a8=True, + use_int8_w8a8=False, + use_int8_w8a16=False, + use_int4_w4a16=False, + block_shape=None, + ) + else: + experts = TritonOrDeepGemmExperts( + use_fp8_w8a8=True, + block_shape=self.quant_config.weight_block_size, + allow_deep_gemm=self.allow_deep_gemm, + ) + + assert experts is not None return experts def apply( @@ -797,6 +823,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): apply_router_weight_on_input: bool = False, activation: str = "silu", ) -> torch.Tensor: + topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, @@ -808,6 +835,7 @@ class Fp8MoEMethod(FusedMoEMethodBase): custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, + indices_type=self.topk_indices_dtype, ) if self.rocm_aiter_moe_enabled: @@ -855,7 +883,6 @@ class Fp8MoEMethod(FusedMoEMethodBase): topk_ids=topk_ids, inplace=True, activation=activation, - use_fp8_w8a8=True, global_num_experts=global_num_experts, apply_router_weight_on_input=apply_router_weight_on_input, expert_map=expert_map, diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index e2d9424dee280..07ae470fabfb8 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -154,6 +154,21 @@ class CudaPlatformBase(Platform): logger.info( "Forcing kv cache block size to 64 for FlashMLA backend.") + if (envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" + and parallel_config.data_parallel_size > 1 + and vllm_config.compilation_config.use_cudagraph): + logger.info( + "Data Parallel: Forcing enforce eager to be True since DP " + "with DeepEP high-throughput kernels are not CUDA Graph " + "compatible. The DeepEP low-latency kernels are CUDA Graph " + "compatible. Set the all_to_all backend to deepep_low_latency " + "to use those kernels instead.") + vllm_config.compilation_config.use_cudagraph = False + vllm_config.model_config.enforce_eager = True + # TODO (varun): Turning this ON gives incorrect results for the + # Deepseek-V2-lite model. + vllm_config.compilation_config.use_inductor = False + @classmethod def get_current_memory_usage(cls, device: Optional[torch.types.Device] = None