mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-01-23 18:44:30 +08:00
[Feat][Perf] Enable deepep-low-latency with round-robin expert placement. (#28449)
Signed-off-by: bruceszchen <bruceszchen@tencent.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
ba558c029a
commit
da2f6800e0
@ -67,6 +67,7 @@ def maybe_roundup_layer_hidden_size(
|
||||
def maybe_make_prepare_finalize(
|
||||
moe: FusedMoEConfig,
|
||||
quant_config: FusedMoEQuantConfig | None,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> FusedMoEPrepareAndFinalize | None:
|
||||
if not moe.moe_parallel_config.use_all2all_kernels:
|
||||
return None
|
||||
@ -134,6 +135,13 @@ def maybe_make_prepare_finalize(
|
||||
|
||||
elif moe.use_deepep_ll_kernels:
|
||||
assert quant_config is not None
|
||||
global_to_physical = physical_to_global = local_expert_global_ids = None
|
||||
if routing_tables is not None:
|
||||
(
|
||||
global_to_physical,
|
||||
physical_to_global,
|
||||
local_expert_global_ids,
|
||||
) = routing_tables
|
||||
all_to_all_args = dict(
|
||||
max_num_tokens_per_dp_rank=moe.max_num_tokens,
|
||||
token_hidden_size=moe.hidden_dim,
|
||||
@ -155,6 +163,9 @@ def maybe_make_prepare_finalize(
|
||||
max_tokens_per_rank=moe.max_num_tokens,
|
||||
num_dispatchers=all2all_manager.world_size,
|
||||
use_fp8_dispatch=use_fp8_dispatch,
|
||||
global_to_physical=global_to_physical,
|
||||
physical_to_global=physical_to_global,
|
||||
local_expert_global_ids=local_expert_global_ids,
|
||||
)
|
||||
|
||||
return prepare_finalize
|
||||
|
||||
@ -85,6 +85,9 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
max_tokens_per_rank: int,
|
||||
num_dispatchers: int,
|
||||
use_fp8_dispatch: bool = False,
|
||||
global_to_physical: torch.Tensor | None = None,
|
||||
physical_to_global: torch.Tensor | None = None,
|
||||
local_expert_global_ids: torch.Tensor | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -97,6 +100,17 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
self.handles: list[tuple | None] = [None, None]
|
||||
self.num_dispatchers_ = num_dispatchers
|
||||
|
||||
topk_indices_dtype = self.topk_indices_dtype()
|
||||
|
||||
def _maybe_cast(tensor: torch.Tensor | None) -> torch.Tensor | None:
|
||||
if tensor is None or topk_indices_dtype is None:
|
||||
return tensor
|
||||
return tensor.to(dtype=topk_indices_dtype)
|
||||
|
||||
self.global_to_physical = _maybe_cast(global_to_physical)
|
||||
self.physical_to_global = _maybe_cast(physical_to_global)
|
||||
self.local_expert_global_ids = _maybe_cast(local_expert_global_ids)
|
||||
|
||||
# We don't have enough information to determine if we should dispatch
|
||||
# activation scales in a packed ue8m0 format during object construction
|
||||
# time. This setting is handled by post_init_setup.
|
||||
@ -136,6 +150,16 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
def topk_indices_dtype(self) -> torch.dtype | None:
|
||||
return torch.int64
|
||||
|
||||
def _map_global_to_physical_ids(self, topk_ids: torch.Tensor) -> torch.Tensor:
|
||||
if self.global_to_physical is None:
|
||||
return topk_ids
|
||||
return self.global_to_physical[topk_ids]
|
||||
|
||||
def _map_local_to_global_ids(self, expert_topk_ids: torch.Tensor) -> torch.Tensor:
|
||||
if self.local_expert_global_ids is None:
|
||||
return expert_topk_ids
|
||||
return self.local_expert_global_ids[expert_topk_ids]
|
||||
|
||||
def _do_quant(
|
||||
self,
|
||||
x: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
|
||||
@ -226,9 +250,10 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
a1 = a1 * topk_weights.to(a1.dtype)
|
||||
|
||||
# Dispatch
|
||||
dispatch_topk_ids = self._map_global_to_physical_ids(topk_ids)
|
||||
expert_x, expert_num_tokens, handle, _, hook = self.buffer.low_latency_dispatch(
|
||||
a1,
|
||||
topk_ids,
|
||||
dispatch_topk_ids,
|
||||
self.max_tokens_per_rank,
|
||||
num_experts,
|
||||
use_fp8=self.use_fp8_dispatch,
|
||||
@ -313,11 +338,12 @@ class DeepEPLLPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
|
||||
# weights have already been applied.
|
||||
combine_topk_weights = torch.ones_like(topk_weights)
|
||||
|
||||
combine_topk_ids = self._map_global_to_physical_ids(topk_ids)
|
||||
# TODO (varun) : Enable zero copy mode
|
||||
dbo_maybe_run_recv_hook()
|
||||
_, _, recv_hook = self.buffer.low_latency_combine(
|
||||
fused_expert_output,
|
||||
topk_ids,
|
||||
combine_topk_ids,
|
||||
combine_topk_weights,
|
||||
handle,
|
||||
async_finish=False,
|
||||
|
||||
@ -50,10 +50,15 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
"""
|
||||
return False
|
||||
|
||||
def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> FusedMoEPrepareAndFinalize | None:
|
||||
from .all2all_utils import maybe_make_prepare_finalize
|
||||
|
||||
return maybe_make_prepare_finalize(self.moe, self.moe_quant_config)
|
||||
return maybe_make_prepare_finalize(
|
||||
self.moe, self.moe_quant_config, routing_tables
|
||||
)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
|
||||
@ -5,7 +5,7 @@ from collections.abc import Callable, Iterable
|
||||
from contextlib import nullcontext
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Literal, get_args, overload
|
||||
from typing import Literal, cast, get_args, overload
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@ -192,6 +192,42 @@ def determine_expert_map(
|
||||
return (local_num_experts, expert_map, expert_mask)
|
||||
|
||||
|
||||
def determine_expert_placement_strategy(
|
||||
expert_placement_strategy: ExpertPlacementStrategy,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
num_expert_group: int | None,
|
||||
num_redundant_experts: int,
|
||||
enable_eplb: bool,
|
||||
) -> ExpertPlacementStrategy:
|
||||
if expert_placement_strategy == "round_robin":
|
||||
round_robin_supported = (
|
||||
(num_expert_group is not None and num_expert_group > 1)
|
||||
and num_redundant_experts == 0
|
||||
and not enable_eplb
|
||||
)
|
||||
|
||||
if not round_robin_supported:
|
||||
logger.warning(
|
||||
"Round-robin expert placement is only supported for "
|
||||
"models with multiple expert groups and no redundant "
|
||||
"experts. Falling back to linear expert placement."
|
||||
)
|
||||
return "linear"
|
||||
if (
|
||||
moe_parallel_config.use_all2all_kernels
|
||||
and not moe_parallel_config.use_deepep_ll_kernels
|
||||
):
|
||||
logger.warning(
|
||||
"Round-robin expert placement currently only supports "
|
||||
"the DeepEP low-latency backend, but '%s' was configured. "
|
||||
"Falling back to linear expert placement.",
|
||||
moe_parallel_config.all2all_backend,
|
||||
)
|
||||
return "linear"
|
||||
|
||||
return expert_placement_strategy
|
||||
|
||||
|
||||
def get_compressed_expert_map(expert_map: torch.Tensor) -> str:
|
||||
"""
|
||||
Compresses the expert map by removing any -1 entries.
|
||||
@ -400,6 +436,9 @@ class FusedMoE(CustomOp):
|
||||
self.expert_load_view: torch.Tensor | None = None
|
||||
self.logical_to_physical_map: torch.Tensor | None = None
|
||||
self.logical_replica_count: torch.Tensor | None = None
|
||||
self.expert_placement_strategy: ExpertPlacementStrategy = (
|
||||
vllm_config.parallel_config.expert_placement_strategy
|
||||
)
|
||||
|
||||
# ROCm aiter shared experts fusion
|
||||
self.rocm_aiter_fmoe_enabled = rocm_aiter_ops.is_fused_moe_enabled()
|
||||
@ -433,38 +472,27 @@ class FusedMoE(CustomOp):
|
||||
"Redundant experts are only supported with EPLB."
|
||||
)
|
||||
|
||||
expert_placement_strategy = (
|
||||
vllm_config.parallel_config.expert_placement_strategy
|
||||
self.expert_placement_strategy = determine_expert_placement_strategy(
|
||||
expert_placement_strategy=self.expert_placement_strategy,
|
||||
moe_parallel_config=self.moe_parallel_config,
|
||||
num_expert_group=num_expert_group,
|
||||
num_redundant_experts=num_redundant_experts,
|
||||
enable_eplb=self.enable_eplb,
|
||||
)
|
||||
if expert_placement_strategy == "round_robin":
|
||||
# TODO(Bruce): will support round robin expert placement with
|
||||
# EPLB enabled in the future.
|
||||
round_robin_supported = (
|
||||
(num_expert_group is not None and num_expert_group > 1)
|
||||
and num_redundant_experts == 0
|
||||
and not self.enable_eplb
|
||||
)
|
||||
|
||||
if not round_robin_supported:
|
||||
logger.warning(
|
||||
"Round-robin expert placement is only supported for "
|
||||
"models with multiple expert groups and no redundant "
|
||||
"experts. Falling back to linear expert placement."
|
||||
)
|
||||
expert_placement_strategy = "linear"
|
||||
|
||||
self.expert_map: torch.Tensor | None
|
||||
local_num_experts, expert_map, expert_mask = determine_expert_map(
|
||||
ep_size=self.ep_size,
|
||||
ep_rank=self.ep_rank,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_placement_strategy=expert_placement_strategy,
|
||||
expert_placement_strategy=self.expert_placement_strategy,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
return_expert_mask=self.rocm_aiter_fmoe_enabled,
|
||||
)
|
||||
self.local_num_experts = local_num_experts
|
||||
self.register_buffer("expert_map", expert_map)
|
||||
self.register_buffer("expert_mask", expert_mask)
|
||||
self._maybe_init_expert_routing_tables()
|
||||
logger.info_once(
|
||||
"[EP Rank %s/%s] Expert parallelism is enabled. Expert "
|
||||
"placement strategy: %s. Local/global"
|
||||
@ -472,7 +500,7 @@ class FusedMoE(CustomOp):
|
||||
" %s.",
|
||||
self.ep_rank,
|
||||
self.ep_size,
|
||||
expert_placement_strategy,
|
||||
self.expert_placement_strategy,
|
||||
self.local_num_experts,
|
||||
self.global_num_experts,
|
||||
get_compressed_expert_map(self.expert_map),
|
||||
@ -621,7 +649,12 @@ class FusedMoE(CustomOp):
|
||||
# should be safe to swap out the quant_method.
|
||||
def maybe_init_modular_kernel(self) -> None:
|
||||
self.ensure_moe_quant_config_init()
|
||||
prepare_finalize = self.quant_method.maybe_make_prepare_finalize()
|
||||
# routing_tables only needed for round-robin expert placement with
|
||||
# DeepEP all2all backend.
|
||||
routing_tables = self._maybe_init_expert_routing_tables()
|
||||
prepare_finalize = self.quant_method.maybe_make_prepare_finalize(
|
||||
routing_tables=routing_tables
|
||||
)
|
||||
if prepare_finalize is not None:
|
||||
logger.debug(
|
||||
"%s for %s(%s)", prepare_finalize.__class__.__name__, self, id(self)
|
||||
@ -703,6 +736,84 @@ class FusedMoE(CustomOp):
|
||||
# By default, router/gate is called before FusedMoE forward pass
|
||||
return False
|
||||
|
||||
def _maybe_init_expert_routing_tables(
|
||||
self,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None:
|
||||
# Currently routing_tables only needed for round-robin expert placement
|
||||
# with DeepEP-ll all2all backend.
|
||||
if (
|
||||
self.expert_placement_strategy != "round_robin"
|
||||
or not self.use_deepep_ll_kernels
|
||||
):
|
||||
return None
|
||||
|
||||
if hasattr(self, "expert_global_to_physical"):
|
||||
return cast(
|
||||
tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
(
|
||||
self.expert_global_to_physical,
|
||||
self.expert_physical_to_global,
|
||||
self.expert_local_to_global,
|
||||
),
|
||||
)
|
||||
|
||||
if self.expert_map is None:
|
||||
return None
|
||||
|
||||
routing_tables = self.ensure_round_robin_expert_routing_tables(
|
||||
global_num_experts=self.global_num_experts,
|
||||
ep_size=self.ep_size,
|
||||
ep_rank=self.ep_rank,
|
||||
local_num_experts=self.local_num_experts,
|
||||
device=self.expert_map.device,
|
||||
)
|
||||
|
||||
global_to_physical, physical_to_global, local_global = routing_tables
|
||||
self.register_buffer("expert_global_to_physical", global_to_physical)
|
||||
self.register_buffer("expert_physical_to_global", physical_to_global)
|
||||
self.register_buffer("expert_local_to_global", local_global)
|
||||
|
||||
return routing_tables
|
||||
|
||||
@staticmethod
|
||||
def ensure_round_robin_expert_routing_tables(
|
||||
global_num_experts: int,
|
||||
ep_size: int,
|
||||
ep_rank: int,
|
||||
local_num_experts: int,
|
||||
device: torch.device | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
device_kwargs = {"device": device} if device is not None else {}
|
||||
global_indices = torch.arange(
|
||||
global_num_experts, dtype=torch.long, **device_kwargs
|
||||
)
|
||||
owner = torch.remainder(global_indices, ep_size)
|
||||
local_index = torch.div(global_indices, ep_size, rounding_mode="floor")
|
||||
base = global_num_experts // ep_size
|
||||
remainder = global_num_experts % ep_size
|
||||
physical_offset = owner * base
|
||||
if remainder > 0:
|
||||
remainder_tensor = torch.tensor(
|
||||
remainder, dtype=torch.long, **device_kwargs
|
||||
)
|
||||
physical_offset = physical_offset + torch.minimum(owner, remainder_tensor)
|
||||
|
||||
global_to_physical = physical_offset + local_index
|
||||
physical_to_global = torch.empty_like(global_to_physical)
|
||||
physical_to_global[global_to_physical] = global_indices
|
||||
|
||||
local_global = torch.arange(
|
||||
ep_rank,
|
||||
global_num_experts,
|
||||
ep_size,
|
||||
dtype=torch.long,
|
||||
**device_kwargs,
|
||||
)
|
||||
if local_global.numel() != local_num_experts:
|
||||
local_global = local_global[:local_num_experts]
|
||||
|
||||
return (global_to_physical, physical_to_global, local_global)
|
||||
|
||||
def update_expert_map(self):
|
||||
# ep_size and ep_rank should already be updated
|
||||
assert self.expert_map is not None
|
||||
@ -711,12 +822,14 @@ class FusedMoE(CustomOp):
|
||||
ep_size=self.ep_size,
|
||||
ep_rank=self.ep_rank,
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_placement_strategy=self.expert_placement_strategy,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
return_expert_mask=self.rocm_aiter_fmoe_enabled,
|
||||
)
|
||||
self.local_num_experts = local_num_experts
|
||||
self.register_buffer("expert_map", expert_map)
|
||||
self.register_buffer("expert_mask", expert_mask)
|
||||
self._maybe_init_expert_routing_tables()
|
||||
if self.aiter_fmoe_shared_expert_enabled:
|
||||
self._init_aiter_shared_experts_topK_buffer(
|
||||
vllm_config=get_current_vllm_config(),
|
||||
|
||||
@ -108,11 +108,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
def allow_inplace(self) -> bool:
|
||||
return True
|
||||
|
||||
def maybe_make_prepare_finalize(self) -> FusedMoEPrepareAndFinalize | None:
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> FusedMoEPrepareAndFinalize | None:
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
return None
|
||||
else:
|
||||
return super().maybe_make_prepare_finalize()
|
||||
return super().maybe_make_prepare_finalize(routing_tables)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
|
||||
@ -380,11 +380,14 @@ class CompressedTensorsW4A4MoeMethod(CompressedTensorsMoEMethod):
|
||||
(layer.w2_input_global_scale), requires_grad=False
|
||||
)
|
||||
|
||||
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if self.use_marlin:
|
||||
return None
|
||||
elif not self.allow_flashinfer:
|
||||
return super().maybe_make_prepare_finalize()
|
||||
return super().maybe_make_prepare_finalize(routing_tables)
|
||||
|
||||
prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize(self.moe)
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
@ -890,11 +893,14 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
layer.w2_weight_scale
|
||||
)
|
||||
|
||||
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if self.use_marlin or self.rocm_aiter_moe_enabled:
|
||||
return None
|
||||
else:
|
||||
return super().maybe_make_prepare_finalize()
|
||||
return super().maybe_make_prepare_finalize(routing_tables)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
|
||||
@ -1018,7 +1018,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
del layer.w13_input_scale
|
||||
del layer.w2_input_scale
|
||||
|
||||
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if (
|
||||
self.rocm_aiter_moe_enabled
|
||||
or self.use_marlin
|
||||
@ -1039,7 +1042,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
else:
|
||||
return super().maybe_make_prepare_finalize()
|
||||
return super().maybe_make_prepare_finalize(routing_tables)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
|
||||
@ -373,6 +373,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
# TRT LLM not supported with all2all yet.
|
||||
if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM:
|
||||
@ -384,7 +385,7 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
else:
|
||||
return super().maybe_make_prepare_finalize()
|
||||
return super().maybe_make_prepare_finalize(routing_tables)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
@ -1179,7 +1180,10 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
" for ModelOptNvFp4FusedMoE."
|
||||
)
|
||||
|
||||
def maybe_make_prepare_finalize(self) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
def maybe_make_prepare_finalize(
|
||||
self,
|
||||
routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None,
|
||||
) -> mk.FusedMoEPrepareAndFinalize | None:
|
||||
if self.use_marlin or (
|
||||
self.allow_flashinfer
|
||||
and self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM
|
||||
@ -1196,7 +1200,7 @@ class ModelOptNvFp4FusedMoE(FusedMoEMethodBase):
|
||||
logger.debug_once("%s", prepare_finalize.__class__.__name__)
|
||||
return prepare_finalize
|
||||
else:
|
||||
return super().maybe_make_prepare_finalize()
|
||||
return super().maybe_make_prepare_finalize(routing_tables)
|
||||
|
||||
def select_gemm_impl(
|
||||
self,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user