[Feat][Perf] Enable deepep-low-latency with round-robin expert placement. (#28449)

Signed-off-by: bruceszchen <bruceszchen@tencent.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Chen Bruce 2025-11-19 20:46:24 +08:00 committed by GitHub
parent ba558c029a
commit da2f6800e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 208 additions and 37 deletions

View File

@ -67,6 +67,7 @@ def maybe_roundup_layer_hidden_size(
def maybe_make_prepare_finalize(
moe: FusedMoEConfig,
quant_config: FusedMoEQuantConfig | None,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None:
if not moe.moe_parallel_config.use_all2all_kernels:
return None
@ -134,6 +135,13 @@ def maybe_make_prepare_finalize(
elif moe.use_deepep_ll_kernels:
assert quant_config is not None
global_to_physical = physical_to_global = local_expert_global_ids = None
if routing_tables is not None:
(
global_to_physical,
physical_to_global,
local_expert_global_ids,
) = routing_tables
all_to_all_args = dict(
max_num_tokens_per_dp_rank=moe.max_num_tokens,
token_hidden_size=moe.hidden_dim,
@ -155,6 +163,9 @@ def maybe_make_prepare_finalize(
max_tokens_per_rank=moe.max_num_tokens,
num_dispatchers=all2all_manager.world_size,
use_fp8_dispatch=use_fp8_dispatch,
global_to_physical=global_to_physical,
physical_to_global=physical_to_global,
local_expert_global_ids=local_expert_global_ids,
)
return prepare_finalize

View File

@ -85,6 +85,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
max_tokens_per_rank: int,
num_dispatchers: int,
use_fp8_dispatch: bool = False,
global_to_physical: torch.Tensor | None = None,
physical_to_global: torch.Tensor | None = None,
local_expert_global_ids: torch.Tensor | None = None,
):
super().__init__()
@ -97,6 +100,17 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
self.handles: list[tuple | None] = [None, None]
self.num_dispatchers_ = num_dispatchers
topk_indices_dtype = self.topk_indices_dtype()
def _maybe_cast(tensor: torch.Tensor | None) -> torch.Tensor | None:
if tensor is None or topk_indices_dtype is None:
return tensor
return tensor.to(dtype=topk_indices_dtype)
self.global_to_physical = _maybe_cast(global_to_physical)
self.physical_to_global = _maybe_cast(physical_to_global)
self.local_expert_global_ids = _maybe_cast(local_expert_global_ids)
# We don't have enough information to determine if we should dispatch
# activation scales in a packed ue8m0 format during object construction
# time. This setting is handled by post_init_setup.
@ -136,6 +150,16 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
def topk_indices_dtype(self) -> torch.dtype | None:
return torch.int64
def _map_global_to_physical_ids(self, topk_ids: torch.Tensor) -> torch.Tensor:
if self.global_to_physical is None:
return topk_ids
return self.global_to_physical[topk_ids]
def _map_local_to_global_ids(self, expert_topk_ids: torch.Tensor) -> torch.Tensor:
if self.local_expert_global_ids is None:
return expert_topk_ids
return self.local_expert_global_ids[expert_topk_ids]
def _do_quant(
self,
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
@ -226,9 +250,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
a1 = a1 * topk_weights.to(a1.dtype)
# Dispatch
dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids)
expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch(
a1,
topk_ids,
dispatch_topk_ids,
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
@ -313,11 +338,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
# weights have already been applied.
combine_topk_weights = torch.ones_like(topk_weights)
combine_topk_ids = self._map_global_to_physical_ids(topk_ids)
# TODO (varun) : Enable zero copy mode
dbo_maybe_run_recv_hook()
_, _, recv_hook = self.buffer.low_latency_combine(
fused_expert_output,
topk_ids,
combine_topk_ids,
combine_topk_weights,
handle,
async_finish=False,

View File

@ -50,10 +50,15 @@ class FusedMoEMethodBase(QuantizeMethodBase):
"""
return False
def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None:
from .all2all_utils import maybe_make_prepare_finalize
return maybe_make_prepare_finalize(self.moe, self.moe_quant_config)
return maybe_make_prepare_finalize(
self.moe, self.moe_quant_config, routing_tables
)
def select_gemm_impl(
self,

View File

@ -5,7 +5,7 @@ from collections.abc import Callable, Iterable
from contextlib import nullcontext
from enum import Enum
from functools import partial
from typing import Literal, get_args, overload
from typing import Literal, cast, get_args, overload
import torch
import torch.nn.functional as F
@ -192,6 +192,42 @@ def determine_expert_map(
return (local_num_experts, expert_map, expert_mask)
def determine_expert_placement_strategy(
expert_placement_strategy: ExpertPlacementStrategy,
moe_parallel_config: FusedMoEParallelConfig,
num_expert_group: int | None,
num_redundant_experts: int,
enable_eplb: bool,
) -> ExpertPlacementStrategy:
if expert_placement_strategy == "round_robin":
round_robin_supported = (
(num_expert_group is not None and num_expert_group > 1)
and num_redundant_experts == 0
and not enable_eplb
)
if not round_robin_supported:
logger.warning(
"Round-robin expert placement is only supported for "
"models with multiple expert groups and no redundant "
"experts. Falling back to linear expert placement."
)
return "linear"
if (
moe_parallel_config.use_all2all_kernels
and not moe_parallel_config.use_deepep_ll_kernels
):
logger.warning(
"Round-robin expert placement currently only supports "
"the DeepEP low-latency backend, but '%s' was configured. "
"Falling back to linear expert placement.",
moe_parallel_config.all2all_backend,
)
return "linear"
return expert_placement_strategy
def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
"""
Compresses the expert map by removing any -1 entries.
@ -400,6 +436,9 @@ class FusedMoE(CustomOp):
self.expert_load_view: torch.Tensor | None = None
self.logical_to_physical_map: torch.Tensor | None = None
self.logical_replica_count: torch.Tensor | None = None
self.expert_placement_strategy: ExpertPlacementStrategy = (
vllm_config.parallel_config.expert_placement_strategy
)
# ROCm aiter shared experts fusion
self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
@ -433,38 +472,27 @@ class FusedMoE(CustomOp):
"Redundant experts are only supported with EPLB."
)
expert_placement_strategy = (
vllm_config.parallel_config.expert_placement_strategy
self.expert_placement_strategy = determine_expert_placement_strategy(
expert_placement_strategy=self.expert_placement_strategy,
moe_parallel_config=self.moe_parallel_config,
num_expert_group=num_expert_group,
num_redundant_experts=num_redundant_experts,
enable_eplb=self.enable_eplb,
)
if expert_placement_strategy == "round_robin":
# TODO(Bruce): will support round robin expert placement with
# EPLB enabled in the future.
round_robin_supported = (
(num_expert_group is not None and num_expert_group > 1)
and num_redundant_experts == 0
and not self.enable_eplb
)
if not round_robin_supported:
logger.warning(
"Round-robin expert placement is only supported for "
"models with multiple expert groups and no redundant "
"experts. Falling back to linear expert placement."
)
expert_placement_strategy = "linear"
self.expert_map: torch.Tensor | None
local_num_experts, expert_map, expert_mask = determine_expert_map(
ep_size=self.ep_size,
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts,
expert_placement_strategy=expert_placement_strategy,
expert_placement_strategy=self.expert_placement_strategy,
num_fused_shared_experts=self.num_fused_shared_experts,
return_expert_mask=self.rocm_aiter_fmoe_enabled,
)
self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask)
self._maybe_init_expert_routing_tables()
logger.info_once(
"[EP Rank %s/%s] Expert parallelism is enabled. Expert "
"placement strategy: %s. Local/global"
@ -472,7 +500,7 @@ class FusedMoE(CustomOp):
" %s.",
self.ep_rank,
self.ep_size,
expert_placement_strategy,
self.expert_placement_strategy,
self.local_num_experts,
self.global_num_experts,
get_compressed_expert_map(self.expert_map),
@ -621,7 +649,12 @@ class FusedMoE(CustomOp):
# should be safe to swap out the quant_method.
def maybe_init_modular_kernel(self) -> None:
self.ensure_moe_quant_config_init()
prepare_finalize = self.quant_method.maybe_make_prepare_finalize()
# routing_tables only needed for round-robin expert placement with
# DeepEP all2all backend.
routing_tables = self._maybe_init_expert_routing_tables()
prepare_finalize = self.quant_method.maybe_make_prepare_finalize(
routing_tables=routing_tables
)
if prepare_finalize is not None:
logger.debug(
"%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self)
@ -703,6 +736,84 @@ class FusedMoE(CustomOp):
# By default, router/gate is called before FusedMoE forward pass
return False
def _maybe_init_expert_routing_tables(
self,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None:
# Currently routing_tables only needed for round-robin expert placement
# with DeepEP-ll all2all backend.
if (
self.expert_placement_strategy != "round_robin"
or not self.use_deepep_ll_kernels
):
return None
if hasattr(self, "expert_global_to_physical"):
return cast(
tuple[torch.Tensor, torch.Tensor, torch.Tensor],
(
self.expert_global_to_physical,
self.expert_physical_to_global,
self.expert_local_to_global,
),
)
if self.expert_map is None:
return None
routing_tables = self.ensure_round_robin_expert_routing_tables(
global_num_experts=self.global_num_experts,
ep_size=self.ep_size,
ep_rank=self.ep_rank,
local_num_experts=self.local_num_experts,
device=self.expert_map.device,
)
global_to_physical, physical_to_global, local_global = routing_tables
self.register_buffer("expert_global_to_physical", global_to_physical)
self.register_buffer("expert_physical_to_global", physical_to_global)
self.register_buffer("expert_local_to_global", local_global)
return routing_tables
@staticmethod
def ensure_round_robin_expert_routing_tables(
global_num_experts: int,
ep_size: int,
ep_rank: int,
local_num_experts: int,
device: torch.device | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
device_kwargs = {"device": device} if device is not None else {}
global_indices = torch.arange(
global_num_experts, dtype=torch.long, **device_kwargs
)
owner = torch.remainder(global_indices, ep_size)
local_index = torch.div(global_indices, ep_size, rounding_mode="floor")
base = global_num_experts // ep_size
remainder = global_num_experts % ep_size
physical_offset = owner * base
if remainder > 0:
remainder_tensor = torch.tensor(
remainder, dtype=torch.long, **device_kwargs
)
physical_offset = physical_offset + torch.minimum(owner, remainder_tensor)
global_to_physical = physical_offset + local_index
physical_to_global = torch.empty_like(global_to_physical)
physical_to_global[global_to_physical] = global_indices
local_global = torch.arange(
ep_rank,
global_num_experts,
ep_size,
dtype=torch.long,
**device_kwargs,
)
if local_global.numel() != local_num_experts:
local_global = local_global[:local_num_experts]
return (global_to_physical, physical_to_global, local_global)
def update_expert_map(self):
# ep_size and ep_rank should already be updated
assert self.expert_map is not None
@ -711,12 +822,14 @@ class FusedMoE(CustomOp):
ep_size=self.ep_size,
ep_rank=self.ep_rank,
global_num_experts=self.global_num_experts,
expert_placement_strategy=self.expert_placement_strategy,
num_fused_shared_experts=self.num_fused_shared_experts,
return_expert_mask=self.rocm_aiter_fmoe_enabled,
)
self.local_num_experts = local_num_experts
self.register_buffer("expert_map", expert_map)
self.register_buffer("expert_mask", expert_mask)
self._maybe_init_expert_routing_tables()
if self.aiter_fmoe_shared_expert_enabled:
self._init_aiter_shared_experts_topK_buffer(
vllm_config=get_current_vllm_config(),

View File

@ -108,11 +108,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def allow_inplace(self) -> bool:
return True
def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> FusedMoEPrepareAndFinalize | None:
if self.rocm_aiter_moe_enabled:
return None
else:
return super().maybe_make_prepare_finalize()
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
self,

View File

@ -380,11 +380,14 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
(layer.w2_input_global_scale), requires_grad=False
)
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if self.use_marlin:
return None
elif not self.allow_flashinfer:
return super().maybe_make_prepare_finalize()
return super().maybe_make_prepare_finalize(routing_tables)
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe)
logger.debug_once("%s", prepare_finalize.__class__.__name__)
@ -890,11 +893,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w2_weight_scale
)
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if self.use_marlin or self.rocm_aiter_moe_enabled:
return None
else:
return super().maybe_make_prepare_finalize()
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
self,

View File

@ -1018,7 +1018,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
del layer.w13_input_scale
del layer.w2_input_scale
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if (
self.rocm_aiter_moe_enabled
or self.use_marlin
@ -1039,7 +1042,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
else:
return super().maybe_make_prepare_finalize()
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
self,

View File

@ -373,6 +373,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
# TRT LLM not supported with all2all yet.
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
@ -384,7 +385,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
else:
return super().maybe_make_prepare_finalize()
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
self,
@ -1179,7 +1180,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
" for ModelOptNvFp4FusedMoE."
)
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
def maybe_make_prepare_finalize(
self,
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
) -> mk.FusedMoEPrepareAndFinalize | None:
if self.use_marlin or (
self.allow_flashinfer
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
@ -1196,7 +1200,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
logger.debug_once("%s", prepare_finalize.__class__.__name__)
return prepare_finalize
else:
return super().maybe_make_prepare_finalize()
return super().maybe_make_prepare_finalize(routing_tables)
def select_gemm_impl(
self,