From 254a7f8fd613d6b6964abc277b73ca1f0b823cdb Mon Sep 17 00:00:00 2001 From: jiahanc <173873397+jiahanc@users.noreply.github.com> Date: Tue, 16 Dec 2025 13:01:48 -0800 Subject: [PATCH] [Perf] Do FP4 quant before All gather on flashinfer trtllmgen MOE (#30014) Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com> --- .../device_communicators/all2all.py | 29 ++++++++++--- .../base_device_communicator.py | 7 +++- .../device_communicators/cuda_communicator.py | 16 +++++--- vllm/distributed/parallel_state.py | 13 ++++-- .../layers/fused_moe/fused_moe_method_base.py | 12 ++++++ vllm/model_executor/layers/fused_moe/layer.py | 41 ++++++++++++++++++- .../layers/quantization/modelopt.py | 25 ++++++++++- .../quantization/utils/flashinfer_fp4_moe.py | 36 +++++++++------- vllm/utils/flashinfer.py | 17 ++++++++ 9 files changed, 165 insertions(+), 31 deletions(-) diff --git a/vllm/distributed/device_communicators/all2all.py b/vllm/distributed/device_communicators/all2all.py index c40dde26b741f..7a4e81cf967de 100644 --- a/vllm/distributed/device_communicators/all2all.py +++ b/vllm/distributed/device_communicators/all2all.py @@ -64,7 +64,12 @@ class NaiveAll2AllManager(All2AllManagerBase): hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: + if extra_tensors is not None: + raise NotImplementedError( + "extra_tensors is not supported for NaiveAll2AllManager" + ) sp_size = self.tp_group.world_size if is_sequence_parallel else 1 dp_metadata = get_forward_context().dp_metadata assert dp_metadata is not None @@ -76,6 +81,7 @@ class NaiveAll2AllManager(All2AllManagerBase): router_logits = self.naive_multicast( router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel ) + return hidden_states, router_logits def combine( @@ -113,7 +119,11 @@ class AgRsAll2AllManager(All2AllManagerBase): hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor]: + extra_tensors: list[torch.Tensor] | None = None, + ) -> ( + tuple[torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): """ Gather hidden_states and router_logits from all dp ranks. """ @@ -121,15 +131,22 @@ class AgRsAll2AllManager(All2AllManagerBase): assert dp_metadata is not None sizes = dp_metadata.get_chunk_sizes_across_dp_rank() assert sizes is not None - dist_group = get_ep_group() if is_sequence_parallel else get_dp_group() assert sizes[dist_group.rank_in_group] == hidden_states.shape[0] - hidden_states, router_logits = dist_group.all_gatherv( - [hidden_states, router_logits], + + tensors_to_gather = [hidden_states, router_logits] + if extra_tensors is not None: + tensors_to_gather.extend(extra_tensors) + + gathered_tensors = dist_group.all_gatherv( + tensors_to_gather, dim=0, sizes=sizes, ) - return hidden_states, router_logits + + if extra_tensors is not None: + return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:]) + return gathered_tensors[0], gathered_tensors[1] def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False @@ -204,6 +221,7 @@ class PPLXAll2AllManager(All2AllManagerBase): hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError @@ -251,6 +269,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase): hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, + extra_tensors: list[torch.Tensor] | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: raise NotImplementedError diff --git a/vllm/distributed/device_communicators/base_device_communicator.py b/vllm/distributed/device_communicators/base_device_communicator.py index 3a849da70e4cb..caeff54406b59 100644 --- a/vllm/distributed/device_communicators/base_device_communicator.py +++ b/vllm/distributed/device_communicators/base_device_communicator.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import threading +from typing import Any from weakref import WeakValueDictionary import torch @@ -68,7 +69,11 @@ class All2AllManagerBase: hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, - ): + extra_tensors: list[torch.Tensor] | None = None, + ) -> Any: + # Subclasses should either: + # - implement handling for extra_tensors, or + # - raise a clear error if extra_tensors is not supported. raise NotImplementedError def set_num_sms(self, num_sms: int): diff --git a/vllm/distributed/device_communicators/cuda_communicator.py b/vllm/distributed/device_communicators/cuda_communicator.py index cd9c267beb5b5..9542498c453ec 100644 --- a/vllm/distributed/device_communicators/cuda_communicator.py +++ b/vllm/distributed/device_communicators/cuda_communicator.py @@ -318,17 +318,23 @@ class CudaCommunicator(DeviceCommunicatorBase): return output_list - def dispatch( + def dispatch( # type: ignore[override] self, hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor]: + extra_tensors: list[torch.Tensor] | None = None, + ) -> ( + tuple[torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): assert self.all2all_manager is not None - hidden_states, router_logits = self.all2all_manager.dispatch( - hidden_states, router_logits, is_sequence_parallel + return self.all2all_manager.dispatch( + hidden_states, + router_logits, + is_sequence_parallel, + extra_tensors, # type: ignore[call-arg] ) - return hidden_states, router_logits def combine( self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index 338cb1f1814b5..f5ada5a009ec3 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -1007,10 +1007,17 @@ class GroupCoordinator: hidden_states: torch.Tensor, router_logits: torch.Tensor, is_sequence_parallel: bool = False, - ) -> tuple[torch.Tensor, torch.Tensor]: + extra_tensors: list[torch.Tensor] | None = None, + ) -> ( + tuple[torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]] + ): if self.device_communicator is not None: - return self.device_communicator.dispatch( - hidden_states, router_logits, is_sequence_parallel + return self.device_communicator.dispatch( # type: ignore[call-arg] + hidden_states, + router_logits, + is_sequence_parallel, + extra_tensors, ) else: return hidden_states, router_logits diff --git a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py index 8c9d8a2777d58..a46e3972ed8e3 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe_method_base.py @@ -71,6 +71,18 @@ class FusedMoEMethodBase(QuantizeMethodBase): "implementation based on the prepare_finalize" ) + def prepare_dp_allgather_tensor( + self, + layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821 + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + """Hook to prepare tensors and extra tensors for DP allgather + EP dispatch.""" + raise NotImplementedError( + "Method 'prepare_dp_allgather_tensor' is not implemented in " + f"{self.__class__.__name__}." + ) + @abstractmethod def get_fused_moe_quant_config( self, layer: torch.nn.Module diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index cc3afade709d9..b39ce415a0f83 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -44,6 +44,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( is_flashinfer_supporting_global_sf, ) from vllm.platforms import current_platform +from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe from vllm.utils.math_utils import cdiv, round_up from vllm.utils.torch_utils import ( aux_stream, @@ -1933,10 +1934,46 @@ class FusedMoE(CustomOp): ) with sp_ctx: + extra_tensors = None if do_naive_dispatch_combine: - hidden_states_combined, router_logits = get_ep_group().dispatch( - hidden_states, router_logits, self.is_sequence_parallel + # Avoid circular import + from vllm.model_executor.layers.quantization.modelopt import ( + ModelOptNvFp4FusedMoE, ) + + post_quant_allgather = ( + has_flashinfer_trtllm_fused_moe() + and self.quant_method is not None + and self.dp_size > 1 + and self.use_ep + and isinstance(self.quant_method, ModelOptNvFp4FusedMoE) + ) + if post_quant_allgather: + hidden_states_to_dispatch, extra_tensors = ( + self.quant_method.prepare_dp_allgather_tensor( + self, hidden_states, router_logits + ) + ) + else: + hidden_states_to_dispatch = hidden_states + + dispatch_res = get_ep_group().dispatch( + hidden_states_to_dispatch, + router_logits, + self.is_sequence_parallel, + extra_tensors=extra_tensors, + ) + if extra_tensors is not None: + hidden_states_combined, router_logits, extra_tensors_combined = ( + dispatch_res + ) + hidden_states_combined = ( + hidden_states_combined, + extra_tensors_combined[0], + ) + else: + hidden_states_combined, router_logits = dispatch_res + # Run shared experts before matrix multiply. # because matrix multiply maybe modify the hidden_states. if has_separate_shared_experts and not use_shared_experts_stream: diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index f71854e6b63c5..d5d7e7bfaae73 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1522,6 +1522,24 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): w2_blockscale_swizzled, requires_grad=False ) + def prepare_dp_allgather_tensor( + self, + layer: FusedMoE, + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + ) -> tuple[torch.Tensor, list[torch.Tensor]]: + """Optionally prepare extra tensors to carry through DP allgather/EP.""" + import flashinfer + + a1_gscale = layer.w13_input_scale_quant + hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize( + hidden_states, + a1_gscale, + is_sf_swizzled_layout=False, + ) + extra_tensors: list[torch.Tensor] = [hidden_states_sf] + return hidden_states_fp4, extra_tensors + def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: @@ -1576,8 +1594,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase): e_score_correction_bias=layer.e_score_correction_bias, ) + # Hidden_states in select_experts is only used to extract metadata + if isinstance(x, tuple): + x_routing, _ = x + else: + x_routing = x topk_weights, topk_ids, _ = layer.select_experts( - hidden_states=x, + hidden_states=x_routing, router_logits=router_logits, ) diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index 76bce8a8d98d6..1d410316d6299 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -238,7 +238,7 @@ def prepare_static_weights_for_trtllm_fp4_moe( def flashinfer_trtllm_fp4_moe( layer: torch.nn.Module, - x: torch.Tensor, + x: torch.Tensor | tuple[torch.Tensor, torch.Tensor], router_logits: torch.Tensor, top_k: int, global_num_experts: int, @@ -269,12 +269,16 @@ def flashinfer_trtllm_fp4_moe( from vllm.model_executor.models.llama4 import Llama4MoE # Quantize input to FP4 - a1_gscale = layer.w13_input_scale_quant - (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( - x, - a1_gscale, - is_sf_swizzled_layout=False, - ) + if isinstance(x, tuple): + hidden_states_fp4, hidden_states_scale_linear_fp4 = x + else: + # hidden_states is the already quantized + a1_gscale = layer.w13_input_scale_quant + (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( + x, + a1_gscale, + is_sf_swizzled_layout=False, + ) # Determine routing method type use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function @@ -360,13 +364,17 @@ def flashinfer_trtllm_fp4_routed_moe( torch.bfloat16 ).view(torch.int16) - # Quantize input to FP4 - a1_gscale = layer.w13_input_scale_quant - (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( - x, - a1_gscale, - is_sf_swizzled_layout=False, - ) + if isinstance(x, tuple): + # Hidden_states is the already quantized + hidden_states_fp4, hidden_states_scale_linear_fp4 = x + else: + # Quantize input to FP4 + a1_gscale = layer.w13_input_scale_quant + (hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize( + x, + a1_gscale, + is_sf_swizzled_layout=False, + ) # Call TRT-LLM FP4 block-scale MoE kernel out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe( diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 5019b771f4a14..1c2710be3173b 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -184,6 +184,23 @@ def has_flashinfer_cutedsl() -> bool: ) +@functools.cache +def has_flashinfer_trtllm_fused_moe() -> bool: + """Return `True` if FlashInfer TRTLLM fused MoE is available.""" + if not has_flashinfer_moe(): + return False + required_functions = [ + ("flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"), + ("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"), + ("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"), + ] + for module_name, attr_name in required_functions: + mod = _get_submodule(module_name) + if not mod or not hasattr(mod, attr_name): + return False + return True + + @functools.cache def has_flashinfer_cutlass_fused_moe() -> bool: """Return `True` if FlashInfer CUTLASS fused MoE is available."""