mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 23:14:30 +08:00
Support mnnvl all2allv from Flashinfer (#21003)
Signed-off-by: Shu Wang <shuw@nvidia.com> Signed-off-by: Shu Wang. <shuw@nvidia.com> Signed-off-by: Tyler Michael Smith <tyler@neuralmagic.com> Signed-off-by: Tyler Michael Smith <tlrmchlsmth@gmail.com> Co-authored-by: Tyler Michael Smith <tyler@neuralmagic.com> Co-authored-by: Tyler Michael Smith <tlrmchlsmth@gmail.com>
This commit is contained in:
parent
2dda3e35d0
commit
54e42b72db
@ -222,7 +222,8 @@ if (has_flashinfer_cutlass_fused_moe()
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( # noqa: E501
|
||||
FlashInferExperts)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
FlashInferCutlassMoEPrepareAndFinalize)
|
||||
FlashInferCutlassMoEPrepareAndFinalize,
|
||||
create_flashinfer_prepare_finalize)
|
||||
|
||||
register_prepare_and_finalize(
|
||||
FlashInferCutlassMoEPrepareAndFinalize,
|
||||
@ -373,7 +374,7 @@ def make_prepare_finalize(
|
||||
assert prepare_finalize is not None
|
||||
return prepare_finalize
|
||||
elif prepare_finalize_type == FlashInferCutlassMoEPrepareAndFinalize:
|
||||
return FlashInferCutlassMoEPrepareAndFinalize(
|
||||
return create_flashinfer_prepare_finalize(
|
||||
use_dp=moe.moe_parallel_config.dp_size > 1)
|
||||
else:
|
||||
return MoEPrepareAndFinalizeNoEP()
|
||||
|
||||
@ -10,9 +10,15 @@ from vllm.distributed import get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import has_deep_ep, has_pplx
|
||||
from vllm.utils.flashinfer import has_flashinfer_all2all
|
||||
|
||||
from .base_device_communicator import All2AllManagerBase, Cache
|
||||
|
||||
if has_flashinfer_all2all():
|
||||
from flashinfer.comm import Mapping
|
||||
from flashinfer.comm.mnnvl import MnnvlConfig
|
||||
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -47,24 +53,22 @@ class NaiveAll2AllManager(All2AllManagerBase):
|
||||
|
||||
def dispatch(self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor):
|
||||
cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states, router_logits = get_dp_group().all_gatherv(
|
||||
[hidden_states, router_logits],
|
||||
dim=0,
|
||||
sizes=sizes,
|
||||
)
|
||||
|
||||
hidden_states = self.naive_multicast(hidden_states,
|
||||
cu_tokens_across_dp_cpu)
|
||||
router_logits = self.naive_multicast(router_logits,
|
||||
cu_tokens_across_dp_cpu)
|
||||
return hidden_states, router_logits
|
||||
|
||||
def combine(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
cu_tokens_across_dp_cpu = get_forward_context(
|
||||
).dp_metadata.cu_tokens_across_dp_cpu
|
||||
start = 0 if self.dp_rank == 0 else cu_tokens_across_dp_cpu[
|
||||
self.dp_rank - 1]
|
||||
end = cu_tokens_across_dp_cpu[self.dp_rank]
|
||||
|
||||
all_hidden_states = self.dp_group.all_reduce(hidden_states)
|
||||
hidden_states = all_hidden_states[start:end, :]
|
||||
sizes = get_forward_context(
|
||||
).dp_metadata.get_chunk_sizes_across_dp_rank()
|
||||
hidden_states = get_dp_group().reduce_scatterv(hidden_states,
|
||||
dim=0,
|
||||
sizes=sizes)
|
||||
return hidden_states
|
||||
|
||||
def destroy(self):
|
||||
@ -300,4 +304,95 @@ class DeepEPLLAll2AllManager(DeepEPAll2AllManagerBase):
|
||||
|
||||
# DeepEP LL uses RDMA so no SMs are used for communication
|
||||
def max_sms_used(self) -> Optional[int]:
|
||||
return 0
|
||||
return 0
|
||||
|
||||
|
||||
class FlashInferAllToAllManager(All2AllManagerBase):
|
||||
"""
|
||||
All2All communication based on flashinfer kernels.
|
||||
"""
|
||||
|
||||
def __init__(self, cpu_group):
|
||||
assert has_flashinfer_all2all(
|
||||
), "flashinfer all2all module not found. Please install/check flashinfer" # noqa
|
||||
super().__init__(cpu_group)
|
||||
logger.debug(
|
||||
"Initialize for flashinfer All2All "
|
||||
"rank=%d, world size=%d", self.rank, self.world_size)
|
||||
self.initialized = False
|
||||
self.alltoall_info = None
|
||||
|
||||
def initialize(
|
||||
self,
|
||||
world_size: int,
|
||||
rank: int,
|
||||
gpus_per_node: int,
|
||||
):
|
||||
"""Initialize workspace"""
|
||||
if self.initialized:
|
||||
return
|
||||
|
||||
self.cleanup()
|
||||
logger.debug("making map: "
|
||||
"rank=%d, world size=%d", rank, world_size)
|
||||
self.mapping = Mapping(
|
||||
world_size,
|
||||
rank,
|
||||
gpus_per_node,
|
||||
tp_size=world_size,
|
||||
)
|
||||
|
||||
from vllm.distributed.device_communicators.mnnvl_compat import (
|
||||
CustomCommunicator)
|
||||
dp_config = MnnvlConfig(
|
||||
comm_backend=CustomCommunicator(get_dp_group().cpu_group),
|
||||
fabric_page_size=1 << 29, # 512MB
|
||||
allocation_granularity=0 # Auto-detect
|
||||
)
|
||||
|
||||
self.workspace_tensor = MnnvlMoe.get_moe_workspaces(
|
||||
self.mapping, dp_config)
|
||||
self.prepare_workspace_tensor = MnnvlMoe.get_moe_prepare_workspace(
|
||||
self.mapping, dp_config)
|
||||
|
||||
self.world_size = world_size
|
||||
self.rank = rank
|
||||
self.gpus_per_node = gpus_per_node
|
||||
self.initialized = True
|
||||
|
||||
logger.info("FlashInfer All2All initialized for rank %s, size %s",
|
||||
rank, world_size)
|
||||
|
||||
def ensure_alltoall_workspace_initialized(self):
|
||||
"""Ensure workspace is initialized"""
|
||||
if not has_flashinfer_all2all():
|
||||
return False
|
||||
|
||||
if self.world_size <= 1:
|
||||
return False
|
||||
|
||||
if not self.initialized:
|
||||
self.initialize(
|
||||
world_size=self.world_size,
|
||||
rank=self.rank,
|
||||
gpus_per_node=torch.cuda.device_count,
|
||||
)
|
||||
return self.initialized
|
||||
|
||||
def get_handle(self, kwargs):
|
||||
return self
|
||||
|
||||
def cleanup(self):
|
||||
"""Clean up workspace"""
|
||||
if self.initialized and self.workspace_tensor is not None \
|
||||
and self.prepare_workspace_tensor is not None:
|
||||
try:
|
||||
del self.workspace_tensor
|
||||
del self.prepare_workspace_tensor
|
||||
except Exception as e:
|
||||
logger.warning("Failed to cleanup FlashInfer workspace: %s", e)
|
||||
finally:
|
||||
self.workspace_tensor = None
|
||||
self.prepare_workspace_tensor = None
|
||||
self.mapping = None
|
||||
self.initialized = False
|
||||
@ -114,6 +114,11 @@ class CudaCommunicator(DeviceCommunicatorBase):
|
||||
from .all2all import DeepEPLLAll2AllManager
|
||||
self.all2all_manager = DeepEPLLAll2AllManager(self.cpu_group)
|
||||
logger.info("Using DeepEP Low-Latency all2all manager.")
|
||||
elif all2all_backend == "flashinfer_all2allv":
|
||||
from .all2all import FlashInferAllToAllManager
|
||||
self.all2all_manager = FlashInferAllToAllManager(
|
||||
self.cpu_group)
|
||||
logger.info("Using Flashinfer all2allv manager.")
|
||||
else:
|
||||
raise ValueError(f"Unknown all2all backend: {all2all_backend}")
|
||||
|
||||
|
||||
28
vllm/distributed/device_communicators/mnnvl_compat.py
Normal file
28
vllm/distributed/device_communicators/mnnvl_compat.py
Normal file
@ -0,0 +1,28 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch.distributed as dist
|
||||
from flashinfer.comm.mnnvl import CommBackend as CommBackend
|
||||
|
||||
from vllm.utils.flashinfer import has_flashinfer_all2all
|
||||
|
||||
assert has_flashinfer_all2all(), "Flashinfer alltoallv module cannot be found"
|
||||
|
||||
|
||||
class CustomCommunicator(CommBackend):
|
||||
|
||||
def __init__(self, group):
|
||||
self._group = group
|
||||
|
||||
def Get_rank(self) -> int:
|
||||
return self._group.rank()
|
||||
|
||||
def Get_size(self) -> int:
|
||||
return self._group.size()
|
||||
|
||||
def allgather(self, data: int):
|
||||
gathered = [None] * self.Get_size()
|
||||
dist.all_gather_object(gathered, data, group=self._group)
|
||||
return gathered
|
||||
|
||||
def Split(self, color: int, key: int) -> 'CustomCommunicator':
|
||||
return self
|
||||
@ -156,7 +156,8 @@ if TYPE_CHECKING:
|
||||
VLLM_ALL2ALL_BACKEND: Literal["naive", "pplx",
|
||||
"deepep_high_throughput",
|
||||
"deepep_low_latency",
|
||||
"allgather_reducescatter"] = \
|
||||
"allgather_reducescatter",
|
||||
"flashinfer_all2allv"] = \
|
||||
"allgather_reducescatter"
|
||||
VLLM_MAX_TOKENS_PER_EXPERT_FP4_MOE: int = 163840
|
||||
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
|
||||
@ -1209,12 +1210,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# - "pplx": use pplx kernels
|
||||
# - "deepep_high_throughput", use deepep high-throughput kernels
|
||||
# - "deepep_low_latency", use deepep low-latency kernels
|
||||
# - "flashinfer_all2allv", use flashinfer alltoallv kernels for mnnvl
|
||||
"VLLM_ALL2ALL_BACKEND":
|
||||
env_with_choices("VLLM_ALL2ALL_BACKEND", "allgather_reducescatter",
|
||||
["naive", "pplx",
|
||||
"deepep_high_throughput",
|
||||
"deepep_low_latency",
|
||||
"allgather_reducescatter"]),
|
||||
"allgather_reducescatter",
|
||||
"flashinfer_all2allv"]),
|
||||
|
||||
# Flashinfer MoE backend for vLLM's fused Mixture-of-Experts support.
|
||||
# Both require compute capability 10.0 or above.
|
||||
|
||||
@ -8,7 +8,7 @@ import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
FlashInferCutlassMoEPrepareAndFinalize)
|
||||
create_flashinfer_prepare_finalize)
|
||||
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
|
||||
TopKWeightAndReduceNoOP)
|
||||
from vllm.utils.flashinfer import (flashinfer_cutlass_fused_moe,
|
||||
@ -108,7 +108,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
of each tuple must be the number of tokens.
|
||||
"""
|
||||
aq_m, aq_n = aq.shape
|
||||
workspace2 = ()
|
||||
workspace2 = (0, )
|
||||
output_shape = (aq_m, aq_n * 2) if self.quant_dtype != \
|
||||
torch.float8_e4m3fn else (aq_m, aq_n)
|
||||
workspace_dtype = a.dtype
|
||||
@ -192,9 +192,8 @@ def flashinfer_cutlass_moe_fp4(
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
|
||||
fused_experts = mk.FusedMoEModularKernel(
|
||||
FlashInferCutlassMoEPrepareAndFinalize(use_dp=False),
|
||||
create_flashinfer_prepare_finalize(use_dp=False),
|
||||
FlashInferExperts(
|
||||
out_dtype=hidden_states.dtype,
|
||||
quant_config=quant_config,
|
||||
|
||||
@ -5,7 +5,9 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.distributed import get_dp_group
|
||||
from vllm.distributed import get_dp_group, get_ep_group
|
||||
from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
All2AllManagerBase)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.utils import (
|
||||
@ -18,6 +20,7 @@ def get_local_sizes():
|
||||
|
||||
|
||||
class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
"""Base class for FlashInfer MoE prepare and finalize operations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -42,6 +45,39 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def num_dispatchers(self) -> int:
|
||||
return self.num_dispatchers_
|
||||
|
||||
def _apply_router_weight_on_input(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> None:
|
||||
"""Apply router weight on input if needed."""
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
assert topk == 1, \
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
a1.mul_(topk_weights.to(a1.dtype))
|
||||
|
||||
|
||||
class FlashInferAllToAllMoEPrepareAndFinalize(
|
||||
FlashInferCutlassMoEPrepareAndFinalize):
|
||||
"""FlashInfer implementation using AllToAll communication."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_dp: bool,
|
||||
num_dispatchers: int = 1,
|
||||
):
|
||||
super().__init__(use_dp, num_dispatchers)
|
||||
self.alltoall_info = None
|
||||
|
||||
# Initialize all2all_manager only for DP case
|
||||
self.all2all_manager = None
|
||||
if self.use_dp:
|
||||
self.all2all_manager = get_ep_group(
|
||||
).device_communicator.all2all_manager
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
@ -53,12 +89,84 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
|
||||
if apply_router_weight_on_input:
|
||||
topk = topk_ids.size(1)
|
||||
# TODO: this only works for topK=1, will need to update for topK>1
|
||||
assert topk == 1, \
|
||||
"apply_router_weight_on_input is only implemented for topk=1"
|
||||
a1.mul_(topk_weights.to(a1.dtype))
|
||||
self._apply_router_weight_on_input(a1, topk_weights, topk_ids,
|
||||
apply_router_weight_on_input)
|
||||
|
||||
if not self.use_dp:
|
||||
# Non-DP case: standard quantization
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
quant_config.a1_gscale,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=not self.use_dp,
|
||||
)
|
||||
else:
|
||||
# DP case: use FlashInfer AllToAll
|
||||
global_num_tokens_cpu = get_local_sizes()
|
||||
top_k = topk_ids.size(1)
|
||||
|
||||
(self.alltoall_info, topk_ids, topk_weights, a1q,
|
||||
a1q_scale) = flashinfer_alltoall_dispatch(
|
||||
self.all2all_manager,
|
||||
global_num_tokens_cpu,
|
||||
a1,
|
||||
quant_config.a1_gscale,
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
top_k,
|
||||
num_experts,
|
||||
quant_config,
|
||||
)
|
||||
|
||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
if self.use_dp:
|
||||
top_k = topk_ids.size(1)
|
||||
token_count = output.shape[0]
|
||||
fused_expert_output = flashinfer_alltoall_combine(
|
||||
self.all2all_manager,
|
||||
fused_expert_output,
|
||||
top_k=top_k,
|
||||
token_count=token_count,
|
||||
alltoall_info=self.alltoall_info,
|
||||
)
|
||||
output.copy_(fused_expert_output)
|
||||
|
||||
|
||||
class FlashInferAllGatherMoEPrepareAndFinalize(
|
||||
FlashInferCutlassMoEPrepareAndFinalize):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
use_dp: bool,
|
||||
num_dispatchers: int = 1,
|
||||
):
|
||||
super().__init__(use_dp, num_dispatchers)
|
||||
|
||||
def prepare(
|
||||
self,
|
||||
a1: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
apply_router_weight_on_input: bool,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
) -> mk.PrepareResultType:
|
||||
|
||||
self._apply_router_weight_on_input(a1, topk_weights, topk_ids,
|
||||
apply_router_weight_on_input)
|
||||
|
||||
a1q, a1q_scale = moe_kernel_quantize_input(
|
||||
a1,
|
||||
@ -66,7 +174,6 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
# Swizzling after communication
|
||||
is_fp4_scale_swizzled=not self.use_dp,
|
||||
)
|
||||
if self.use_dp:
|
||||
@ -76,17 +183,117 @@ class FlashInferCutlassMoEPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
dim=0,
|
||||
sizes=get_local_sizes(),
|
||||
)
|
||||
a1_m, a1_n = a1q.shape
|
||||
a1q_scale = nvfp4_block_scale_interleave(a1q_scale)
|
||||
|
||||
return a1q, a1q_scale, None, topk_ids, topk_weights
|
||||
|
||||
def finalize(self, output: torch.Tensor, fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce) -> None:
|
||||
def finalize(
|
||||
self,
|
||||
output: torch.Tensor,
|
||||
fused_expert_output: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
apply_router_weight_on_input: bool,
|
||||
weight_and_reduce_impl: mk.TopKWeightAndReduce,
|
||||
) -> None:
|
||||
|
||||
if self.use_dp:
|
||||
fused_expert_output = get_dp_group().reduce_scatterv(
|
||||
fused_expert_output, dim=0, sizes=get_local_sizes())
|
||||
output.copy_(fused_expert_output)
|
||||
|
||||
|
||||
def flashinfer_alltoall_dispatch(
|
||||
all2all_manager: All2AllManagerBase,
|
||||
global_num_tokens_cpu: list[int],
|
||||
x: torch.Tensor,
|
||||
gs: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
top_k: int,
|
||||
num_experts: int,
|
||||
quant_config: FusedMoEQuantConfig,
|
||||
):
|
||||
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
|
||||
assert (all2all_manager.ensure_alltoall_workspace_initialized()
|
||||
), "FlashInfer AllToAll workspace not available"
|
||||
|
||||
ep_rank = all2all_manager.rank
|
||||
ep_size = all2all_manager.world_size
|
||||
max_num_token = max(global_num_tokens_cpu
|
||||
) if global_num_tokens_cpu is not None else x.shape[0]
|
||||
alltoall_info, topk_ids, topk_weights, _ = (
|
||||
MnnvlMoe.mnnvl_moe_alltoallv_prepare_without_allgather(
|
||||
topk_ids,
|
||||
topk_weights,
|
||||
None,
|
||||
all2all_manager.prepare_workspace,
|
||||
max_num_token,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
num_experts,
|
||||
num_experts,
|
||||
top_k,
|
||||
))
|
||||
|
||||
x, x_sf = moe_kernel_quantize_input(
|
||||
x,
|
||||
gs,
|
||||
quant_config.quant_dtype,
|
||||
quant_config.per_act_token_quant,
|
||||
quant_config.block_shape,
|
||||
is_fp4_scale_swizzled=False, # delay swizzle to after comm
|
||||
)
|
||||
x = MnnvlMoe.mnnvl_moe_alltoallv(
|
||||
x,
|
||||
alltoall_info,
|
||||
all2all_manager.workspace_tensor,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
)
|
||||
|
||||
x_sf = MnnvlMoe.mnnvl_moe_alltoallv(
|
||||
x_sf,
|
||||
alltoall_info,
|
||||
all2all_manager.workspace_tensor,
|
||||
ep_rank,
|
||||
ep_size,
|
||||
)
|
||||
x_sf = nvfp4_block_scale_interleave(x_sf)
|
||||
return alltoall_info, topk_ids, topk_weights, x, x_sf
|
||||
|
||||
|
||||
def flashinfer_alltoall_combine(
|
||||
all2all_manager: All2AllManagerBase,
|
||||
output: torch.Tensor,
|
||||
top_k: int,
|
||||
token_count: int,
|
||||
alltoall_info,
|
||||
):
|
||||
from flashinfer.comm.trtllm_alltoall import MnnvlMoe
|
||||
assert (all2all_manager.ensure_alltoall_workspace_initialized()
|
||||
), "FlashInfer AllToAll workspace not available"
|
||||
return MnnvlMoe.mnnvl_moe_alltoallv_combine(
|
||||
output,
|
||||
alltoall_info,
|
||||
all2all_manager.workspace_tensor,
|
||||
ep_rank=all2all_manager.rank,
|
||||
ep_size=all2all_manager.world_size,
|
||||
top_k=top_k,
|
||||
token_count=token_count,
|
||||
)
|
||||
|
||||
|
||||
def create_flashinfer_prepare_finalize(
|
||||
use_dp: bool,
|
||||
use_nvfp4: bool = False,
|
||||
enable_alltoallv: bool = False,
|
||||
) -> FlashInferCutlassMoEPrepareAndFinalize:
|
||||
"""Factory function to create the appropriate FlashInfer implementation."""
|
||||
if use_nvfp4:
|
||||
if enable_alltoallv:
|
||||
return FlashInferAllToAllMoEPrepareAndFinalize(use_dp)
|
||||
else:
|
||||
return FlashInferAllGatherMoEPrepareAndFinalize(use_dp)
|
||||
# Fp8 only supports AllGather
|
||||
return FlashInferAllGatherMoEPrepareAndFinalize(use_dp)
|
||||
|
||||
@ -12,7 +12,7 @@ from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
FlashInferCutlassMoEPrepareAndFinalize)
|
||||
create_flashinfer_prepare_finalize)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
@ -51,7 +51,9 @@ def build_flashinfer_fp4_cutlass_moe_prepare_finalize(
|
||||
moe: FusedMoEConfig) -> mk.FusedMoEPrepareAndFinalize:
|
||||
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
|
||||
use_dp = moe.moe_parallel_config.dp_size > 1
|
||||
return FlashInferCutlassMoEPrepareAndFinalize(use_dp)
|
||||
enable_alltoallv = envs.VLLM_ALL2ALL_BACKEND == "flashinfer_all2allv"
|
||||
return create_flashinfer_prepare_finalize(
|
||||
use_dp=use_dp, use_nvfp4=True, enable_alltoallv=enable_alltoallv)
|
||||
|
||||
|
||||
def select_nvfp4_gemm_impl(
|
||||
|
||||
@ -13,7 +13,7 @@ from vllm.model_executor.layers.fused_moe.config import (FusedMoEConfig,
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import (
|
||||
FlashInferExperts)
|
||||
from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_prepare_finalize import ( # noqa: E501
|
||||
FlashInferCutlassMoEPrepareAndFinalize)
|
||||
create_flashinfer_prepare_finalize)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -173,7 +173,7 @@ def build_flashinfer_fp8_cutlass_moe_prepare_finalize(
|
||||
moe: Optional[FusedMoEConfig], ) -> mk.FusedMoEPrepareAndFinalize:
|
||||
"""Create a FlashInfer CUTLASS fused-MoE prepare finalize kernel"""
|
||||
use_dp = moe.moe_parallel_config.dp_size > 1 if moe is not None else False
|
||||
return FlashInferCutlassMoEPrepareAndFinalize(use_dp)
|
||||
return create_flashinfer_prepare_finalize(use_dp)
|
||||
|
||||
|
||||
def select_cutlass_fp8_gemm_impl(
|
||||
|
||||
@ -97,6 +97,34 @@ autotune = _lazy_import_wrapper(
|
||||
fallback_fn=lambda *args, **kwargs: contextlib.nullcontext())
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_comm() -> bool:
|
||||
"""Return ``True`` if FlashInfer comm module is available."""
|
||||
return has_flashinfer() and importlib.util.find_spec(
|
||||
"flashinfer.comm") is not None
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_all2all() -> bool:
|
||||
"""Return ``True`` if FlashInfer mnnvl all2all is available."""
|
||||
if not has_flashinfer_comm():
|
||||
return False
|
||||
|
||||
# Check if all required functions are available
|
||||
required_functions = [
|
||||
("flashinfer.comm", "Mapping"),
|
||||
("flashinfer.comm.mnnvl", "MnnvlMemory"),
|
||||
("flashinfer.comm.trtllm_alltoall", "MnnvlMoe"),
|
||||
("flashinfer.comm.trtllm_alltoall", "MoEAlltoallInfo"),
|
||||
]
|
||||
|
||||
for module_name, attr_name in required_functions:
|
||||
mod = _get_submodule(module_name)
|
||||
if not mod or not hasattr(mod, attr_name):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
@functools.cache
|
||||
def has_flashinfer_moe() -> bool:
|
||||
"""Return ``True`` if FlashInfer MoE module is available."""
|
||||
@ -402,6 +430,8 @@ __all__ = [
|
||||
"trtllm_fp4_block_scale_moe",
|
||||
"autotune",
|
||||
"has_flashinfer_moe",
|
||||
"has_flashinfer_comm",
|
||||
"has_flashinfer_all2all",
|
||||
"has_flashinfer_cutlass_fused_moe",
|
||||
"has_nvidia_artifactory",
|
||||
"supports_trtllm_attention",
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user