diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 83e1e675ba639..a7e7c402aac5c 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -24,80 +24,111 @@ from vllm.logger import init_logger logger = init_logger(__name__) -def idx_local_to_global( - local_idx: int, - local_cnt: int, - ep_rank: int, -) -> int: - """ - Convert a local expert index to a global expert index. - """ - return ep_rank * local_cnt + local_idx - - -def idx_global_to_local( - global_idx: int, - local_cnt: int, - ep_rank: int, -) -> int: - """ - Convert a global expert index to a local expert index. - """ - return global_idx - ep_rank * local_cnt - - -def global_idx_to_rank( - global_idx: int, - local_cnt: int, -) -> int: - """ - Convert a global expert index to a rank index. - """ - return global_idx // local_cnt - - -def get_ep_ranks_with_expert( - idx: int, +def get_ep_ranks_with_experts_batch( + expert_ids: np.ndarray, num_local_experts: int, old_indices: np.ndarray, new_indices: np.ndarray, -) -> tuple[list[int], list[int]]: +) -> tuple[dict[int, list[int]], dict[int, list[int]]]: """ Get the ranks of the experts that need to be exchanged. Args: - idx: The index of the expert. + expert_ids: 1D array of expert indices to query. num_local_experts: The number of local experts. old_indices: The old indices of the experts. new_indices: The new indices of the experts. Returns: - A tuple of two lists: - - The ranks of the experts that need to be sent. - - The ranks of the experts that need to be received. + A tuple of two dictionaries mapping expert_id to: + - ranks_to_send: The ranks that have this expert and need to send. + - ranks_to_recv: The ranks that need to receive this expert. """ - # Indices where expert idx appears - old_pos = np.nonzero(old_indices == idx)[0] - new_pos = np.nonzero(new_indices == idx)[0] - # Map positions to ranks - if old_pos.size > 0: - old_ranks = old_pos // num_local_experts - uniq_send, first_idx_send = np.unique(old_ranks, return_index=True) - order_send = np.argsort(first_idx_send) - ranks_to_send = uniq_send[order_send].astype(int).tolist() - else: - ranks_to_send = [] - if new_pos.size > 0: - new_ranks = new_pos // num_local_experts - uniq_recv, first_idx_recv = np.unique(new_ranks, return_index=True) - order_recv = np.argsort(first_idx_recv) - ranks_to_recv = uniq_recv[order_recv].astype(int).tolist() - else: - ranks_to_recv = [] - # Remove ranks that have local copies to avoid unnecessary recv - ranks_to_send_set = set(ranks_to_send) - ranks_to_recv_actual = [r for r in ranks_to_recv if r not in ranks_to_send_set] - return ranks_to_send, ranks_to_recv_actual + ranks_to_send_map: dict[int, list[int]] = {} + ranks_to_recv_map: dict[int, list[int]] = {} + + # Fast path: if no experts, return empty dicts + if expert_ids.size == 0: + return ranks_to_send_map, ranks_to_recv_map + + unique_experts = np.unique(expert_ids) + num_positions = len(old_indices) + position_indices = np.arange(num_positions, dtype=np.int32) + + # Vectorized approach: find all positions matching any query expert in one pass + # Use np.isin to get boolean masks for all relevant positions at once + old_relevant_mask = np.isin(old_indices, unique_experts) + new_relevant_mask = np.isin(new_indices, unique_experts) + + # Process old_indices (send ranks) + if np.any(old_relevant_mask): + old_relevant_positions = position_indices[old_relevant_mask] + old_relevant_experts = old_indices[old_relevant_mask] + old_relevant_ranks = old_relevant_positions // num_local_experts + + # Sort by expert first, then by position (to maintain first-appearance order) + sort_order = np.lexsort((old_relevant_positions, old_relevant_experts)) + sorted_experts = old_relevant_experts[sort_order] + sorted_ranks = old_relevant_ranks[sort_order] + + # Find boundaries where expert changes + expert_boundaries = np.concatenate( + [[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]] + ) + + # For each expert, extract unique ranks in order of first appearance + for i in range(len(expert_boundaries) - 1): + start, end = expert_boundaries[i], expert_boundaries[i + 1] + expert = int(sorted_experts[start]) + expert_ranks = sorted_ranks[start:end] + + # Get unique ranks preserving order + _, unique_idx = np.unique(expert_ranks, return_index=True) + unique_ranks = expert_ranks[np.sort(unique_idx)] + ranks_to_send_map[expert] = unique_ranks.tolist() + + # Process new_indices (recv ranks) + if np.any(new_relevant_mask): + new_relevant_positions = position_indices[new_relevant_mask] + new_relevant_experts = new_indices[new_relevant_mask] + new_relevant_ranks = new_relevant_positions // num_local_experts + + # Sort by expert first, then by position + sort_order = np.lexsort((new_relevant_positions, new_relevant_experts)) + sorted_experts = new_relevant_experts[sort_order] + sorted_ranks = new_relevant_ranks[sort_order] + + # Find boundaries where expert changes + expert_boundaries = np.concatenate( + [[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]] + ) + + # For each expert, extract unique ranks and exclude local copies + for i in range(len(expert_boundaries) - 1): + start, end = expert_boundaries[i], expert_boundaries[i + 1] + expert = int(sorted_experts[start]) + expert_ranks = sorted_ranks[start:end] + + # Get unique ranks preserving order + _, unique_idx = np.unique(expert_ranks, return_index=True) + unique_ranks = expert_ranks[np.sort(unique_idx)] + + # Remove ranks that have local copies (in send map) + send_ranks_set = set(ranks_to_send_map.get(expert, [])) + recv_ranks_actual = [ + int(r) for r in unique_ranks if r not in send_ranks_set + ] + ranks_to_recv_map[expert] = recv_ranks_actual + + # Handle experts that only appear in old (send only) or new (recv only) + for expert in unique_experts: + expert = int(expert) + if expert not in ranks_to_send_map: + ranks_to_send_map[expert] = [] + if expert not in ranks_to_recv_map: + ranks_to_recv_map[expert] = [] + + return ranks_to_send_map, ranks_to_recv_map def move_to_buffer( @@ -228,6 +259,10 @@ def move_to_buffer( p2p_ops: list[P2POp] = [] + # Pre-compute global ranks mapping + ep_size = ep_group.size() + rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)} + # 2. Post sends per layer for layer_idx in range(group_size): old_indices = old_indices_group[layer_idx] @@ -241,13 +276,17 @@ def move_to_buffer( order = np.argsort(experts, kind="stable") experts = experts[order] srcs = srcs[order] + + send_map, recv_map = get_ep_ranks_with_experts_batch( + experts, + num_local_experts, + old_indices, + layer_new_indices, + ) + for expert, src in zip(experts.tolist(), srcs.tolist()): - ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( - expert, - num_local_experts, - old_indices, - layer_new_indices, - ) + ranks_to_send = send_map[expert] + ranks_to_recv = recv_map[expert] if not ranks_to_send or not ranks_to_recv: continue num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) @@ -260,7 +299,7 @@ def move_to_buffer( if recver_pos < len(ranks_to_recv): recv_ranks.append(ranks_to_recv[recver_pos]) for dst in recv_ranks: - dst_global = get_global_rank(ep_group, dst) + dst_global = rank_to_global[dst] p2p_ops += [ P2POp( torch.distributed.isend, @@ -283,13 +322,18 @@ def move_to_buffer( order = np.argsort(experts, kind="stable") experts = experts[order] dsts = dsts[order] + + # Batch query all experts for this layer + send_map, recv_map = get_ep_ranks_with_experts_batch( + experts, + num_local_experts, + old_indices, + layer_new_indices, + ) + for expert, dst in zip(experts.tolist(), dsts.tolist()): - ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( - expert, - num_local_experts, - old_indices, - layer_new_indices, - ) + ranks_to_send = send_map[expert] + ranks_to_recv = recv_map[expert] if not ranks_to_send or not ranks_to_recv: continue num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) @@ -299,7 +343,7 @@ def move_to_buffer( src = ranks_to_send[recver_pos // num_dst_per_sender] else: src = ranks_to_send[recver_pos - remainder_start] - src_global = get_global_rank(ep_group, src) + src_global = rank_to_global[src] p2p_ops += [ P2POp( torch.distributed.irecv,