[Bugfix] EPLB load statistics problem (#22167)

Signed-off-by: ycyaw66 <497410282@qq.com>
Signed-off-by: David Chen <530634352@qq.com>
Co-authored-by: ycyaw66 <497410282@qq.com>
This commit is contained in:
WeiQing Chen 2025-08-07 12:07:54 +08:00 committed by GitHub
parent f6278b6243
commit 4be02a3776
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 26 additions and 41 deletions

View File

@ -32,7 +32,7 @@ from dataclasses import dataclass
from typing import Optional, Union from typing import Optional, Union
import torch import torch
from torch.distributed import ProcessGroup, all_gather, all_reduce from torch.distributed import ProcessGroup, all_reduce
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import (get_ep_group, get_node_count, from vllm.distributed.parallel_state import (get_ep_group, get_node_count,
@ -112,13 +112,21 @@ class EplbState:
Expert load during this forward pass. Expert load during this forward pass.
We use the token count each expert processes as the load. We use the token count each expert processes as the load.
Shape: (num_moe_layers, num_local_physical_experts) Shape: (num_moe_layers, num_physical_experts)
""" """
expert_load_window: torch.Tensor expert_load_window: torch.Tensor
""" """
A sliding window of expert load. A sliding window of expert load.
Shape: (window_size, num_moe_layers, num_local_physical_experts) Shape: (window_size, num_moe_layers, num_physical_experts)
NOTE: The expert_load_view now records load for all physical experts
rather than just local experts. This ensures consistent load statistics
across different dispatch methods (naive all-to-all, DeepEP, pplx-kernels).
The recorded load will be multiplied by dp_size when using naive all-to-all
due to each DP rank contributing the same token set to the calculation.
See:
https://github.com/vllm-project/vllm/pull/22167#pullrequestreview-3086143856
""" """
expert_load_window_step: int = 0 expert_load_window_step: int = 0
""" """
@ -232,14 +240,14 @@ class EplbState:
).contiguous() ).contiguous()
expert_load_pass = torch.zeros( expert_load_pass = torch.zeros(
(model.num_moe_layers, model.num_local_physical_experts), (model.num_moe_layers, model.num_physical_experts),
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
expert_load_window_size = parallel_config.eplb_window_size expert_load_window_size = parallel_config.eplb_window_size
expert_load_window = torch.zeros( expert_load_window = torch.zeros(
(expert_load_window_size, model.num_moe_layers, (expert_load_window_size, model.num_moe_layers,
model.num_local_physical_experts), model.num_physical_experts),
dtype=torch.int32, dtype=torch.int32,
device=device, device=device,
) )
@ -353,18 +361,18 @@ class EplbState:
self.expert_load_pass.zero_() self.expert_load_pass.zero_()
if log_stats: if log_stats:
# `num_tokens`: (num_moe_layers,) # total_expert_load_pass: (num_moe_layers, num_physical_experts)
num_tokens = self.expert_load_pass.sum(dim=-1) total_expert_load_pass = self.expert_load_pass.clone()
# Collect load metrics from all ranks # Collect load metrics from all ranks
ep_group = get_ep_group().device_group ep_group = get_ep_group().device_group
assert ep_group is not None assert ep_group is not None
num_tokens_list = [ all_reduce(total_expert_load_pass, group=ep_group)
torch.empty_like(num_tokens) for _ in range(ep_group.size())
] # num_tokens_per_rank: (num_moe_layers, num_ranks)
all_gather(num_tokens_list, num_tokens, group=ep_group) num_tokens_per_rank = total_expert_load_pass.reshape(
# Stack to get (num_ranks, num_moe_layers) total_expert_load_pass.shape[0], ep_group.size(),
num_tokens_per_rank = torch.stack(num_tokens_list).float() -1).sum(dim=-1).float()
# Compute balancedness ratio: # Compute balancedness ratio:
# for each layer: # for each layer:
@ -426,17 +434,7 @@ class EplbState:
"(profile)" if is_profile else "") "(profile)" if is_profile else "")
if global_expert_load is None: if global_expert_load is None:
# This mapping is only used here, so we do not store it in the state # Map the physical expert load to global logical experts
physical_expert_start = ep_rank * model.num_local_physical_experts
physical_expert_end = (physical_expert_start +
model.num_local_physical_experts)
# (num_moe_layers, num_local_physical_experts)
local_physical_to_logical_map = self.physical_to_logical_map[
:,
physical_expert_start:physical_expert_end,
]
# Map the local physical expert load to global logical experts
logical_expert_load_window = torch.zeros( logical_expert_load_window = torch.zeros(
self.expert_load_window_size, self.expert_load_window_size,
model.num_moe_layers, model.num_moe_layers,
@ -446,7 +444,7 @@ class EplbState:
) )
logical_expert_load_window.scatter_add_( logical_expert_load_window.scatter_add_(
dim=-1, dim=-1,
index=local_physical_to_logical_map.unsqueeze(0).expand_as( index=self.physical_to_logical_map.unsqueeze(0).expand_as(
self.expert_load_window).long(), self.expert_load_window).long(),
src=self.expert_load_window, src=self.expert_load_window,
) )
@ -618,4 +616,4 @@ def _node_count_with_rank_mapping(
if is_same_node and node_assignment[other_rank] == 0: if is_same_node and node_assignment[other_rank] == 0:
node_assignment[other_rank] = next_node_id node_assignment[other_rank] = next_node_id
return next_node_id return next_node_id

View File

@ -1430,22 +1430,9 @@ class FusedMoE(torch.nn.Module):
# to the modular kernel, we can move this logic there # to the modular kernel, we can move this logic there
# to achieve better efficiency. # to achieve better efficiency.
# `expert_load_view`: (num_logical_experts,) # `expert_load_view`: (num_physical_experts,)
# Mask out non-local experts topk_ids_flatten = topk_ids.flatten()
if expert_map is not None:
topk_ids_local = expert_map[topk_ids]
topk_ids_flatten = topk_ids_local.flatten()
else:
topk_ids_flatten = topk_ids.flatten()
# Should be equivalent to:
# ```
# topk_ids_masked = topk_ids_local[topk_ids_local >= 0]
# expert_load_view += topk_ids_masked.bincount(
# minlength=expert_load_view.shape[0])
# ```
# We use `scatter_add_` since `bincount` cannot be compiled
# Performance optimization: # Performance optimization:
# `masked_fill` is significantly faster than `masked_select` # `masked_fill` is significantly faster than `masked_select`