diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 2a283a6d12b9f..254cd2e10b8fb 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -591,22 +591,20 @@ def determine_expert_map( if ep_size == 1: return (global_num_experts, None) - local_num_experts = global_num_experts // ep_size + # Distribute experts as evenly as possible to each rank. + base_experts = global_num_experts // ep_size + remainder = global_num_experts % ep_size + if ep_rank < remainder: + local_num_experts = base_experts + 1 + else: + local_num_experts = base_experts # Create a tensor of size num_experts filled with -1 expert_map = torch.full((global_num_experts, ), -1, dtype=torch.int32) # Create a expert map for the local experts - if ep_rank < (ep_size - 1): - # Each non-last rank gets local_num_experts experts. - expert_map[ep_rank * local_num_experts: - (ep_rank + 1) * local_num_experts] = \ - torch.arange(0, local_num_experts, dtype=torch.int32) - else: - # All remaining experts are assigned to the last rank. - local_num_experts = (global_num_experts - ep_rank * local_num_experts) - - expert_map[-local_num_experts:] = \ - torch.arange(0, local_num_experts, dtype=torch.int32) + start_idx = ep_rank * base_experts + min(ep_rank, remainder) + expert_map[start_idx:start_idx + local_num_experts] = torch.arange( + 0, local_num_experts, dtype=torch.int32) return (local_num_experts, expert_map)