From 4be02a37767f05a3fd27d66435d5cebea7a9bfe8 Mon Sep 17 00:00:00 2001 From: WeiQing Chen <40507679+david6666666@users.noreply.github.com> Date: Thu, 7 Aug 2025 12:07:54 +0800 Subject: [PATCH] [Bugfix] EPLB load statistics problem (#22167) Signed-off-by: ycyaw66 <497410282@qq.com> Signed-off-by: David Chen <530634352@qq.com> Co-authored-by: ycyaw66 <497410282@qq.com> --- vllm/distributed/eplb/eplb_state.py | 50 +++++++++---------- vllm/model_executor/layers/fused_moe/layer.py | 17 +------ 2 files changed, 26 insertions(+), 41 deletions(-) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index f64b516b0d04..c415d409f7fe 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -32,7 +32,7 @@ from dataclasses import dataclass from typing import Optional, Union import torch -from torch.distributed import ProcessGroup, all_gather, all_reduce +from torch.distributed import ProcessGroup, all_reduce from vllm.config import ParallelConfig from vllm.distributed.parallel_state import (get_ep_group, get_node_count, @@ -112,13 +112,21 @@ class EplbState: Expert load during this forward pass. We use the token count each expert processes as the load. - Shape: (num_moe_layers, num_local_physical_experts) + Shape: (num_moe_layers, num_physical_experts) """ expert_load_window: torch.Tensor """ A sliding window of expert load. - Shape: (window_size, num_moe_layers, num_local_physical_experts) + Shape: (window_size, num_moe_layers, num_physical_experts) + + NOTE: The expert_load_view now records load for all physical experts + rather than just local experts. This ensures consistent load statistics + across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels). + The recorded load will be multiplied by dp_size when using naive all-to-all + due to each DP rank contributing the same token set to the calculation. + See: + https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856 """ expert_load_window_step: int = 0 """ @@ -232,14 +240,14 @@ class EplbState: ).contiguous() expert_load_pass = torch.zeros( - (model.num_moe_layers, model.num_local_physical_experts), + (model.num_moe_layers, model.num_physical_experts), dtype=torch.int32, device=device, ) expert_load_window_size = parallel_config.eplb_window_size expert_load_window = torch.zeros( (expert_load_window_size, model.num_moe_layers, - model.num_local_physical_experts), + model.num_physical_experts), dtype=torch.int32, device=device, ) @@ -353,18 +361,18 @@ class EplbState: self.expert_load_pass.zero_() if log_stats: - # `num_tokens`: (num_moe_layers,) - num_tokens = self.expert_load_pass.sum(dim=-1) + # total_expert_load_pass: (num_moe_layers, num_physical_experts) + total_expert_load_pass = self.expert_load_pass.clone() # Collect load metrics from all ranks ep_group = get_ep_group().device_group assert ep_group is not None - num_tokens_list = [ - torch.empty_like(num_tokens) for _ in range(ep_group.size()) - ] - all_gather(num_tokens_list, num_tokens, group=ep_group) - # Stack to get (num_ranks, num_moe_layers) - num_tokens_per_rank = torch.stack(num_tokens_list).float() + all_reduce(total_expert_load_pass, group=ep_group) + + # num_tokens_per_rank: (num_moe_layers, num_ranks) + num_tokens_per_rank = total_expert_load_pass.reshape( + total_expert_load_pass.shape[0], ep_group.size(), + -1).sum(dim=-1).float() # Compute balancedness ratio: # for each layer: @@ -426,17 +434,7 @@ class EplbState: "(profile)" if is_profile else "") if global_expert_load is None: - # This mapping is only used here, so we do not store it in the state - physical_expert_start = ep_rank * model.num_local_physical_experts - physical_expert_end = (physical_expert_start + - model.num_local_physical_experts) - # (num_moe_layers, num_local_physical_experts) - local_physical_to_logical_map = self.physical_to_logical_map[ - :, - physical_expert_start:physical_expert_end, - ] - - # Map the local physical expert load to global logical experts + # Map the physical expert load to global logical experts logical_expert_load_window = torch.zeros( self.expert_load_window_size, model.num_moe_layers, @@ -446,7 +444,7 @@ class EplbState: ) logical_expert_load_window.scatter_add_( dim=-1, - index=local_physical_to_logical_map.unsqueeze(0).expand_as( + index=self.physical_to_logical_map.unsqueeze(0).expand_as( self.expert_load_window).long(), src=self.expert_load_window, ) @@ -618,4 +616,4 @@ def _node_count_with_rank_mapping( if is_same_node and node_assignment[other_rank] == 0: node_assignment[other_rank] = next_node_id - return next_node_id + return next_node_id \ No newline at end of file diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index a4a6157fa4bf..72c2bc9a3d73 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1430,22 +1430,9 @@ class FusedMoE(torch.nn.Module): # to the modular kernel, we can move this logic there # to achieve better efficiency. - # `expert_load_view`: (num_logical_experts,) + # `expert_load_view`: (num_physical_experts,) - # Mask out non-local experts - if expert_map is not None: - topk_ids_local = expert_map[topk_ids] - topk_ids_flatten = topk_ids_local.flatten() - else: - topk_ids_flatten = topk_ids.flatten() - - # Should be equivalent to: - # ``` - # topk_ids_masked = topk_ids_local[topk_ids_local >= 0] - # expert_load_view += topk_ids_masked.bincount( - # minlength=expert_load_view.shape[0]) - # ``` - # We use `scatter_add_` since `bincount` cannot be compiled + topk_ids_flatten = topk_ids.flatten() # Performance optimization: # `masked_fill` is significantly faster than `masked_select`