Support mnnvl all2allv from Flashinfer (#21003)

Signed-off-by: Shu Wang <shuw@nvidia.com>
Signed-off-by: Shu Wang. <shuw@nvidia.com>
Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com>
Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com>
Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
Shu Wang 2025-09-24 13:38:16 -05:00 committed by GitHub
parent 2dda3e35d0
commit 54e42b72db
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 410 additions and 40 deletions

View File

@ -222,7 +222,8 @@ if (has_flashinfer_cutlass_fused_moe()
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
FlashInferExperts)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize)
FlashInferCutlassMoEPrepareAndFinalize,
create_flashinfer_prepare_finalize)
register_prepare_and_finalize(
FlashInferCutlassMoEPrepareAndFinalize,
@ -373,7 +374,7 @@ def make_prepare_finalize(
assert prepare_finalize is not None
return prepare_finalize
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
return FlashInferCutlassMoEPrepareAndFinalize(
return create_flashinfer_prepare_finalize(
use_dp=moe.moe_parallel_config.dp_size > 1)
else:
return MoEPrepareAndFinalizeNoEP()

View File

@ -10,9 +10,15 @@ from vllm.distributed import get_dp_group
from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.utils import has_deep_ep, has_pplx
from vllm.utils.flashinfer import has_flashinfer_all2all
from .base_device_communicator import All2AllManagerBase, Cache
if has_flashinfer_all2all():
from flashinfer.comm import Mapping
from flashinfer.comm.mnnvl import MnnvlConfig
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
logger = init_logger(__name__)
@ -47,24 +53,22 @@ class NaiveAll2AllManager(All2AllManagerBase):
def dispatch(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states, router_logits = get_dp_group().all_gatherv(
[hidden_states, router_logits],
dim=0,
sizes=sizes,
)
hidden_states = self.naive_multicast(hidden_states,
cu_tokens_across_dp_cpu)
router_logits = self.naive_multicast(router_logits,
cu_tokens_across_dp_cpu)
return hidden_states, router_logits
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
cu_tokens_across_dp_cpu = get_forward_context(
).dp_metadata.cu_tokens_across_dp_cpu
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
self.dp_rank - 1]
end = cu_tokens_across_dp_cpu[self.dp_rank]
all_hidden_states = self.dp_group.all_reduce(hidden_states)
hidden_states = all_hidden_states[start:end, :]
sizes = get_forward_context(
).dp_metadata.get_chunk_sizes_across_dp_rank()
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
dim=0,
sizes=sizes)
return hidden_states
def destroy(self):
@ -300,4 +304,95 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
# DeepEP LL uses RDMA so no SMs are used for communication
def max_sms_used(self) -> Optional[int]:
return 0
return 0
class FlashInferAllToAllManager(All2AllManagerBase):
"""
All2All communication based on flashinfer kernels.
"""
def __init__(self, cpu_group):
assert has_flashinfer_all2all(
), "flashinfer all2all module not found. Please install/check flashinfer" # noqa
super().__init__(cpu_group)
logger.debug(
"Initialize for flashinfer All2All "
"rank=%d, world size=%d", self.rank, self.world_size)
self.initialized = False
self.alltoall_info = None
def initialize(
self,
world_size: int,
rank: int,
gpus_per_node: int,
):
"""Initialize workspace"""
if self.initialized:
return
self.cleanup()
logger.debug("making map: "
"rank=%d, world size=%d", rank, world_size)
self.mapping = Mapping(
world_size,
rank,
gpus_per_node,
tp_size=world_size,
)
from vllm.distributed.device_communicators.mnnvl_compat import (
CustomCommunicator)
dp_config = MnnvlConfig(
comm_backend=CustomCommunicator(get_dp_group().cpu_group),
fabric_page_size=1 << 29, # 512MB
allocation_granularity=0 # Auto-detect
)
self.workspace_tensor = MnnvlMoe.get_moe_workspaces(
self.mapping, dp_config)
self.prepare_workspace_tensor = MnnvlMoe.get_moe_prepare_workspace(
self.mapping, dp_config)
self.world_size = world_size
self.rank = rank
self.gpus_per_node = gpus_per_node
self.initialized = True
logger.info("FlashInfer All2All initialized for rank %s, size %s",
rank, world_size)
def ensure_alltoall_workspace_initialized(self):
"""Ensure workspace is initialized"""
if not has_flashinfer_all2all():
return False
if self.world_size <= 1:
return False
if not self.initialized:
self.initialize(
world_size=self.world_size,
rank=self.rank,
gpus_per_node=torch.cuda.device_count,
)
return self.initialized
def get_handle(self, kwargs):
return self
def cleanup(self):
"""Clean up workspace"""
if self.initialized and self.workspace_tensor is not None \
and self.prepare_workspace_tensor is not None:
try:
del self.workspace_tensor
del self.prepare_workspace_tensor
except Exception as e:
logger.warning("Failed to cleanup FlashInfer workspace: %s", e)
finally:
self.workspace_tensor = None
self.prepare_workspace_tensor = None
self.mapping = None
self.initialized = False

View File

@ -114,6 +114,11 @@ class CudaCommunicator(DeviceCommunicatorBase):
from .all2all import DeepEPLLAll2AllManager
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
logger.info("Using DeepEP Low-Latency all2all manager.")
elif all2all_backend == "flashinfer_all2allv":
from .all2all import FlashInferAllToAllManager
self.all2all_manager = FlashInferAllToAllManager(
self.cpu_group)
logger.info("Using Flashinfer all2allv manager.")
else:
raise ValueError(f"Unknown all2all backend: {all2all_backend}")

View File

@ -0,0 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch.distributed as dist
from flashinfer.comm.mnnvl import CommBackend as CommBackend
from vllm.utils.flashinfer import has_flashinfer_all2all
assert has_flashinfer_all2all(), "Flashinfer alltoallv module cannot be found"
class CustomCommunicator(CommBackend):
def __init__(self, group):
self._group = group
def Get_rank(self) -> int:
return self._group.rank()
def Get_size(self) -> int:
return self._group.size()
def allgather(self, data: int):
gathered = [None] * self.Get_size()
dist.all_gather_object(gathered, data, group=self._group)
return gathered
def Split(self, color: int, key: int) -> 'CustomCommunicator':
return self

View File

@ -156,7 +156,8 @@ if TYPE_CHECKING:
VLLM_ALL2ALL_BACKEND: Literal["naive", "pplx",
"deepep_high_throughput",
"deepep_low_latency",
"allgather_reducescatter"] = \
"allgather_reducescatter",
"flashinfer_all2allv"] = \
"allgather_reducescatter"
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
@ -1209,12 +1210,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
# - "pplx": use pplx kernels
# - "deepep_high_throughput", use deepep high-throughput kernels
# - "deepep_low_latency", use deepep low-latency kernels
# - "flashinfer_all2allv", use flashinfer alltoallv kernels for mnnvl
"VLLM_ALL2ALL_BACKEND":
env_with_choices("VLLM_ALL2ALL_BACKEND", "allgather_reducescatter",
["naive", "pplx",
"deepep_high_throughput",
"deepep_low_latency",
"allgather_reducescatter"]),
"allgather_reducescatter",
"flashinfer_all2allv"]),
# Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support.
# Both require compute capability 10.0 or above.

View File

@ -8,7 +8,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize)
create_flashinfer_prepare_finalize)
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceNoOP)
from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe,
@ -108,7 +108,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
of each tuple must be the number of tokens.
"""
aq_m, aq_n = aq.shape
workspace2 = ()
workspace2 = (0, )
output_shape = (aq_m, aq_n * 2) if self.quant_dtype != \
torch.float8_e4m3fn else (aq_m, aq_n)
workspace_dtype = a.dtype
@ -192,9 +192,8 @@ def flashinfer_cutlass_moe_fp4(
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
fused_experts = mk.FusedMoEModularKernel(
FlashInferCutlassMoEPrepareAndFinalize(use_dp=False),
create_flashinfer_prepare_finalize(use_dp=False),
FlashInferExperts(
out_dtype=hidden_states.dtype,
quant_config=quant_config,

View File

@ -5,7 +5,9 @@ from typing import Optional
import torch
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm.distributed import get_dp_group
from vllm.distributed import get_dp_group, get_ep_group
from vllm.distributed.device_communicators.base_device_communicator import (
All2AllManagerBase)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.utils import (
@ -18,6 +20,7 @@ def get_local_sizes():
class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""Base class for FlashInfer MoE prepare and finalize operations."""
def __init__(
self,
@ -42,6 +45,39 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def num_dispatchers(self) -> int:
return self.num_dispatchers_
def _apply_router_weight_on_input(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
) -> None:
"""Apply router weight on input if needed."""
if apply_router_weight_on_input:
topk = topk_ids.size(1)
assert topk == 1, \
"apply_router_weight_on_input is only implemented for topk=1"
a1.mul_(topk_weights.to(a1.dtype))
class FlashInferAllToAllMoEPrepareAndFinalize(
FlashInferCutlassMoEPrepareAndFinalize):
"""FlashInfer implementation using AllToAll communication."""
def __init__(
self,
use_dp: bool,
num_dispatchers: int = 1,
):
super().__init__(use_dp, num_dispatchers)
self.alltoall_info = None
# Initialize all2all_manager only for DP case
self.all2all_manager = None
if self.use_dp:
self.all2all_manager = get_ep_group(
).device_communicator.all2all_manager
def prepare(
self,
a1: torch.Tensor,
@ -53,12 +89,84 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, \
"apply_router_weight_on_input is only implemented for topk=1"
a1.mul_(topk_weights.to(a1.dtype))
self._apply_router_weight_on_input(a1, topk_weights, topk_ids,
apply_router_weight_on_input)
if not self.use_dp:
# Non-DP case: standard quantization
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_gscale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
is_fp4_scale_swizzled=not self.use_dp,
)
else:
# DP case: use FlashInfer AllToAll
global_num_tokens_cpu = get_local_sizes()
top_k = topk_ids.size(1)
(self.alltoall_info, topk_ids, topk_weights, a1q,
a1q_scale) = flashinfer_alltoall_dispatch(
self.all2all_manager,
global_num_tokens_cpu,
a1,
quant_config.a1_gscale,
topk_ids,
topk_weights,
top_k,
num_experts,
quant_config,
)
return a1q, a1q_scale, None, topk_ids, topk_weights
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
if self.use_dp:
top_k = topk_ids.size(1)
token_count = output.shape[0]
fused_expert_output = flashinfer_alltoall_combine(
self.all2all_manager,
fused_expert_output,
top_k=top_k,
token_count=token_count,
alltoall_info=self.alltoall_info,
)
output.copy_(fused_expert_output)
class FlashInferAllGatherMoEPrepareAndFinalize(
FlashInferCutlassMoEPrepareAndFinalize):
def __init__(
self,
use_dp: bool,
num_dispatchers: int = 1,
):
super().__init__(use_dp, num_dispatchers)
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
) -> mk.PrepareResultType:
self._apply_router_weight_on_input(a1, topk_weights, topk_ids,
apply_router_weight_on_input)
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
@ -66,7 +174,6 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
# Swizzling after communication
is_fp4_scale_swizzled=not self.use_dp,
)
if self.use_dp:
@ -76,17 +183,117 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
dim=0,
sizes=get_local_sizes(),
)
a1_m, a1_n = a1q.shape
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
return a1q, a1q_scale, None, topk_ids, topk_weights
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None:
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
if self.use_dp:
fused_expert_output = get_dp_group().reduce_scatterv(
fused_expert_output, dim=0, sizes=get_local_sizes())
output.copy_(fused_expert_output)
def flashinfer_alltoall_dispatch(
all2all_manager: All2AllManagerBase,
global_num_tokens_cpu: list[int],
x: torch.Tensor,
gs: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
top_k: int,
num_experts: int,
quant_config: FusedMoEQuantConfig,
):
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
assert (all2all_manager.ensure_alltoall_workspace_initialized()
), "FlashInfer AllToAll workspace not available"
ep_rank = all2all_manager.rank
ep_size = all2all_manager.world_size
max_num_token = max(global_num_tokens_cpu
) if global_num_tokens_cpu is not None else x.shape[0]
alltoall_info, topk_ids, topk_weights, _ = (
MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
topk_ids,
topk_weights,
None,
all2all_manager.prepare_workspace,
max_num_token,
ep_rank,
ep_size,
num_experts,
num_experts,
top_k,
))
x, x_sf = moe_kernel_quantize_input(
x,
gs,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
is_fp4_scale_swizzled=False, # delay swizzle to after comm
)
x = MnnvlMoe.mnnvl_moe_alltoallv(
x,
alltoall_info,
all2all_manager.workspace_tensor,
ep_rank,
ep_size,
)
x_sf = MnnvlMoe.mnnvl_moe_alltoallv(
x_sf,
alltoall_info,
all2all_manager.workspace_tensor,
ep_rank,
ep_size,
)
x_sf = nvfp4_block_scale_interleave(x_sf)
return alltoall_info, topk_ids, topk_weights, x, x_sf
def flashinfer_alltoall_combine(
all2all_manager: All2AllManagerBase,
output: torch.Tensor,
top_k: int,
token_count: int,
alltoall_info,
):
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
assert (all2all_manager.ensure_alltoall_workspace_initialized()
), "FlashInfer AllToAll workspace not available"
return MnnvlMoe.mnnvl_moe_alltoallv_combine(
output,
alltoall_info,
all2all_manager.workspace_tensor,
ep_rank=all2all_manager.rank,
ep_size=all2all_manager.world_size,
top_k=top_k,
token_count=token_count,
)
def create_flashinfer_prepare_finalize(
use_dp: bool,
use_nvfp4: bool = False,
enable_alltoallv: bool = False,
) -> FlashInferCutlassMoEPrepareAndFinalize:
"""Factory function to create the appropriate FlashInfer implementation."""
if use_nvfp4:
if enable_alltoallv:
return FlashInferAllToAllMoEPrepareAndFinalize(use_dp)
else:
return FlashInferAllGatherMoEPrepareAndFinalize(use_dp)
# Fp8 only supports AllGather
return FlashInferAllGatherMoEPrepareAndFinalize(use_dp)

View File

@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize)
create_flashinfer_prepare_finalize)
from vllm.platforms import current_platform
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
@ -51,7 +51,9 @@ def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
moe: FusedMoEConfig) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1
return FlashInferCutlassMoEPrepareAndFinalize(use_dp)
enable_alltoallv = envs.VLLM_ALL2ALL_BACKEND == "flashinfer_all2allv"
return create_flashinfer_prepare_finalize(
use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv)
def select_nvfp4_gemm_impl(

View File

@ -13,7 +13,7 @@ from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
FlashInferExperts)
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
FlashInferCutlassMoEPrepareAndFinalize)
create_flashinfer_prepare_finalize)
logger = init_logger(__name__)
@ -173,7 +173,7 @@ def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
moe: Optional[FusedMoEConfig], ) -> mk.FusedMoEPrepareAndFinalize:
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
return FlashInferCutlassMoEPrepareAndFinalize(use_dp)
return create_flashinfer_prepare_finalize(use_dp)
def select_cutlass_fp8_gemm_impl(

View File

@ -97,6 +97,34 @@ autotune = _lazy_import_wrapper(
fallback_fn=lambda *args, **kwargs: contextlib.nullcontext())
@functools.cache
def has_flashinfer_comm() -> bool:
"""Return ``True`` if FlashInfer comm module is available."""
return has_flashinfer() and importlib.util.find_spec(
"flashinfer.comm") is not None
@functools.cache
def has_flashinfer_all2all() -> bool:
"""Return ``True`` if FlashInfer mnnvl all2all is available."""
if not has_flashinfer_comm():
return False
# Check if all required functions are available
required_functions = [
("flashinfer.comm", "Mapping"),
("flashinfer.comm.mnnvl", "MnnvlMemory"),
("flashinfer.comm.trtllm_alltoall", "MnnvlMoe"),
("flashinfer.comm.trtllm_alltoall", "MoEAlltoallInfo"),
]
for module_name, attr_name in required_functions:
mod = _get_submodule(module_name)
if not mod or not hasattr(mod, attr_name):
return False
return True
@functools.cache
def has_flashinfer_moe() -> bool:
"""Return ``True`` if FlashInfer MoE module is available."""
@ -402,6 +430,8 @@ __all__ = [
"trtllm_fp4_block_scale_moe",
"autotune",
"has_flashinfer_moe",
"has_flashinfer_comm",
"has_flashinfer_all2all",
"has_flashinfer_cutlass_fused_moe",
"has_nvidia_artifactory",
"supports_trtllm_attention",