[Perf] Do FP4 quant before All gather on flashinfer trtllmgen MOE (#30014)

Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
This commit is contained in:
jiahanc 2025-12-16 13:01:48 -08:00 committed by GitHub
parent f21f5ea38c
commit 254a7f8fd6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 165 additions and 31 deletions

View File

@ -64,7 +64,12 @@ class NaiveAll2AllManager(All2AllManagerBase):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
if extra_tensors is not None:
raise NotImplementedError(
"extra_tensors is not supported for NaiveAll2AllManager"
)
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
dp_metadata = get_forward_context().dp_metadata
assert dp_metadata is not None
@ -76,6 +81,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
router_logits = self.naive_multicast(
router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel
)
return hidden_states, router_logits
def combine(
@ -113,7 +119,11 @@ class AgRsAll2AllManager(All2AllManagerBase):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
"""
Gather hidden_states and router_logits from all dp ranks.
"""
@ -121,15 +131,22 @@ class AgRsAll2AllManager(All2AllManagerBase):
assert dp_metadata is not None
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
assert sizes is not None
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
hidden_states, router_logits = dist_group.all_gatherv(
[hidden_states, router_logits],
tensors_to_gather = [hidden_states, router_logits]
if extra_tensors is not None:
tensors_to_gather.extend(extra_tensors)
gathered_tensors = dist_group.all_gatherv(
tensors_to_gather,
dim=0,
sizes=sizes,
)
return hidden_states, router_logits
if extra_tensors is not None:
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
return gathered_tensors[0], gathered_tensors[1]
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
@ -204,6 +221,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError
@ -251,6 +269,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
extra_tensors: list[torch.Tensor] | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
raise NotImplementedError

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import threading
from typing import Any
from weakref import WeakValueDictionary
import torch
@ -68,7 +69,11 @@ class All2AllManagerBase:
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
):
extra_tensors: list[torch.Tensor] | None = None,
) -> Any:
# Subclasses should either:
# - implement handling for extra_tensors, or
# - raise a clear error if extra_tensors is not supported.
raise NotImplementedError
def set_num_sms(self, num_sms: int):

View File

@ -318,17 +318,23 @@ class CudaCommunicator(DeviceCommunicatorBase):
return output_list
def dispatch(
def dispatch( # type: ignore[override]
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
assert self.all2all_manager is not None
hidden_states, router_logits = self.all2all_manager.dispatch(
hidden_states, router_logits, is_sequence_parallel
return self.all2all_manager.dispatch(
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors, # type: ignore[call-arg]
)
return hidden_states, router_logits
def combine(
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False

View File

@ -1007,10 +1007,17 @@ class GroupCoordinator:
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_sequence_parallel: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
extra_tensors: list[torch.Tensor] | None = None,
) -> (
tuple[torch.Tensor, torch.Tensor]
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
):
if self.device_communicator is not None:
return self.device_communicator.dispatch(
hidden_states, router_logits, is_sequence_parallel
return self.device_communicator.dispatch( # type: ignore[call-arg]
hidden_states,
router_logits,
is_sequence_parallel,
extra_tensors,
)
else:
return hidden_states, router_logits

View File

@ -71,6 +71,18 @@ class FusedMoEMethodBase(QuantizeMethodBase):
"implementation based on the prepare_finalize"
)
def prepare_dp_allgather_tensor(
self,
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""Hook to prepare tensors and extra tensors for DP allgather + EP dispatch."""
raise NotImplementedError(
"Method 'prepare_dp_allgather_tensor' is not implemented in "
f"{self.__class__.__name__}."
)
@abstractmethod
def get_fused_moe_quant_config(
self, layer: torch.nn.Module

View File

@ -44,6 +44,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
is_flashinfer_supporting_global_sf,
)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
from vllm.utils.math_utils import cdiv, round_up
from vllm.utils.torch_utils import (
aux_stream,
@ -1933,10 +1934,46 @@ class FusedMoE(CustomOp):
)
with sp_ctx:
extra_tensors = None
if do_naive_dispatch_combine:
hidden_states_combined, router_logits = get_ep_group().dispatch(
hidden_states, router_logits, self.is_sequence_parallel
# Avoid circular import
from vllm.model_executor.layers.quantization.modelopt import (
ModelOptNvFp4FusedMoE,
)
post_quant_allgather = (
has_flashinfer_trtllm_fused_moe()
and self.quant_method is not None
and self.dp_size > 1
and self.use_ep
and isinstance(self.quant_method, ModelOptNvFp4FusedMoE)
)
if post_quant_allgather:
hidden_states_to_dispatch, extra_tensors = (
self.quant_method.prepare_dp_allgather_tensor(
self, hidden_states, router_logits
)
)
else:
hidden_states_to_dispatch = hidden_states
dispatch_res = get_ep_group().dispatch(
hidden_states_to_dispatch,
router_logits,
self.is_sequence_parallel,
extra_tensors=extra_tensors,
)
if extra_tensors is not None:
hidden_states_combined, router_logits, extra_tensors_combined = (
dispatch_res
)
hidden_states_combined = (
hidden_states_combined,
extra_tensors_combined[0],
)
else:
hidden_states_combined, router_logits = dispatch_res
# Run shared experts before matrix multiply.
# because matrix multiply maybe modify the hidden_states.
if has_separate_shared_experts and not use_shared_experts_stream:

View File

@ -1522,6 +1522,24 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
w2_blockscale_swizzled, requires_grad=False
)
def prepare_dp_allgather_tensor(
self,
layer: FusedMoE,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
) -> tuple[torch.Tensor, list[torch.Tensor]]:
"""Optionally prepare extra tensors to carry through DP allgather/EP."""
import flashinfer
a1_gscale = layer.w13_input_scale_quant
hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
hidden_states,
a1_gscale,
is_sf_swizzled_layout=False,
)
extra_tensors: list[torch.Tensor] = [hidden_states_sf]
return hidden_states_fp4, extra_tensors
def get_fused_moe_quant_config(
self, layer: torch.nn.Module
) -> FusedMoEQuantConfig | None:
@ -1576,8 +1594,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
e_score_correction_bias=layer.e_score_correction_bias,
)
# Hidden_states in select_experts is only used to extract metadata
if isinstance(x, tuple):
x_routing, _ = x
else:
x_routing = x
topk_weights, topk_ids, _ = layer.select_experts(
hidden_states=x,
hidden_states=x_routing,
router_logits=router_logits,
)

View File

@ -238,7 +238,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
def flashinfer_trtllm_fp4_moe(
layer: torch.nn.Module,
x: torch.Tensor,
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
router_logits: torch.Tensor,
top_k: int,
global_num_experts: int,
@ -269,12 +269,16 @@ def flashinfer_trtllm_fp4_moe(
from vllm.model_executor.models.llama4 import Llama4MoE
# Quantize input to FP4
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)
if isinstance(x, tuple):
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# hidden_states is the already quantized
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)
# Determine routing method type
use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
@ -360,13 +364,17 @@ def flashinfer_trtllm_fp4_routed_moe(
torch.bfloat16
).view(torch.int16)
# Quantize input to FP4
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)
if isinstance(x, tuple):
# Hidden_states is the already quantized
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
else:
# Quantize input to FP4
a1_gscale = layer.w13_input_scale_quant
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
x,
a1_gscale,
is_sf_swizzled_layout=False,
)
# Call TRT-LLM FP4 block-scale MoE kernel
out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(

View File

@ -184,6 +184,23 @@ def has_flashinfer_cutedsl() -> bool:
)
@functools.cache
def has_flashinfer_trtllm_fused_moe() -> bool:
"""Return `True` if FlashInfer TRTLLM fused MoE is available."""
if not has_flashinfer_moe():
return False
required_functions = [
("flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"),
("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"),
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
]
for module_name, attr_name in required_functions:
mod = _get_submodule(module_name)
if not mod or not hasattr(mod, attr_name):
return False
return True
@functools.cache
def has_flashinfer_cutlass_fused_moe() -> bool:
"""Return `True` if FlashInfer CUTLASS fused MoE is available."""