mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-27 10:48:43 +08:00
[Perf] Do FP4 quant before All gather on flashinfer trtllmgen MOE (#30014)
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
This commit is contained in:
parent
f21f5ea38c
commit
254a7f8fd6
@ -64,7 +64,12 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if extra_tensors is not None:
|
||||
raise NotImplementedError(
|
||||
"extra_tensors is not supported for NaiveAll2AllManager"
|
||||
)
|
||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||
dp_metadata = get_forward_context().dp_metadata
|
||||
assert dp_metadata is not None
|
||||
@ -76,6 +81,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
router_logits = self.naive_multicast(
|
||||
router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel
|
||||
)
|
||||
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(
|
||||
@ -113,7 +119,11 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
"""
|
||||
Gather hidden_states and router_logits from all dp ranks.
|
||||
"""
|
||||
@ -121,15 +131,22 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
||||
assert dp_metadata is not None
|
||||
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
assert sizes is not None
|
||||
|
||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
|
||||
hidden_states, router_logits = dist_group.all_gatherv(
|
||||
[hidden_states, router_logits],
|
||||
|
||||
tensors_to_gather = [hidden_states, router_logits]
|
||||
if extra_tensors is not None:
|
||||
tensors_to_gather.extend(extra_tensors)
|
||||
|
||||
gathered_tensors = dist_group.all_gatherv(
|
||||
tensors_to_gather,
|
||||
dim=0,
|
||||
sizes=sizes,
|
||||
)
|
||||
return hidden_states, router_logits
|
||||
|
||||
if extra_tensors is not None:
|
||||
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
|
||||
return gathered_tensors[0], gathered_tensors[1]
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
@ -204,6 +221,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -251,6 +269,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import threading
|
||||
from typing import Any
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
import torch
|
||||
@ -68,7 +69,11 @@ class All2AllManagerBase:
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
):
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> Any:
|
||||
# Subclasses should either:
|
||||
# - implement handling for extra_tensors, or
|
||||
# - raise a clear error if extra_tensors is not supported.
|
||||
raise NotImplementedError
|
||||
|
||||
def set_num_sms(self, num_sms: int):
|
||||
|
||||
@ -318,17 +318,23 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
return output_list
|
||||
|
||||
def dispatch(
|
||||
def dispatch( # type: ignore[override]
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
assert self.all2all_manager is not None
|
||||
hidden_states, router_logits = self.all2all_manager.dispatch(
|
||||
hidden_states, router_logits, is_sequence_parallel
|
||||
return self.all2all_manager.dispatch(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
is_sequence_parallel,
|
||||
extra_tensors, # type: ignore[call-arg]
|
||||
)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(
|
||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||
|
||||
@ -1007,10 +1007,17 @@ class GroupCoordinator:
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_sequence_parallel: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
extra_tensors: list[torch.Tensor] | None = None,
|
||||
) -> (
|
||||
tuple[torch.Tensor, torch.Tensor]
|
||||
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||
):
|
||||
if self.device_communicator is not None:
|
||||
return self.device_communicator.dispatch(
|
||||
hidden_states, router_logits, is_sequence_parallel
|
||||
return self.device_communicator.dispatch( # type: ignore[call-arg]
|
||||
hidden_states,
|
||||
router_logits,
|
||||
is_sequence_parallel,
|
||||
extra_tensors,
|
||||
)
|
||||
else:
|
||||
return hidden_states, router_logits
|
||||
|
||||
@ -71,6 +71,18 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
"implementation based on the prepare_finalize"
|
||||
)
|
||||
|
||||
def prepare_dp_allgather_tensor(
|
||||
self,
|
||||
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
"""Hook to prepare tensors and extra tensors for DP allgather + EP dispatch."""
|
||||
raise NotImplementedError(
|
||||
"Method 'prepare_dp_allgather_tensor' is not implemented in "
|
||||
f"{self.__class__.__name__}."
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
|
||||
@ -44,6 +44,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
||||
is_flashinfer_supporting_global_sf,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
|
||||
from vllm.utils.math_utils import cdiv, round_up
|
||||
from vllm.utils.torch_utils import (
|
||||
aux_stream,
|
||||
@ -1933,10 +1934,46 @@ class FusedMoE(CustomOp):
|
||||
)
|
||||
|
||||
with sp_ctx:
|
||||
extra_tensors = None
|
||||
if do_naive_dispatch_combine:
|
||||
hidden_states_combined, router_logits = get_ep_group().dispatch(
|
||||
hidden_states, router_logits, self.is_sequence_parallel
|
||||
# Avoid circular import
|
||||
from vllm.model_executor.layers.quantization.modelopt import (
|
||||
ModelOptNvFp4FusedMoE,
|
||||
)
|
||||
|
||||
post_quant_allgather = (
|
||||
has_flashinfer_trtllm_fused_moe()
|
||||
and self.quant_method is not None
|
||||
and self.dp_size > 1
|
||||
and self.use_ep
|
||||
and isinstance(self.quant_method, ModelOptNvFp4FusedMoE)
|
||||
)
|
||||
if post_quant_allgather:
|
||||
hidden_states_to_dispatch, extra_tensors = (
|
||||
self.quant_method.prepare_dp_allgather_tensor(
|
||||
self, hidden_states, router_logits
|
||||
)
|
||||
)
|
||||
else:
|
||||
hidden_states_to_dispatch = hidden_states
|
||||
|
||||
dispatch_res = get_ep_group().dispatch(
|
||||
hidden_states_to_dispatch,
|
||||
router_logits,
|
||||
self.is_sequence_parallel,
|
||||
extra_tensors=extra_tensors,
|
||||
)
|
||||
if extra_tensors is not None:
|
||||
hidden_states_combined, router_logits, extra_tensors_combined = (
|
||||
dispatch_res
|
||||
)
|
||||
hidden_states_combined = (
|
||||
hidden_states_combined,
|
||||
extra_tensors_combined[0],
|
||||
)
|
||||
else:
|
||||
hidden_states_combined, router_logits = dispatch_res
|
||||
|
||||
# Run shared experts before matrix multiply.
|
||||
# because matrix multiply maybe modify the hidden_states.
|
||||
if has_separate_shared_experts and not use_shared_experts_stream:
|
||||
|
||||
@ -1522,6 +1522,24 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
w2_blockscale_swizzled, requires_grad=False
|
||||
)
|
||||
|
||||
def prepare_dp_allgather_tensor(
|
||||
self,
|
||||
layer: FusedMoE,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
||||
"""Optionally prepare extra tensors to carry through DP allgather/EP."""
|
||||
import flashinfer
|
||||
|
||||
a1_gscale = layer.w13_input_scale_quant
|
||||
hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
|
||||
hidden_states,
|
||||
a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
extra_tensors: list[torch.Tensor] = [hidden_states_sf]
|
||||
return hidden_states_fp4, extra_tensors
|
||||
|
||||
def get_fused_moe_quant_config(
|
||||
self, layer: torch.nn.Module
|
||||
) -> FusedMoEQuantConfig | None:
|
||||
@ -1576,8 +1594,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
e_score_correction_bias=layer.e_score_correction_bias,
|
||||
)
|
||||
|
||||
# Hidden_states in select_experts is only used to extract metadata
|
||||
if isinstance(x, tuple):
|
||||
x_routing, _ = x
|
||||
else:
|
||||
x_routing = x
|
||||
topk_weights, topk_ids, _ = layer.select_experts(
|
||||
hidden_states=x,
|
||||
hidden_states=x_routing,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
|
||||
|
||||
@ -238,7 +238,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
|
||||
|
||||
def flashinfer_trtllm_fp4_moe(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
@ -269,12 +269,16 @@ def flashinfer_trtllm_fp4_moe(
|
||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||
|
||||
# Quantize input to FP4
|
||||
a1_gscale = layer.w13_input_scale_quant
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||
x,
|
||||
a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
if isinstance(x, tuple):
|
||||
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
|
||||
else:
|
||||
# hidden_states is the already quantized
|
||||
a1_gscale = layer.w13_input_scale_quant
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||
x,
|
||||
a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
|
||||
# Determine routing method type
|
||||
use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
|
||||
@ -360,13 +364,17 @@ def flashinfer_trtllm_fp4_routed_moe(
|
||||
torch.bfloat16
|
||||
).view(torch.int16)
|
||||
|
||||
# Quantize input to FP4
|
||||
a1_gscale = layer.w13_input_scale_quant
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||
x,
|
||||
a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
if isinstance(x, tuple):
|
||||
# Hidden_states is the already quantized
|
||||
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
|
||||
else:
|
||||
# Quantize input to FP4
|
||||
a1_gscale = layer.w13_input_scale_quant
|
||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||
x,
|
||||
a1_gscale,
|
||||
is_sf_swizzled_layout=False,
|
||||
)
|
||||
|
||||
# Call TRT-LLM FP4 block-scale MoE kernel
|
||||
out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
|
||||
|
||||
@ -184,6 +184,23 @@ def has_flashinfer_cutedsl() -> bool:
|
||||
)
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_trtllm_fused_moe() -> bool:
|
||||
"""Return `True` if FlashInfer TRTLLM fused MoE is available."""
|
||||
if not has_flashinfer_moe():
|
||||
return False
|
||||
required_functions = [
|
||||
("flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"),
|
||||
("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"),
|
||||
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
|
||||
]
|
||||
for module_name, attr_name in required_functions:
|
||||
mod = _get_submodule(module_name)
|
||||
if not mod or not hasattr(mod, attr_name):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_cutlass_fused_moe() -> bool:
|
||||
"""Return `True` if FlashInfer CUTLASS fused MoE is available."""
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user