From 787384dd4af4976adb51f86924f80d129a361609 Mon Sep 17 00:00:00 2001 From: Robert Shaw Date: Thu, 28 Aug 2025 01:25:55 +0000 Subject: [PATCH] updated Signed-off-by: Robert Shaw --- vllm/model_executor/layers/fused_moe/layer.py | 116 ++++++++++-------- 1 file changed, 66 insertions(+), 50 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 54406a5a2d87f..d5a7e97fbe109 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1401,6 +1401,66 @@ class FusedMoE(CustomOp): self.logical_to_physical_map = logical_to_physical_map[moe_layer_idx] self.logical_replica_count = logical_replica_count[moe_layer_idx] + @staticmethod + @torch.compile(dynamic=True, + backend=current_platform.simple_compile_backend) + def handle_eplb( + topk_ids: torch.Tensor, + logical_replica_count: torch.Tensor, + logical_to_physical_map: torch.Tensor, + expert_load_view: torch.Tensor, + indices_type: torch.dtype, + ) -> torch.Tensor: + # 1. Convert the logical expert ids to physical expert ids + # Directly select a random replica for each logical expert + + # TODO: maybe optimize this by using specified kernels, + # or compute pseudo-random indices by modulo + + # In case `indices_type` is not `torch.long` or `torch.int`, + # e.g. `torch.uint32` as required by dispatch/combine kernels + topk_ids_long = topk_ids.long() + replica_indices = ( + torch.rand_like(topk_ids, dtype=torch.float) * + logical_replica_count[topk_ids_long]).long().unsqueeze(-1) + physical_ids = logical_to_physical_map[topk_ids_long].gather( + -1, replica_indices).squeeze(-1) + + topk_ids = physical_ids + + # 2. Record expert load metrics. + + # TODO(bowen): When using `FusedMoEModularKernel`, this + # can be done in a more unified way, since + # `FusedMoEPrepareAndFinalize` will return the expert + # token count, in some cases directly from the kernel. + # However, now there are many code paths not using + # the modular kernel, e.g. calling `fused_experts`, + # so we decide to keep the logic here. + # + # If later refactor moved all the MoE kernel calls + # to the modular kernel, we can move this logic there + # to achieve better efficiency. + + # `expert_load_view`: (num_physical_experts,) + + topk_ids_flatten = topk_ids.flatten() + + # Performance optimization: + # `masked_fill` is significantly faster than `masked_select` + invalid_mask = topk_ids_flatten < 0 + # Replace invalid expert ids with 0 (just a dummy position) + # to avoid out-of-bounds errors in scatter_add_ + index = topk_ids_flatten.masked_fill_(invalid_mask, 0) + # `src` is the valid mask, which is 1 for valid and 0 for invalid + src = ~invalid_mask + + expert_load_view.scatter_add_(dim=0, + index=index.long(), + src=src.to(expert_load_view)) + + return topk_ids.to(dtype=indices_type) + @staticmethod def select_experts( hidden_states: torch.Tensor, @@ -1480,56 +1540,12 @@ class FusedMoE(CustomOp): assert expert_load_view is not None assert logical_to_physical_map is not None assert logical_replica_count is not None - - # 1. Convert the logical expert ids to physical expert ids - # Directly select a random replica for each logical expert - - # TODO: maybe optimize this by using specified kernels, - # or compute pseudo-random indices by modulo - - # In case `indices_type` is not `torch.long` or `torch.int`, - # e.g. `torch.uint32` as required by dispatch/combine kernels - topk_ids_long = topk_ids.long() - replica_indices = ( - torch.rand_like(topk_ids, dtype=torch.float) * - logical_replica_count[topk_ids_long]).long().unsqueeze(-1) - physical_ids = logical_to_physical_map[topk_ids_long].gather( - -1, replica_indices).squeeze(-1) - - topk_ids = physical_ids - - # 2. Record expert load metrics. - - # TODO(bowen): When using `FusedMoEModularKernel`, this - # can be done in a more unified way, since - # `FusedMoEPrepareAndFinalize` will return the expert - # token count, in some cases directly from the kernel. - # However, now there are many code paths not using - # the modular kernel, e.g. calling `fused_experts`, - # so we decide to keep the logic here. - # - # If later refactor moved all the MoE kernel calls - # to the modular kernel, we can move this logic there - # to achieve better efficiency. - - # `expert_load_view`: (num_physical_experts,) - - topk_ids_flatten = topk_ids.flatten() - - # Performance optimization: - # `masked_fill` is significantly faster than `masked_select` - invalid_mask = topk_ids_flatten < 0 - # Replace invalid expert ids with 0 (just a dummy position) - # to avoid out-of-bounds errors in scatter_add_ - index = topk_ids_flatten.masked_fill_(invalid_mask, 0) - # `src` is the valid mask, which is 1 for valid and 0 for invalid - src = ~invalid_mask - - expert_load_view.scatter_add_(dim=0, - index=index.long(), - src=src.to(expert_load_view)) - - topk_ids = topk_ids.to(dtype=indices_type) + topk_ids = FusedMoE.handle_eplb( + topk_ids=topk_ids, + logical_replica_count=logical_replica_count, + logical_to_physical_map=logical_to_physical_map, + expert_load_view=expert_load_view, + indices_type=indices_type) assert topk_ids.dtype == indices_type or indices_type is None