From 72a5101c7a67f59767653fc6c9722c628745d6d1 Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Wed, 24 Sep 2025 13:38:16 -0500 Subject: [PATCH] Support mnnvl all2allv from Flashinfer (#21003) Signed-off-by: Shu Wang Signed-off-by: Shu Wang. Signed-off-by: Tyler Michael Smith Signed-off-by: Tyler Michael Smith Co-authored-by: Tyler Michael Smith Co-authored-by: Tyler Michael Smith Signed-off-by: yewentao256 --- .../moe/modular_kernel_tools/mk_objects.py | 5 +- .../device_communicators/all2all.py | 125 ++++++++-- .../device_communicators/cuda_communicator.py | 5 + .../device_communicators/mnnvl_compat.py | 28 +++ vllm/envs.py | 7 +- .../fused_moe/flashinfer_cutlass_moe.py | 7 +- .../flashinfer_cutlass_prepare_finalize.py | 233 +++++++++++++++++- .../quantization/utils/flashinfer_fp4_moe.py | 6 +- .../quantization/utils/flashinfer_utils.py | 4 +- vllm/utils/flashinfer.py | 30 +++ 10 files changed, 410 insertions(+), 40 deletions(-) create mode 100644 vllm/distributed/device_communicators/mnnvl_compat.py diff --git a/tests/kernels/moe/modular_kernel_tools/mk_objects.py b/tests/kernels/moe/modular_kernel_tools/mk_objects.py index 7947391d03483..57a1da7b4b1a0 100644 --- a/tests/kernels/moe/modular_kernel_tools/mk_objects.py +++ b/tests/kernels/moe/modular_kernel_tools/mk_objects.py @@ -222,7 +222,8 @@ if (has_flashinfer_cutlass_fused_moe() from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501 FlashInferExperts) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - FlashInferCutlassMoEPrepareAndFinalize) + FlashInferCutlassMoEPrepareAndFinalize, + create_flashinfer_prepare_finalize) register_prepare_and_finalize( FlashInferCutlassMoEPrepareAndFinalize, @@ -373,7 +374,7 @@ def make_prepare_finalize( assert prepare_finalize is not None return prepare_finalize elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize: - return FlashInferCutlassMoEPrepareAndFinalize( + return create_flashinfer_prepare_finalize( use_dp=moe.moe_parallel_config.dp_size > 1) else: return MoEPrepareAndFinalizeNoEP() diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index ae18429f62518..661ed939608a0 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -10,9 +10,15 @@ from vllm.distributed import get_dp_group from vllm.forward_context import get_forward_context from vllm.logger import init_logger from vllm.utils import has_deep_ep, has_pplx +from vllm.utils.flashinfer import has_flashinfer_all2all from .base_device_communicator import All2AllManagerBase, Cache +if has_flashinfer_all2all(): + from flashinfer.comm import Mapping + from flashinfer.comm.mnnvl import MnnvlConfig + from flashinfer.comm.trtllm_alltoall import MnnvlMoe + logger = init_logger(__name__) @@ -47,24 +53,22 @@ class NaiveAll2AllManager(All2AllManagerBase): def dispatch(self, hidden_states: torch.Tensor, router_logits: torch.Tensor): - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu + sizes = get_forward_context( + ).dp_metadata.get_chunk_sizes_across_dp_rank() + hidden_states, router_logits = get_dp_group().all_gatherv( + [hidden_states, router_logits], + dim=0, + sizes=sizes, + ) - hidden_states = self.naive_multicast(hidden_states, - cu_tokens_across_dp_cpu) - router_logits = self.naive_multicast(router_logits, - cu_tokens_across_dp_cpu) return hidden_states, router_logits def combine(self, hidden_states: torch.Tensor) -> torch.Tensor: - cu_tokens_across_dp_cpu = get_forward_context( - ).dp_metadata.cu_tokens_across_dp_cpu - start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[ - self.dp_rank - 1] - end = cu_tokens_across_dp_cpu[self.dp_rank] - - all_hidden_states = self.dp_group.all_reduce(hidden_states) - hidden_states = all_hidden_states[start:end, :] + sizes = get_forward_context( + ).dp_metadata.get_chunk_sizes_across_dp_rank() + hidden_states = get_dp_group().reduce_scatterv(hidden_states, + dim=0, + sizes=sizes) return hidden_states def destroy(self): @@ -300,4 +304,95 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase): # DeepEP LL uses RDMA so no SMs are used for communication def max_sms_used(self) -> Optional[int]: - return 0 \ No newline at end of file + return 0 + + +class FlashInferAllToAllManager(All2AllManagerBase): + """ + All2All communication based on flashinfer kernels. + """ + + def __init__(self, cpu_group): + assert has_flashinfer_all2all( + ), "flashinfer all2all module not found. Please install/check flashinfer" # noqa + super().__init__(cpu_group) + logger.debug( + "Initialize for flashinfer All2All " + "rank=%d, world size=%d", self.rank, self.world_size) + self.initialized = False + self.alltoall_info = None + + def initialize( + self, + world_size: int, + rank: int, + gpus_per_node: int, + ): + """Initialize workspace""" + if self.initialized: + return + + self.cleanup() + logger.debug("making map: " + "rank=%d, world size=%d", rank, world_size) + self.mapping = Mapping( + world_size, + rank, + gpus_per_node, + tp_size=world_size, + ) + + from vllm.distributed.device_communicators.mnnvl_compat import ( + CustomCommunicator) + dp_config = MnnvlConfig( + comm_backend=CustomCommunicator(get_dp_group().cpu_group), + fabric_page_size=1 << 29, # 512MB + allocation_granularity=0 # Auto-detect + ) + + self.workspace_tensor = MnnvlMoe.get_moe_workspaces( + self.mapping, dp_config) + self.prepare_workspace_tensor = MnnvlMoe.get_moe_prepare_workspace( + self.mapping, dp_config) + + self.world_size = world_size + self.rank = rank + self.gpus_per_node = gpus_per_node + self.initialized = True + + logger.info("FlashInfer All2All initialized for rank %s, size %s", + rank, world_size) + + def ensure_alltoall_workspace_initialized(self): + """Ensure workspace is initialized""" + if not has_flashinfer_all2all(): + return False + + if self.world_size <= 1: + return False + + if not self.initialized: + self.initialize( + world_size=self.world_size, + rank=self.rank, + gpus_per_node=torch.cuda.device_count, + ) + return self.initialized + + def get_handle(self, kwargs): + return self + + def cleanup(self): + """Clean up workspace""" + if self.initialized and self.workspace_tensor is not None \ + and self.prepare_workspace_tensor is not None: + try: + del self.workspace_tensor + del self.prepare_workspace_tensor + except Exception as e: + logger.warning("Failed to cleanup FlashInfer workspace: %s", e) + finally: + self.workspace_tensor = None + self.prepare_workspace_tensor = None + self.mapping = None + self.initialized = False \ No newline at end of file diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index b20e79f577c35..bab372b722dbb 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -114,6 +114,11 @@ class CudaCommunicator(DeviceCommunicatorBase): from .all2all import DeepEPLLAll2AllManager self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group) logger.info("Using DeepEP Low-Latency all2all manager.") + elif all2all_backend == "flashinfer_all2allv": + from .all2all import FlashInferAllToAllManager + self.all2all_manager = FlashInferAllToAllManager( + self.cpu_group) + logger.info("Using Flashinfer all2allv manager.") else: raise ValueError(f"Unknown all2all backend: {all2all_backend}") diff --git a/vllm/distributed/device_communicators/mnnvl_compat.py b/vllm/distributed/device_communicators/mnnvl_compat.py new file mode 100644 index 0000000000000..80072c4fa643f --- /dev/null +++ b/vllm/distributed/device_communicators/mnnvl_compat.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import torch.distributed as dist +from flashinfer.comm.mnnvl import CommBackend as CommBackend + +from vllm.utils.flashinfer import has_flashinfer_all2all + +assert has_flashinfer_all2all(), "Flashinfer alltoallv module cannot be found" + + +class CustomCommunicator(CommBackend): + + def __init__(self, group): + self._group = group + + def Get_rank(self) -> int: + return self._group.rank() + + def Get_size(self) -> int: + return self._group.size() + + def allgather(self, data: int): + gathered = [None] * self.Get_size() + dist.all_gather_object(gathered, data, group=self._group) + return gathered + + def Split(self, color: int, key: int) -> 'CustomCommunicator': + return self diff --git a/vllm/envs.py b/vllm/envs.py index 0833949b527f5..4797d96bb899a 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -156,7 +156,8 @@ if TYPE_CHECKING: VLLM_ALL2ALL_BACKEND: Literal["naive", "pplx", "deepep_high_throughput", "deepep_low_latency", - "allgather_reducescatter"] = \ + "allgather_reducescatter", + "flashinfer_all2allv"] = \ "allgather_reducescatter" VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840 VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1 @@ -1209,12 +1210,14 @@ environment_variables: dict[str, Callable[[], Any]] = { # - "pplx": use pplx kernels # - "deepep_high_throughput", use deepep high-throughput kernels # - "deepep_low_latency", use deepep low-latency kernels + # - "flashinfer_all2allv", use flashinfer alltoallv kernels for mnnvl "VLLM_ALL2ALL_BACKEND": env_with_choices("VLLM_ALL2ALL_BACKEND", "allgather_reducescatter", ["naive", "pplx", "deepep_high_throughput", "deepep_low_latency", - "allgather_reducescatter"]), + "allgather_reducescatter", + "flashinfer_all2allv"]), # Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support. # Both require compute capability 10.0 or above. diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index a074da883088e..8700181d18feb 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -8,7 +8,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - FlashInferCutlassMoEPrepareAndFinalize) + create_flashinfer_prepare_finalize) from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceNoOP) from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe, @@ -108,7 +108,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute): of each tuple must be the number of tokens. """ aq_m, aq_n = aq.shape - workspace2 = () + workspace2 = (0, ) output_shape = (aq_m, aq_n * 2) if self.quant_dtype != \ torch.float8_e4m3fn else (aq_m, aq_n) workspace_dtype = a.dtype @@ -192,9 +192,8 @@ def flashinfer_cutlass_moe_fp4( expert_map: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, ) -> torch.Tensor: - fused_experts = mk.FusedMoEModularKernel( - FlashInferCutlassMoEPrepareAndFinalize(use_dp=False), + create_flashinfer_prepare_finalize(use_dp=False), FlashInferExperts( out_dtype=hidden_states.dtype, quant_config=quant_config, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py index 8c7eff59f3cd1..6e127064d32d6 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py @@ -5,7 +5,9 @@ from typing import Optional import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk -from vllm.distributed import get_dp_group +from vllm.distributed import get_dp_group, get_ep_group +from vllm.distributed.device_communicators.base_device_communicator import ( + All2AllManagerBase) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.utils import ( @@ -18,6 +20,7 @@ def get_local_sizes(): class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): + """Base class for FlashInfer MoE prepare and finalize operations.""" def __init__( self, @@ -42,6 +45,39 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): def num_dispatchers(self) -> int: return self.num_dispatchers_ + def _apply_router_weight_on_input( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + ) -> None: + """Apply router weight on input if needed.""" + if apply_router_weight_on_input: + topk = topk_ids.size(1) + assert topk == 1, \ + "apply_router_weight_on_input is only implemented for topk=1" + a1.mul_(topk_weights.to(a1.dtype)) + + +class FlashInferAllToAllMoEPrepareAndFinalize( + FlashInferCutlassMoEPrepareAndFinalize): + """FlashInfer implementation using AllToAll communication.""" + + def __init__( + self, + use_dp: bool, + num_dispatchers: int = 1, + ): + super().__init__(use_dp, num_dispatchers) + self.alltoall_info = None + + # Initialize all2all_manager only for DP case + self.all2all_manager = None + if self.use_dp: + self.all2all_manager = get_ep_group( + ).device_communicator.all2all_manager + def prepare( self, a1: torch.Tensor, @@ -53,12 +89,84 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): quant_config: FusedMoEQuantConfig, ) -> mk.PrepareResultType: - if apply_router_weight_on_input: - topk = topk_ids.size(1) - # TODO: this only works for topK=1, will need to update for topK>1 - assert topk == 1, \ - "apply_router_weight_on_input is only implemented for topk=1" - a1.mul_(topk_weights.to(a1.dtype)) + self._apply_router_weight_on_input(a1, topk_weights, topk_ids, + apply_router_weight_on_input) + + if not self.use_dp: + # Non-DP case: standard quantization + a1q, a1q_scale = moe_kernel_quantize_input( + a1, + quant_config.a1_gscale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + is_fp4_scale_swizzled=not self.use_dp, + ) + else: + # DP case: use FlashInfer AllToAll + global_num_tokens_cpu = get_local_sizes() + top_k = topk_ids.size(1) + + (self.alltoall_info, topk_ids, topk_weights, a1q, + a1q_scale) = flashinfer_alltoall_dispatch( + self.all2all_manager, + global_num_tokens_cpu, + a1, + quant_config.a1_gscale, + topk_ids, + topk_weights, + top_k, + num_experts, + quant_config, + ) + + return a1q, a1q_scale, None, topk_ids, topk_weights + + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: + if self.use_dp: + top_k = topk_ids.size(1) + token_count = output.shape[0] + fused_expert_output = flashinfer_alltoall_combine( + self.all2all_manager, + fused_expert_output, + top_k=top_k, + token_count=token_count, + alltoall_info=self.alltoall_info, + ) + output.copy_(fused_expert_output) + + +class FlashInferAllGatherMoEPrepareAndFinalize( + FlashInferCutlassMoEPrepareAndFinalize): + + def __init__( + self, + use_dp: bool, + num_dispatchers: int = 1, + ): + super().__init__(use_dp, num_dispatchers) + + def prepare( + self, + a1: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + num_experts: int, + expert_map: Optional[torch.Tensor], + apply_router_weight_on_input: bool, + quant_config: FusedMoEQuantConfig, + ) -> mk.PrepareResultType: + + self._apply_router_weight_on_input(a1, topk_weights, topk_ids, + apply_router_weight_on_input) a1q, a1q_scale = moe_kernel_quantize_input( a1, @@ -66,7 +174,6 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): quant_config.quant_dtype, quant_config.per_act_token_quant, quant_config.block_shape, - # Swizzling after communication is_fp4_scale_swizzled=not self.use_dp, ) if self.use_dp: @@ -76,17 +183,117 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): dim=0, sizes=get_local_sizes(), ) - a1_m, a1_n = a1q.shape a1q_scale = nvfp4_block_scale_interleave(a1q_scale) return a1q, a1q_scale, None, topk_ids, topk_weights - def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor, - topk_weights: torch.Tensor, topk_ids: torch.Tensor, - apply_router_weight_on_input: bool, - weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None: + def finalize( + self, + output: torch.Tensor, + fused_expert_output: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + apply_router_weight_on_input: bool, + weight_and_reduce_impl: mk.TopKWeightAndReduce, + ) -> None: if self.use_dp: fused_expert_output = get_dp_group().reduce_scatterv( fused_expert_output, dim=0, sizes=get_local_sizes()) output.copy_(fused_expert_output) + + +def flashinfer_alltoall_dispatch( + all2all_manager: All2AllManagerBase, + global_num_tokens_cpu: list[int], + x: torch.Tensor, + gs: torch.Tensor, + topk_ids: torch.Tensor, + topk_weights: torch.Tensor, + top_k: int, + num_experts: int, + quant_config: FusedMoEQuantConfig, +): + from flashinfer.comm.trtllm_alltoall import MnnvlMoe + assert (all2all_manager.ensure_alltoall_workspace_initialized() + ), "FlashInfer AllToAll workspace not available" + + ep_rank = all2all_manager.rank + ep_size = all2all_manager.world_size + max_num_token = max(global_num_tokens_cpu + ) if global_num_tokens_cpu is not None else x.shape[0] + alltoall_info, topk_ids, topk_weights, _ = ( + MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather( + topk_ids, + topk_weights, + None, + all2all_manager.prepare_workspace, + max_num_token, + ep_rank, + ep_size, + num_experts, + num_experts, + top_k, + )) + + x, x_sf = moe_kernel_quantize_input( + x, + gs, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + is_fp4_scale_swizzled=False, # delay swizzle to after comm + ) + x = MnnvlMoe.mnnvl_moe_alltoallv( + x, + alltoall_info, + all2all_manager.workspace_tensor, + ep_rank, + ep_size, + ) + + x_sf = MnnvlMoe.mnnvl_moe_alltoallv( + x_sf, + alltoall_info, + all2all_manager.workspace_tensor, + ep_rank, + ep_size, + ) + x_sf = nvfp4_block_scale_interleave(x_sf) + return alltoall_info, topk_ids, topk_weights, x, x_sf + + +def flashinfer_alltoall_combine( + all2all_manager: All2AllManagerBase, + output: torch.Tensor, + top_k: int, + token_count: int, + alltoall_info, +): + from flashinfer.comm.trtllm_alltoall import MnnvlMoe + assert (all2all_manager.ensure_alltoall_workspace_initialized() + ), "FlashInfer AllToAll workspace not available" + return MnnvlMoe.mnnvl_moe_alltoallv_combine( + output, + alltoall_info, + all2all_manager.workspace_tensor, + ep_rank=all2all_manager.rank, + ep_size=all2all_manager.world_size, + top_k=top_k, + token_count=token_count, + ) + + +def create_flashinfer_prepare_finalize( + use_dp: bool, + use_nvfp4: bool = False, + enable_alltoallv: bool = False, +) -> FlashInferCutlassMoEPrepareAndFinalize: + """Factory function to create the appropriate FlashInfer implementation.""" + if use_nvfp4: + if enable_alltoallv: + return FlashInferAllToAllMoEPrepareAndFinalize(use_dp) + else: + return FlashInferAllGatherMoEPrepareAndFinalize(use_dp) + # Fp8 only supports AllGather + return FlashInferAllGatherMoEPrepareAndFinalize(use_dp) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index fabf855b36e68..a520302c62d9f 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - FlashInferCutlassMoEPrepareAndFinalize) + create_flashinfer_prepare_finalize) from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe @@ -51,7 +51,9 @@ def build_flashinfer_fp4_cutlass_moe_prepare_finalize( moe: FusedMoEConfig) -> mk.FusedMoEPrepareAndFinalize: """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" use_dp = moe.moe_parallel_config.dp_size > 1 - return FlashInferCutlassMoEPrepareAndFinalize(use_dp) + enable_alltoallv = envs.VLLM_ALL2ALL_BACKEND == "flashinfer_all2allv" + return create_flashinfer_prepare_finalize( + use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv) def select_nvfp4_gemm_impl( diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index aa66a42c588a7..b779a5355b679 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -13,7 +13,7 @@ from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig, from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501 - FlashInferCutlassMoEPrepareAndFinalize) + create_flashinfer_prepare_finalize) logger = init_logger(__name__) @@ -173,7 +173,7 @@ def build_flashinfer_fp8_cutlass_moe_prepare_finalize( moe: Optional[FusedMoEConfig], ) -> mk.FusedMoEPrepareAndFinalize: """Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel""" use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False - return FlashInferCutlassMoEPrepareAndFinalize(use_dp) + return create_flashinfer_prepare_finalize(use_dp) def select_cutlass_fp8_gemm_impl( diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index ebc7a56ff906a..734cd938792a9 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -97,6 +97,34 @@ autotune = _lazy_import_wrapper( fallback_fn=lambda *args, **kwargs: contextlib.nullcontext()) +@functools.cache +def has_flashinfer_comm() -> bool: + """Return ``True`` if FlashInfer comm module is available.""" + return has_flashinfer() and importlib.util.find_spec( + "flashinfer.comm") is not None + + +@functools.cache +def has_flashinfer_all2all() -> bool: + """Return ``True`` if FlashInfer mnnvl all2all is available.""" + if not has_flashinfer_comm(): + return False + + # Check if all required functions are available + required_functions = [ + ("flashinfer.comm", "Mapping"), + ("flashinfer.comm.mnnvl", "MnnvlMemory"), + ("flashinfer.comm.trtllm_alltoall", "MnnvlMoe"), + ("flashinfer.comm.trtllm_alltoall", "MoEAlltoallInfo"), + ] + + for module_name, attr_name in required_functions: + mod = _get_submodule(module_name) + if not mod or not hasattr(mod, attr_name): + return False + return True + + @functools.cache def has_flashinfer_moe() -> bool: """Return ``True`` if FlashInfer MoE module is available.""" @@ -402,6 +430,8 @@ __all__ = [ "trtllm_fp4_block_scale_moe", "autotune", "has_flashinfer_moe", + "has_flashinfer_comm", + "has_flashinfer_all2all", "has_flashinfer_cutlass_fused_moe", "has_nvidia_artifactory", "supports_trtllm_attention",