mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-22 00:31:19 +08:00
[Perf] Do FP4 quant before All gather on flashinfer trtllmgen MOE (#30014)
Signed-off-by: jiahanc <173873397+jiahanc@users.noreply.github.com>
This commit is contained in:
parent
f21f5ea38c
commit
254a7f8fd6
@ -64,7 +64,12 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
is_sequence_parallel: bool = False,
|
is_sequence_parallel: bool = False,
|
||||||
|
extra_tensors: list[torch.Tensor] | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
if extra_tensors is not None:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"extra_tensors is not supported for NaiveAll2AllManager"
|
||||||
|
)
|
||||||
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
sp_size = self.tp_group.world_size if is_sequence_parallel else 1
|
||||||
dp_metadata = get_forward_context().dp_metadata
|
dp_metadata = get_forward_context().dp_metadata
|
||||||
assert dp_metadata is not None
|
assert dp_metadata is not None
|
||||||
@ -76,6 +81,7 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
|||||||
router_logits = self.naive_multicast(
|
router_logits = self.naive_multicast(
|
||||||
router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel
|
router_logits, cu_tokens_across_sp_cpu, is_sequence_parallel
|
||||||
)
|
)
|
||||||
|
|
||||||
return hidden_states, router_logits
|
return hidden_states, router_logits
|
||||||
|
|
||||||
def combine(
|
def combine(
|
||||||
@ -113,7 +119,11 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
is_sequence_parallel: bool = False,
|
is_sequence_parallel: bool = False,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
extra_tensors: list[torch.Tensor] | None = None,
|
||||||
|
) -> (
|
||||||
|
tuple[torch.Tensor, torch.Tensor]
|
||||||
|
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Gather hidden_states and router_logits from all dp ranks.
|
Gather hidden_states and router_logits from all dp ranks.
|
||||||
"""
|
"""
|
||||||
@ -121,15 +131,22 @@ class AgRsAll2AllManager(All2AllManagerBase):
|
|||||||
assert dp_metadata is not None
|
assert dp_metadata is not None
|
||||||
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
|
sizes = dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||||
assert sizes is not None
|
assert sizes is not None
|
||||||
|
|
||||||
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
dist_group = get_ep_group() if is_sequence_parallel else get_dp_group()
|
||||||
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
|
assert sizes[dist_group.rank_in_group] == hidden_states.shape[0]
|
||||||
hidden_states, router_logits = dist_group.all_gatherv(
|
|
||||||
[hidden_states, router_logits],
|
tensors_to_gather = [hidden_states, router_logits]
|
||||||
|
if extra_tensors is not None:
|
||||||
|
tensors_to_gather.extend(extra_tensors)
|
||||||
|
|
||||||
|
gathered_tensors = dist_group.all_gatherv(
|
||||||
|
tensors_to_gather,
|
||||||
dim=0,
|
dim=0,
|
||||||
sizes=sizes,
|
sizes=sizes,
|
||||||
)
|
)
|
||||||
return hidden_states, router_logits
|
|
||||||
|
if extra_tensors is not None:
|
||||||
|
return (gathered_tensors[0], gathered_tensors[1], gathered_tensors[2:])
|
||||||
|
return gathered_tensors[0], gathered_tensors[1]
|
||||||
|
|
||||||
def combine(
|
def combine(
|
||||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||||
@ -204,6 +221,7 @@ class PPLXAll2AllManager(All2AllManagerBase):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
is_sequence_parallel: bool = False,
|
is_sequence_parallel: bool = False,
|
||||||
|
extra_tensors: list[torch.Tensor] | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@ -251,6 +269,7 @@ class DeepEPAll2AllManagerBase(All2AllManagerBase):
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
is_sequence_parallel: bool = False,
|
is_sequence_parallel: bool = False,
|
||||||
|
extra_tensors: list[torch.Tensor] | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import threading
|
import threading
|
||||||
|
from typing import Any
|
||||||
from weakref import WeakValueDictionary
|
from weakref import WeakValueDictionary
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -68,7 +69,11 @@ class All2AllManagerBase:
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
is_sequence_parallel: bool = False,
|
is_sequence_parallel: bool = False,
|
||||||
):
|
extra_tensors: list[torch.Tensor] | None = None,
|
||||||
|
) -> Any:
|
||||||
|
# Subclasses should either:
|
||||||
|
# - implement handling for extra_tensors, or
|
||||||
|
# - raise a clear error if extra_tensors is not supported.
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def set_num_sms(self, num_sms: int):
|
def set_num_sms(self, num_sms: int):
|
||||||
|
|||||||
@ -318,17 +318,23 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
|||||||
|
|
||||||
return output_list
|
return output_list
|
||||||
|
|
||||||
def dispatch(
|
def dispatch( # type: ignore[override]
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
is_sequence_parallel: bool = False,
|
is_sequence_parallel: bool = False,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
extra_tensors: list[torch.Tensor] | None = None,
|
||||||
|
) -> (
|
||||||
|
tuple[torch.Tensor, torch.Tensor]
|
||||||
|
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||||
|
):
|
||||||
assert self.all2all_manager is not None
|
assert self.all2all_manager is not None
|
||||||
hidden_states, router_logits = self.all2all_manager.dispatch(
|
return self.all2all_manager.dispatch(
|
||||||
hidden_states, router_logits, is_sequence_parallel
|
hidden_states,
|
||||||
|
router_logits,
|
||||||
|
is_sequence_parallel,
|
||||||
|
extra_tensors, # type: ignore[call-arg]
|
||||||
)
|
)
|
||||||
return hidden_states, router_logits
|
|
||||||
|
|
||||||
def combine(
|
def combine(
|
||||||
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
self, hidden_states: torch.Tensor, is_sequence_parallel: bool = False
|
||||||
|
|||||||
@ -1007,10 +1007,17 @@ class GroupCoordinator:
|
|||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
is_sequence_parallel: bool = False,
|
is_sequence_parallel: bool = False,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
extra_tensors: list[torch.Tensor] | None = None,
|
||||||
|
) -> (
|
||||||
|
tuple[torch.Tensor, torch.Tensor]
|
||||||
|
| tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
|
||||||
|
):
|
||||||
if self.device_communicator is not None:
|
if self.device_communicator is not None:
|
||||||
return self.device_communicator.dispatch(
|
return self.device_communicator.dispatch( # type: ignore[call-arg]
|
||||||
hidden_states, router_logits, is_sequence_parallel
|
hidden_states,
|
||||||
|
router_logits,
|
||||||
|
is_sequence_parallel,
|
||||||
|
extra_tensors,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return hidden_states, router_logits
|
return hidden_states, router_logits
|
||||||
|
|||||||
@ -71,6 +71,18 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|||||||
"implementation based on the prepare_finalize"
|
"implementation based on the prepare_finalize"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def prepare_dp_allgather_tensor(
|
||||||
|
self,
|
||||||
|
layer: "FusedMoE", # type: ignore[name-defined] # noqa: F821
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
||||||
|
"""Hook to prepare tensors and extra tensors for DP allgather + EP dispatch."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Method 'prepare_dp_allgather_tensor' is not implemented in "
|
||||||
|
f"{self.__class__.__name__}."
|
||||||
|
)
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_fused_moe_quant_config(
|
def get_fused_moe_quant_config(
|
||||||
self, layer: torch.nn.Module
|
self, layer: torch.nn.Module
|
||||||
|
|||||||
@ -44,6 +44,7 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import (
|
|||||||
is_flashinfer_supporting_global_sf,
|
is_flashinfer_supporting_global_sf,
|
||||||
)
|
)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe
|
||||||
from vllm.utils.math_utils import cdiv, round_up
|
from vllm.utils.math_utils import cdiv, round_up
|
||||||
from vllm.utils.torch_utils import (
|
from vllm.utils.torch_utils import (
|
||||||
aux_stream,
|
aux_stream,
|
||||||
@ -1933,10 +1934,46 @@ class FusedMoE(CustomOp):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with sp_ctx:
|
with sp_ctx:
|
||||||
|
extra_tensors = None
|
||||||
if do_naive_dispatch_combine:
|
if do_naive_dispatch_combine:
|
||||||
hidden_states_combined, router_logits = get_ep_group().dispatch(
|
# Avoid circular import
|
||||||
hidden_states, router_logits, self.is_sequence_parallel
|
from vllm.model_executor.layers.quantization.modelopt import (
|
||||||
|
ModelOptNvFp4FusedMoE,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
post_quant_allgather = (
|
||||||
|
has_flashinfer_trtllm_fused_moe()
|
||||||
|
and self.quant_method is not None
|
||||||
|
and self.dp_size > 1
|
||||||
|
and self.use_ep
|
||||||
|
and isinstance(self.quant_method, ModelOptNvFp4FusedMoE)
|
||||||
|
)
|
||||||
|
if post_quant_allgather:
|
||||||
|
hidden_states_to_dispatch, extra_tensors = (
|
||||||
|
self.quant_method.prepare_dp_allgather_tensor(
|
||||||
|
self, hidden_states, router_logits
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states_to_dispatch = hidden_states
|
||||||
|
|
||||||
|
dispatch_res = get_ep_group().dispatch(
|
||||||
|
hidden_states_to_dispatch,
|
||||||
|
router_logits,
|
||||||
|
self.is_sequence_parallel,
|
||||||
|
extra_tensors=extra_tensors,
|
||||||
|
)
|
||||||
|
if extra_tensors is not None:
|
||||||
|
hidden_states_combined, router_logits, extra_tensors_combined = (
|
||||||
|
dispatch_res
|
||||||
|
)
|
||||||
|
hidden_states_combined = (
|
||||||
|
hidden_states_combined,
|
||||||
|
extra_tensors_combined[0],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states_combined, router_logits = dispatch_res
|
||||||
|
|
||||||
# Run shared experts before matrix multiply.
|
# Run shared experts before matrix multiply.
|
||||||
# because matrix multiply maybe modify the hidden_states.
|
# because matrix multiply maybe modify the hidden_states.
|
||||||
if has_separate_shared_experts and not use_shared_experts_stream:
|
if has_separate_shared_experts and not use_shared_experts_stream:
|
||||||
|
|||||||
@ -1522,6 +1522,24 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
w2_blockscale_swizzled, requires_grad=False
|
w2_blockscale_swizzled, requires_grad=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def prepare_dp_allgather_tensor(
|
||||||
|
self,
|
||||||
|
layer: FusedMoE,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
) -> tuple[torch.Tensor, list[torch.Tensor]]:
|
||||||
|
"""Optionally prepare extra tensors to carry through DP allgather/EP."""
|
||||||
|
import flashinfer
|
||||||
|
|
||||||
|
a1_gscale = layer.w13_input_scale_quant
|
||||||
|
hidden_states_fp4, hidden_states_sf = flashinfer.fp4_quantize(
|
||||||
|
hidden_states,
|
||||||
|
a1_gscale,
|
||||||
|
is_sf_swizzled_layout=False,
|
||||||
|
)
|
||||||
|
extra_tensors: list[torch.Tensor] = [hidden_states_sf]
|
||||||
|
return hidden_states_fp4, extra_tensors
|
||||||
|
|
||||||
def get_fused_moe_quant_config(
|
def get_fused_moe_quant_config(
|
||||||
self, layer: torch.nn.Module
|
self, layer: torch.nn.Module
|
||||||
) -> FusedMoEQuantConfig | None:
|
) -> FusedMoEQuantConfig | None:
|
||||||
@ -1576,8 +1594,13 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
|||||||
e_score_correction_bias=layer.e_score_correction_bias,
|
e_score_correction_bias=layer.e_score_correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Hidden_states in select_experts is only used to extract metadata
|
||||||
|
if isinstance(x, tuple):
|
||||||
|
x_routing, _ = x
|
||||||
|
else:
|
||||||
|
x_routing = x
|
||||||
topk_weights, topk_ids, _ = layer.select_experts(
|
topk_weights, topk_ids, _ = layer.select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x_routing,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -238,7 +238,7 @@ def prepare_static_weights_for_trtllm_fp4_moe(
|
|||||||
|
|
||||||
def flashinfer_trtllm_fp4_moe(
|
def flashinfer_trtllm_fp4_moe(
|
||||||
layer: torch.nn.Module,
|
layer: torch.nn.Module,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
global_num_experts: int,
|
global_num_experts: int,
|
||||||
@ -269,12 +269,16 @@ def flashinfer_trtllm_fp4_moe(
|
|||||||
from vllm.model_executor.models.llama4 import Llama4MoE
|
from vllm.model_executor.models.llama4 import Llama4MoE
|
||||||
|
|
||||||
# Quantize input to FP4
|
# Quantize input to FP4
|
||||||
a1_gscale = layer.w13_input_scale_quant
|
if isinstance(x, tuple):
|
||||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
|
||||||
x,
|
else:
|
||||||
a1_gscale,
|
# hidden_states is the already quantized
|
||||||
is_sf_swizzled_layout=False,
|
a1_gscale = layer.w13_input_scale_quant
|
||||||
)
|
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||||
|
x,
|
||||||
|
a1_gscale,
|
||||||
|
is_sf_swizzled_layout=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Determine routing method type
|
# Determine routing method type
|
||||||
use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
|
use_llama4_routing = custom_routing_function is Llama4MoE.custom_routing_function
|
||||||
@ -360,13 +364,17 @@ def flashinfer_trtllm_fp4_routed_moe(
|
|||||||
torch.bfloat16
|
torch.bfloat16
|
||||||
).view(torch.int16)
|
).view(torch.int16)
|
||||||
|
|
||||||
# Quantize input to FP4
|
if isinstance(x, tuple):
|
||||||
a1_gscale = layer.w13_input_scale_quant
|
# Hidden_states is the already quantized
|
||||||
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
hidden_states_fp4, hidden_states_scale_linear_fp4 = x
|
||||||
x,
|
else:
|
||||||
a1_gscale,
|
# Quantize input to FP4
|
||||||
is_sf_swizzled_layout=False,
|
a1_gscale = layer.w13_input_scale_quant
|
||||||
)
|
(hidden_states_fp4, hidden_states_scale_linear_fp4) = flashinfer.fp4_quantize(
|
||||||
|
x,
|
||||||
|
a1_gscale,
|
||||||
|
is_sf_swizzled_layout=False,
|
||||||
|
)
|
||||||
|
|
||||||
# Call TRT-LLM FP4 block-scale MoE kernel
|
# Call TRT-LLM FP4 block-scale MoE kernel
|
||||||
out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
|
out = flashinfer.fused_moe.trtllm_fp4_block_scale_routed_moe(
|
||||||
|
|||||||
@ -184,6 +184,23 @@ def has_flashinfer_cutedsl() -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.cache
|
||||||
|
def has_flashinfer_trtllm_fused_moe() -> bool:
|
||||||
|
"""Return `True` if FlashInfer TRTLLM fused MoE is available."""
|
||||||
|
if not has_flashinfer_moe():
|
||||||
|
return False
|
||||||
|
required_functions = [
|
||||||
|
("flashinfer.fused_moe", "trtllm_fp8_block_scale_moe"),
|
||||||
|
("flashinfer.fused_moe", "trtllm_fp8_per_tensor_scale_moe"),
|
||||||
|
("flashinfer.fused_moe", "trtllm_fp4_block_scale_moe"),
|
||||||
|
]
|
||||||
|
for module_name, attr_name in required_functions:
|
||||||
|
mod = _get_submodule(module_name)
|
||||||
|
if not mod or not hasattr(mod, attr_name):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
@functools.cache
|
@functools.cache
|
||||||
def has_flashinfer_cutlass_fused_moe() -> bool:
|
def has_flashinfer_cutlass_fused_moe() -> bool:
|
||||||
"""Return `True` if FlashInfer CUTLASS fused MoE is available."""
|
"""Return `True` if FlashInfer CUTLASS fused MoE is available."""
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user