From f4df2af9465a32b4ac2a42e58eba04a1188c224c Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 24 Nov 2025 18:58:46 +0000 Subject: [PATCH 01/30] Wip Signed-off-by: ilmarkov --- vllm/distributed/eplb/rebalance_execute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 376dad8a72ef1..046f3895970c4 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -436,7 +436,7 @@ def rearrange_expert_weights_inplace( is_unchanged=is_unchanged, is_received_locally=is_received_locally, experts_recv_loc=experts_recv_loc, - new_indices=new_global_expert_indices[layer].tolist(), + new_indices=new_global_expert_indices_cpu[layer].tolist(), ep_group=ep_group, ) From a46c72ac71d0422604cb02516f51429bcaff6b8e Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 25 Nov 2025 14:34:22 +0000 Subject: [PATCH 02/30] Optimize weight rearrange with numpy Signed-off-by: ilmarkov --- tests/distributed/eplb_utils.py | 3 +- tests/distributed/test_eplb_execute.py | 21 +- vllm/distributed/eplb/async_worker.py | 2 +- vllm/distributed/eplb/eplb_state.py | 46 +- vllm/distributed/eplb/rebalance_execute.py | 543 +++++++++++++-------- vllm/envs.py | 5 + 6 files changed, 394 insertions(+), 226 deletions(-) diff --git a/tests/distributed/eplb_utils.py b/tests/distributed/eplb_utils.py index 27a63e0215148..9d06e705968bd 100644 --- a/tests/distributed/eplb_utils.py +++ b/tests/distributed/eplb_utils.py @@ -15,7 +15,7 @@ from vllm.utils.system_utils import update_environment_variables mp.set_start_method("spawn", force=True) -def distributed_run(fn, world_size, *args): +def distributed_run(fn, world_size, *args, max_grouped_layers=1): number_of_processes = world_size processes: list[mp.Process] = [] for i in range(number_of_processes): @@ -26,6 +26,7 @@ def distributed_run(fn, world_size, *args): env["LOCAL_WORLD_SIZE"] = str(number_of_processes) env["MASTER_ADDR"] = "localhost" env["MASTER_PORT"] = "12345" + env["VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS"] = str(max_grouped_layers) p = mp.Process(target=fn, args=(env, world_size, *args)) processes.append(p) p.start() diff --git a/tests/distributed/test_eplb_execute.py b/tests/distributed/test_eplb_execute.py index 781dfd44c1ef6..1bf231500d9aa 100644 --- a/tests/distributed/test_eplb_execute.py +++ b/tests/distributed/test_eplb_execute.py @@ -286,15 +286,17 @@ def _test_async_transfer_layer_without_mtp_worker( device, old_indices, ) + old_indices_cpu = old_indices.cpu() + new_indices_cpu = new_indices.cpu() expert_buffer = [torch.empty_like(w) for w in expert_weights[0]] cuda_stream = torch.cuda.Stream(device=device) for layer_idx in range(num_layers): - is_unchanged, is_received_locally, experts_recv_loc = asyncio.run( + is_unchanged, is_received_locally, recv_metadata = asyncio.run( transfer_layer( - old_global_expert_indices=old_indices, - new_global_expert_indices=new_indices, + old_global_expert_indices=old_indices_cpu, + new_global_expert_indices=new_indices_cpu, expert_weights=expert_weights, expert_weights_buffer=expert_buffer, ep_group=ep_group, @@ -302,15 +304,14 @@ def _test_async_transfer_layer_without_mtp_worker( cuda_stream=cuda_stream, ) ) - cuda_stream.synchronize() move_from_buffer( - expert_weights=expert_weights[layer_idx], - expert_weights_buffer=expert_buffer, + weights_group=[expert_weights[layer_idx]], + buffers_group=[expert_buffer], is_unchanged=is_unchanged, is_received_locally=is_received_locally, - experts_recv_loc=experts_recv_loc, - new_indices=new_indices[layer_idx].tolist(), + recv_metadata=recv_metadata, + new_indices_group=new_indices_cpu[layer_idx : layer_idx + 1], ep_group=ep_group, ) @@ -426,8 +427,9 @@ def _test_rearrange_expert_weights_with_redundancy( (4, 8, 8, 16), ], ) +@pytest.mark.parametrize("group_layers", [1, 2]) def test_rearrange_expert_weights_with_redundancy( - world_size, num_layers, num_local_experts, num_logical_experts + world_size, num_layers, num_local_experts, num_logical_experts, group_layers ): """Test the functionality of rearranging expert weights with redundancy.""" @@ -439,6 +441,7 @@ def test_rearrange_expert_weights_with_redundancy( num_layers, num_local_experts, num_logical_experts, + max_grouped_layers=group_layers, ) diff --git a/vllm/distributed/eplb/async_worker.py b/vllm/distributed/eplb/async_worker.py index e4b4fc92eeaaa..9d7366996e3b2 100644 --- a/vllm/distributed/eplb/async_worker.py +++ b/vllm/distributed/eplb/async_worker.py @@ -89,7 +89,7 @@ async def transfer_run_periodically( ( model_state.is_unchanged, model_state.is_received_locally, - model_state.experts_recv_loc, + model_state.recv_metadata, ) = await transfer_layer( old_global_expert_indices=model_state.physical_to_logical_map, new_global_expert_indices=model_state.new_physical_to_logical_map, diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 9f8798a96a2fc..0dd3999c1a65a 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -31,6 +31,7 @@ import time from collections.abc import Sequence from dataclasses import dataclass +import numpy as np import torch from torch.distributed import ProcessGroup, all_reduce @@ -164,20 +165,24 @@ class EplbModelState: """ Whether the async EPLB needs to poll peers for buffer readiness. """ - is_unchanged: list[bool] + is_unchanged: np.ndarray """ intermediate variable between `move_to_buffer` and `move_to_workspace`. The size is same as the num of physical experts in the current layer. """ - is_received_locally: list[bool] + is_received_locally: np.ndarray """ intermediate variable between `move_to_buffer` and `move_to_workspace`. The size is same as the num of physical experts in the current layer. """ - experts_recv_loc: dict[int, int] + recv_metadata: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] """ intermediate variable between `move_to_buffer` and `move_to_workspace`. - The size is same as the num of physical experts in the current layer. + The tuple contains: + - recv_primary_mask: np.ndarray, shape (group_size, num_local_experts) + - recv_counts: np.ndarray, shape (group_size,) + - recv_expert_ids: np.ndarray, shape (group_size, num_local_experts) + - recv_dst_rows: np.ndarray, shape (group_size, num_local_experts) """ is_async_enabled: bool """ @@ -498,9 +503,9 @@ class EplbState: layer_to_transfer=0, rebalanced=False, pending_global_ready_check=False, - is_unchanged=[], - is_received_locally=[], - experts_recv_loc={}, + is_unchanged=np.array([]), + is_received_locally=np.array([]), + recv_metadata=(np.array([]), np.array([]), np.array([]), np.array([])), is_async_enabled=self.is_async, cuda_device_index=self.cuda_device_index, new_physical_to_logical_map=new_physical_to_logical_map, @@ -847,8 +852,6 @@ class EplbState: time_end - time_start, ) else: - device = eplb_model_state.physical_to_logical_map.device - new_physical = new_physical_to_logical_map.to(device) max_slots = eplb_model_state.logical_to_physical_map.shape[-1] padded_logical = torch.nn.functional.pad( new_logical_to_physical_map, @@ -859,7 +862,10 @@ class EplbState: eplb_model_state.logical_replica_count.device ) - eplb_model_state.new_physical_to_logical_map = new_physical + # Move map to cpu in advance + eplb_model_state.new_physical_to_logical_map = ( + new_physical_to_logical_map.cpu() + ) eplb_model_state.new_logical_to_physical_map = padded_logical eplb_model_state.new_logical_replica_count = new_replica @@ -958,17 +964,21 @@ class EplbState: stream = torch.cuda.current_stream(device=device_index) stream.wait_event(model_state.buffer_ready_event) model_state.buffer_ready_event = None + weights_group = [ + model_state.model.expert_weights[model_state.layer_to_transfer] + ] + buffers_group = [model_state.expert_buffer] move_from_buffer( - expert_weights=model_state.model.expert_weights[ - model_state.layer_to_transfer - ], - expert_weights_buffer=model_state.expert_buffer, + weights_group=weights_group, + buffers_group=buffers_group, is_unchanged=model_state.is_unchanged, is_received_locally=model_state.is_received_locally, - experts_recv_loc=model_state.experts_recv_loc, - new_indices=model_state.new_physical_to_logical_map[ - model_state.layer_to_transfer - ].tolist(), + recv_metadata=model_state.recv_metadata, + new_indices_group=model_state.new_physical_to_logical_map[ + model_state.layer_to_transfer : model_state.layer_to_transfer + 1 + ] + .cpu() + .numpy(), ep_group=ep_group, ) transferred_layer = model_state.layer_to_transfer diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 046f3895970c4..aa9f77f3ca5c4 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -6,9 +6,10 @@ The actual execution of the rearrangement. This involves the exchange of expert weights between GPUs. """ -from collections.abc import Iterable, MutableSequence, Sequence +from collections.abc import Iterable, Sequence from functools import partial +import numpy as np import torch from torch.distributed import ( P2POp, @@ -18,6 +19,11 @@ from torch.distributed import ( get_global_rank, ) +import vllm.envs as envs +from vllm.logger import init_logger + +logger = init_logger(__name__) + def idx_local_to_global( local_idx: int, @@ -54,9 +60,9 @@ def global_idx_to_rank( def get_ep_ranks_with_expert( idx: int, num_local_experts: int, - old_indices: Sequence[int], - new_indices: Sequence[int], -) -> tuple[MutableSequence[int], MutableSequence[int]]: + old_indices: np.ndarray, + new_indices: np.ndarray, +) -> tuple[list[int], list[int]]: """ Get the ranks of the experts that need to be exchanged. @@ -71,161 +77,227 @@ def get_ep_ranks_with_expert( - The ranks of the experts that need to be sent. - The ranks of the experts that need to be received. """ - global2rank = partial( - global_idx_to_rank, - local_cnt=num_local_experts, - ) - - ranks_to_send: list[int] = [] - ranks_to_recv: list[int] = [] - - for i, e in enumerate(old_indices): - if e == idx: - rank = global2rank(i) - if not ranks_to_send or ranks_to_send[-1] != rank: - ranks_to_send.append(rank) - - for i, e in enumerate(new_indices): - if e == idx: - rank = global2rank(i) - if not ranks_to_recv or ranks_to_recv[-1] != rank: - ranks_to_recv.append(rank) - - # Remove those ranks that can get this expert locally. + # Indices where expert idx appears + old_pos = np.nonzero(old_indices == idx)[0] + new_pos = np.nonzero(new_indices == idx)[0] + # Map positions to ranks + if old_pos.size > 0: + old_ranks = old_pos // num_local_experts + uniq_send, first_idx_send = np.unique(old_ranks, return_index=True) + order_send = np.argsort(first_idx_send) + ranks_to_send = uniq_send[order_send].astype(int).tolist() + else: + ranks_to_send = [] + if new_pos.size > 0: + new_ranks = new_pos // num_local_experts + uniq_recv, first_idx_recv = np.unique(new_ranks, return_index=True) + order_recv = np.argsort(first_idx_recv) + ranks_to_recv = uniq_recv[order_recv].astype(int).tolist() + else: + ranks_to_recv = [] + # Remove ranks that have local copies to avoid unnecessary recv ranks_to_send_set = set(ranks_to_send) - ranks_to_recv_actual = [ - rank for rank in ranks_to_recv if rank not in ranks_to_send_set - ] - + ranks_to_recv_actual = [r for r in ranks_to_recv if r not in ranks_to_send_set] return ranks_to_send, ranks_to_recv_actual def move_to_buffer( num_local_experts: int, - old_indices: Sequence[int], - new_indices: Sequence[int], - expert_weights: Iterable[torch.Tensor], - expert_weights_buffer: Sequence[torch.Tensor], + old_indices_group: np.ndarray, + new_indices_group: np.ndarray, + expert_weights_group: Sequence[Iterable[torch.Tensor]], + buffers_group: Sequence[Sequence[torch.Tensor]], cuda_stream: torch.cuda.Stream | None, ep_group: ProcessGroup, -) -> tuple[list[bool], list[bool], dict[int, int]]: +) -> tuple[ + np.ndarray, np.ndarray, tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] +]: """ - Perform expert weights rearrangement of one layer. + Perform expert weights rearrangement of a group of layers. """ + assert len(old_indices_group) == len(new_indices_group) == len(expert_weights_group) + group_size = len(old_indices_group) ep_rank = ep_group.rank() - local2global = partial( - idx_local_to_global, - local_cnt=num_local_experts, - ep_rank=ep_rank, - ) - # 0. Do nothing for experts that did not change. - is_unchanged = [ - old_indices[local2global(i)] == new_indices[local2global(i)] - for i in range(num_local_experts) - ] + # Pre-allocate per-layer compact maps/masks (numpy) + is_unchanged = np.zeros((group_size, num_local_experts), dtype=np.bool_) + is_received_locally = np.zeros((group_size, num_local_experts), dtype=np.bool_) + recv_primary_mask = np.zeros((group_size, num_local_experts), dtype=np.bool_) + send_counts = np.zeros(group_size, dtype=np.int32) + send_expert_ids = np.full((group_size, num_local_experts), -1, dtype=np.int64) + send_src_rows = np.full((group_size, num_local_experts), -1, dtype=np.int32) + recv_counts = np.zeros(group_size, dtype=np.int32) + recv_expert_ids = np.full((group_size, num_local_experts), -1, dtype=np.int64) + recv_dst_rows = np.full((group_size, num_local_experts), -1, dtype=np.int32) + base = ep_rank * num_local_experts + local_rows = np.arange(num_local_experts, dtype=np.int32) + local_global = base + local_rows - # 1. Perform weight copy inside the local rank. - is_received_locally = is_unchanged[:] - for src in range(num_local_experts): - src_global = local2global(src) + # Build masks and expert maps per layer + for layer_idx in range(group_size): + old_indices = old_indices_group[layer_idx] + layer_new_indices = new_indices_group[layer_idx] + + old_local_expert_ids = old_indices[local_global] + new_local_expert_ids = layer_new_indices[local_global] + + # Unchanged per-dst mask + unchanged_mask = old_local_expert_ids == new_local_expert_ids + is_unchanged[layer_idx, :] = unchanged_mask + + # Local receive eligibility + new_valid = new_local_expert_ids != -1 + can_recv_local = np.isin( + new_local_expert_ids, old_local_expert_ids, assume_unique=False + ) + is_local_recv = np.logical_or( + unchanged_mask, np.logical_and(new_valid, can_recv_local) + ) + is_received_locally[layer_idx, :] = is_local_recv + + # Send map: first src row per unique expert present locally in old mapping + valid_old = old_local_expert_ids != -1 + if np.any(valid_old): + uniq_experts, first_idx = np.unique( + old_local_expert_ids[valid_old], return_index=True + ) + filtered_rows = local_rows[valid_old] + src_rows = filtered_rows[first_idx] + layer_send_count = int(uniq_experts.shape[0]) + send_counts[layer_idx] = layer_send_count + send_expert_ids[layer_idx, :layer_send_count] = uniq_experts + send_src_rows[layer_idx, :layer_send_count] = src_rows + else: + send_counts[layer_idx] = 0 + + # Recv map: primary dst per unique expert needed remotely + need_recv_mask = np.logical_and(~is_local_recv, new_valid) + if np.any(need_recv_mask): + desired_experts = new_local_expert_ids[need_recv_mask] + desired_dsts = local_rows[need_recv_mask] + uniq_recv_experts, uniq_indices = np.unique( + desired_experts, return_index=True + ) + dst_rows = desired_dsts[uniq_indices] + layer_send_count = int(uniq_recv_experts.shape[0]) + recv_counts[layer_idx] = layer_send_count + recv_expert_ids[layer_idx, :layer_send_count] = uniq_recv_experts + recv_dst_rows[layer_idx, :layer_send_count] = dst_rows + recv_primary_mask[layer_idx, dst_rows] = True + else: + recv_counts[layer_idx] = 0 + + # 1. Local moves into tmp buffers + for layer_idx in range(group_size): + layer_is_unchanged = is_unchanged[layer_idx, :] + layer_is_received_locally = is_received_locally[layer_idx, :] + layer_new_indices = new_indices_group[layer_idx] + layer_send_count = int(send_counts[layer_idx]) + layer_send_experts = send_expert_ids[layer_idx, :layer_send_count] + layer_send_srcs = send_src_rows[layer_idx, :layer_send_count] + local2global = partial( + idx_local_to_global, + local_cnt=num_local_experts, + ep_rank=ep_rank, + ) + layer_weights_list = list(expert_weights_group[layer_idx]) + layer_buffers_list = list(buffers_group[layer_idx]) for dst in range(num_local_experts): + if layer_is_unchanged[dst] or not layer_is_received_locally[dst]: + continue dst_global = local2global(dst) - if is_received_locally[dst]: + expert = layer_new_indices[dst_global] + if expert == -1: continue - if old_indices[src_global] == -1 or new_indices[dst_global] == -1: + matches = np.nonzero(layer_send_experts == expert)[0] + if matches.size == 0: continue - if old_indices[src_global] == new_indices[dst_global]: - is_received_locally[dst] = True - for weight, buffer in zip(expert_weights, expert_weights_buffer): - with torch.cuda.stream(cuda_stream): - buffer[dst].copy_(weight[src], non_blocking=True) - + src_local = int(layer_send_srcs[matches[0]]) + for w, b in zip(layer_weights_list, layer_buffers_list): + b[dst].copy_(w[src_local]) p2p_ops: list[P2POp] = [] - # 2. Initiate sending of weights. - experts_send_loc: dict[int, int] = {} - for src in range(num_local_experts): - expert = old_indices[local2global(src)] - if expert == -1: + # 2. Post sends per layer + for layer_idx in range(group_size): + old_indices = old_indices_group[layer_idx] + layer_new_indices = new_indices_group[layer_idx] + layer_weights_list = list(expert_weights_group[layer_idx]) + layer_send_count = int(send_counts[layer_idx]) + if layer_send_count == 0: continue - if expert in experts_send_loc: + experts = send_expert_ids[layer_idx, :layer_send_count] + srcs = send_src_rows[layer_idx, :layer_send_count] + order = np.argsort(experts, kind="stable") + experts = experts[order] + srcs = srcs[order] + for expert, src in zip(experts.tolist(), srcs.tolist()): + ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( + expert, + num_local_experts, + old_indices, + layer_new_indices, + ) + if not ranks_to_send or not ranks_to_recv: + continue + num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) + sender_pos = ranks_to_send.index(ep_rank) + recv_begin = sender_pos * num_dst_per_sender + recv_end = recv_begin + num_dst_per_sender + recv_ranks = ranks_to_recv[recv_begin:recv_end] + remainder_start = len(ranks_to_send) * num_dst_per_sender + recver_pos = remainder_start + sender_pos + if recver_pos < len(ranks_to_recv): + recv_ranks.append(ranks_to_recv[recver_pos]) + for dst in recv_ranks: + dst_global = get_global_rank(ep_group, dst) + p2p_ops += [ + P2POp( + torch.distributed.isend, + w[src], + dst_global, + ) + for w in layer_weights_list + ] + + # 3. Post recvs per layer + for layer_idx in range(group_size): + old_indices = old_indices_group[layer_idx] + layer_new_indices = new_indices_group[layer_idx] + layer_buffers_list = list(buffers_group[layer_idx]) + layer_recv_count = int(recv_counts[layer_idx]) + if layer_recv_count == 0: continue - experts_send_loc[expert] = src - - # We need to sort here to match send/recv - for expert, src in sorted(experts_send_loc.items()): - ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( - expert, - num_local_experts, - old_indices, - new_indices, - ) - - # Calculate the ranks to send by this rank - num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) - sender_pos = ranks_to_send.index(ep_rank) - recv_begin = sender_pos * num_dst_per_sender - recv_end = recv_begin + num_dst_per_sender - recv_ranks = ranks_to_recv[recv_begin:recv_end] - - # Tackle remainders - remainder_start = len(ranks_to_send) * num_dst_per_sender - recver_pos = remainder_start + sender_pos - if recver_pos < len(ranks_to_recv): - recv_ranks.append(ranks_to_recv[recver_pos]) - - for dst in recv_ranks: - dst_global = get_global_rank(ep_group, dst) + experts = recv_expert_ids[layer_idx, :layer_recv_count] + dsts = recv_dst_rows[layer_idx, :layer_recv_count] + order = np.argsort(experts, kind="stable") + experts = experts[order] + dsts = dsts[order] + for expert, dst in zip(experts.tolist(), dsts.tolist()): + ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( + expert, + num_local_experts, + old_indices, + layer_new_indices, + ) + if not ranks_to_send or not ranks_to_recv: + continue + num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) + recver_pos = ranks_to_recv.index(ep_rank) + remainder_start = len(ranks_to_send) * num_dst_per_sender + if recver_pos < remainder_start: + src = ranks_to_send[recver_pos // num_dst_per_sender] + else: + src = ranks_to_send[recver_pos - remainder_start] + src_global = get_global_rank(ep_group, src) p2p_ops += [ P2POp( - torch.distributed.isend, - weight[src], - dst_global, + torch.distributed.irecv, + b[dst], + src_global, ) - for weight in expert_weights + for b in layer_buffers_list ] - # 3. Initiate receiving of weights. - experts_recv_loc: dict[int, int] = {} - for dst in range(num_local_experts): - if is_received_locally[dst]: - continue - expert = new_indices[local2global(dst)] - if expert == -1: - continue - if expert in experts_recv_loc: - continue - experts_recv_loc[expert] = dst - - # We need to sort here to match send/recv - for expert, dst in sorted(experts_recv_loc.items()): - ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( - expert, - num_local_experts, - old_indices, - new_indices, - ) - - # Calculate the rank to recv by this rank - num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) - recver_pos = ranks_to_recv.index(ep_rank) - remainder_start = len(ranks_to_send) * num_dst_per_sender - if recver_pos < remainder_start: - src = ranks_to_send[recver_pos // num_dst_per_sender] - else: - src = ranks_to_send[recver_pos - remainder_start] - - src_global = get_global_rank(ep_group, src) - p2p_ops += [ - P2POp( - torch.distributed.irecv, - weight[dst], - src_global, - ) - for weight in expert_weights_buffer - ] - # 4. Execute the P2P operations. The real communication happens here. if p2p_ops and cuda_stream is not None: with torch.cuda.stream(cuda_stream): @@ -237,38 +309,98 @@ def move_to_buffer( for req in reqs: req.wait() # wait for the communication to finish - return is_unchanged, is_received_locally, experts_recv_loc + return ( + is_unchanged, + is_received_locally, + (recv_primary_mask, recv_counts, recv_expert_ids, recv_dst_rows), + ) def move_from_buffer( - expert_weights: Iterable[torch.Tensor], - expert_weights_buffer: list[torch.Tensor], - is_unchanged: list[bool], - is_received_locally: list[bool], - experts_recv_loc: dict[int, int], - new_indices: Sequence[int], + weights_group: Sequence[Iterable[torch.Tensor]], + buffers_group: Sequence[Sequence[torch.Tensor]], + is_unchanged: np.ndarray, + is_received_locally: np.ndarray, + recv_metadata: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], + new_indices_group: np.ndarray, ep_group: ProcessGroup, ) -> None: + assert ( + len(weights_group) + == len(buffers_group) + == len(is_unchanged) + == len(is_received_locally) + == len(recv_metadata[0]) + == len(new_indices_group) + ), "Unmatching layer group size" ep_rank = ep_group.rank() - num_local_experts = len(is_unchanged) - - local2global = partial( - idx_local_to_global, local_cnt=num_local_experts, ep_rank=ep_rank - ) - - for dst in range(num_local_experts): - if is_unchanged[dst]: + group_size = len(is_unchanged) + recv_primary_mask, recv_counts, recv_expert_ids, recv_dst_rows = recv_metadata + num_local_experts = is_unchanged.shape[1] + # Mask for rows to copy back from buffers: + # copy if locally received OR remote primary recv + copy_mask = np.logical_or(is_received_locally, recv_primary_mask) + # Copy back local buffered rows into destination weights + for layer_idx in range(group_size): + layer_is_unchanged = is_unchanged[layer_idx, :] + layer_copy_mask = copy_mask[layer_idx, :] + weights_list = list(weights_group[layer_idx]) + buffers_list = list(buffers_group[layer_idx]) + # rows to copy = (~unchanged) & copy_mask + dest_mask_np = np.logical_and(~layer_is_unchanged, layer_copy_mask) + if not bool(dest_mask_np.any()): continue - if is_received_locally[dst]: - for weight, buffer in zip(expert_weights, expert_weights_buffer): - weight[dst].copy_(buffer[dst], non_blocking=True) - else: - expert = new_indices[local2global(dst)] - if expert == -1: - continue - src = experts_recv_loc[expert] - for weight, buffer in zip(expert_weights, expert_weights_buffer): - weight[dst].copy_(buffer[src], non_blocking=True) + dest_indices = np.nonzero(dest_mask_np)[0].tolist() + for dst in dest_indices: + for w, b in zip(weights_list, buffers_list): + w[dst].copy_(b[dst]) + + # Duplicate remote received rows to non-primary duplicate dsts + for layer_idx in range(group_size): + layer_is_unchanged = is_unchanged[layer_idx, :] + layer_is_received_locally = is_received_locally[layer_idx, :] + new_indices = new_indices_group[layer_idx] + weights_list = list(weights_group[layer_idx]) + + count_recv = int(recv_counts[layer_idx]) + if count_recv == 0: + # No remote primaries on this layer → no remote duplicates to materialize + continue + # Local view of desired expert ids per local row + base = ep_rank * num_local_experts + local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)] + # Duplicate rows mask: need remote, not primary, and valid expert id + duplicate_mask = np.logical_and( + np.logical_and(~layer_is_unchanged, ~layer_is_received_locally), + np.logical_and(~recv_primary_mask[layer_idx, :], local_experts != -1), + ) + if not bool(duplicate_mask.any()): + continue + dup_dst_rows = np.nonzero(duplicate_mask)[0] + dup_experts = local_experts[dup_dst_rows] + + # Build primary mapping arrays (expert -> primary dst) and vector-match + prim_experts = recv_expert_ids[layer_idx, :count_recv] + prim_dsts = recv_dst_rows[layer_idx, :count_recv] + order = np.argsort(prim_experts, kind="stable") + prim_experts_sorted = prim_experts[order] + prim_dsts_sorted = prim_dsts[order] + pos = np.searchsorted(prim_experts_sorted, dup_experts) + # Filter to experts that have a matching primary entry + valid = np.logical_and( + pos < prim_experts_sorted.shape[0], + prim_experts_sorted[np.minimum(pos, prim_experts_sorted.shape[0] - 1)] + == dup_experts, + ) + if not bool(valid.any()): + continue + matched_dst_rows = dup_dst_rows[valid] + matched_src_rows = prim_dsts_sorted[pos[valid]] + + # Perform row copies per (dst, src) pair without tensor indexing + for dst, src in zip(matched_dst_rows.tolist(), matched_src_rows.tolist()): + for w in weights_list: + w[dst].copy_(w[src]) async def transfer_layer( @@ -281,7 +413,9 @@ async def transfer_layer( layer: int = 0, cuda_stream: torch.cuda.Stream | None = None, rank_mapping: dict[int, int] | None = None, -) -> tuple[list[bool], list[bool], dict[int, int]]: +) -> tuple[ + np.ndarray, np.ndarray, tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] +]: """ Rearranges the expert weights in place according to the new expert indices. @@ -322,20 +456,20 @@ async def transfer_layer( num_local_physical_experts = next(iter(expert_weights[0])).shape[0] assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts) assert num_physical_experts == ep_size * num_local_physical_experts - # A buffer to hold the expert weights in one layer during the exchange. - # NOTE: Currently we assume the same weights across different layers - # have the same shape. - is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer( + old_global_expert_indices_np = old_global_expert_indices.cpu().numpy() + new_global_expert_indices_np = new_global_expert_indices.cpu().numpy() + + is_unchanged, is_received_locally, recv_metadata = move_to_buffer( num_local_experts=num_local_physical_experts, - old_indices=old_global_expert_indices[layer].tolist(), - new_indices=new_global_expert_indices[layer].tolist(), - expert_weights=expert_weights[layer], - expert_weights_buffer=expert_weights_buffer, + old_indices_group=old_global_expert_indices_np[layer : layer + 1], + new_indices_group=new_global_expert_indices_np[layer : layer + 1], + expert_weights_group=[expert_weights[layer]], + buffers_group=[expert_weights_buffer], cuda_stream=cuda_stream, ep_group=ep_group, ) - return is_unchanged, is_received_locally, experts_recv_loc + return is_unchanged, is_received_locally, recv_metadata def rearrange_expert_weights_inplace( @@ -391,54 +525,69 @@ def rearrange_expert_weights_inplace( ep_size = ep_group.size() assert num_physical_experts == ep_size * num_local_physical_experts - # A buffer to hold the expert weights in one layer during the exchange. + # Max number of layers to group for communication + max_group_layers = envs.VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS + max_group_layers = max(min(max_group_layers, num_moe_layers), 1) + logger.info_once( + f"EPLB Sync: rearrange max_group_layers: {max_group_layers}", scope="global" + ) + + first_layer_weights = list(expert_weights[0]) + # Buffers to hold the expert weights during the exchange. # NOTE: Currently we assume the same weights across different layers # have the same shape. - expert_weights_buffer = [torch.empty_like(w) for w in expert_weights[0]] - + weights_buffers: list[list[torch.Tensor]] = [ + [torch.empty_like(w) for w in first_layer_weights] + for _ in range(max_group_layers) + ] if is_profile: - # Maximum send size is to send all local experts to all ranks, - # So we use a dummy `all_gather` to reserve enough communication buffer - for weight, buffer in zip(expert_weights[0], expert_weights_buffer): - # A `/dev/null`-like buffer to avoid real memory allocation - dummy_recv_buffer = [buffer for _ in range(ep_size)] - # NOTE(bowen): Needed this barrier to avoid OOM during actual - # execution. I'm not very sure why this is needed - torch.distributed.barrier() - all_gather( - dummy_recv_buffer, - weight, - group=ep_group, - ) + # Reserve communication buffers via a minimal dummy all_gather on first layer + for layer_idx in range(max_group_layers): + for weight, buffer in zip(expert_weights[0], weights_buffers[layer_idx]): + dummy_recv_buffer = [buffer for _ in range(ep_size)] + torch.distributed.barrier() + all_gather( + dummy_recv_buffer, + weight, + group=ep_group, + ) return - old_global_expert_indices_cpu = old_global_expert_indices.cpu() - new_global_expert_indices_cpu = new_global_expert_indices.cpu() - # NOTE(bowen): We need this synchronize to run, but I don't know why. # If you figure out the reason, please let me know -- thank you! torch.cuda.synchronize() - for layer in range(num_moe_layers): - is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer( + old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy() + new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy() + + start = 0 + while start < num_moe_layers: + end = min(start + max_group_layers, num_moe_layers) + old_group = old_global_expert_indices_cpu[start:end] + new_group = new_global_expert_indices_cpu[start:end] + weights_group = [expert_weights[i] for i in range(start, end)] + buffers_group = weights_buffers[: (end - start)] + + is_unchanged, is_received_locally, recv_metadata = move_to_buffer( num_local_experts=num_local_physical_experts, - old_indices=old_global_expert_indices_cpu[layer].tolist(), - new_indices=new_global_expert_indices_cpu[layer].tolist(), - expert_weights=expert_weights[layer], - expert_weights_buffer=expert_weights_buffer, + old_indices_group=old_group, + new_indices_group=new_group, + expert_weights_group=weights_group, + buffers_group=buffers_group, cuda_stream=None, ep_group=ep_group, ) move_from_buffer( - expert_weights=expert_weights[layer], - expert_weights_buffer=expert_weights_buffer, + weights_group=weights_group, + buffers_group=buffers_group, is_unchanged=is_unchanged, is_received_locally=is_received_locally, - experts_recv_loc=experts_recv_loc, - new_indices=new_global_expert_indices_cpu[layer].tolist(), + recv_metadata=recv_metadata, + new_indices_group=new_group, ep_group=ep_group, ) + start = end def _map_old_expert_indices_with_rank_mapping( diff --git a/vllm/envs.py b/vllm/envs.py index 56558548d3981..df812f0e0a445 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -232,6 +232,7 @@ if TYPE_CHECKING: VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_USE_V2_MODEL_RUNNER: bool = False + VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS: int = 1 def get_default_cache_root(): @@ -1526,6 +1527,10 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_V2_MODEL_RUNNER": lambda: bool( int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0")) ), + # Max number of layers to group in synchronous EPLB weight communication. + "VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS": lambda: int( + os.getenv("VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS", "1") + ), } # --8<-- [end:env-vars-definition] From 561b42729936c2c19606fa82c2963822b1f591e8 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 25 Nov 2025 15:27:59 +0000 Subject: [PATCH 03/30] Add preserve expert on the same slot within gpu optimization Signed-off-by: ilmarkov --- tests/distributed/test_eplb_algo.py | 138 +++++++++++++++++++++++- vllm/distributed/eplb/eplb_state.py | 1 + vllm/distributed/eplb/rebalance_algo.py | 108 ++++++++++++++++++- 3 files changed, 245 insertions(+), 2 deletions(-) diff --git a/tests/distributed/test_eplb_algo.py b/tests/distributed/test_eplb_algo.py index 79805a7cce53b..2f292820f3d7d 100644 --- a/tests/distributed/test_eplb_algo.py +++ b/tests/distributed/test_eplb_algo.py @@ -4,7 +4,10 @@ import pytest import torch -from vllm.distributed.eplb.rebalance_algo import rebalance_experts +from vllm.distributed.eplb.rebalance_algo import ( + preserve_intragpu_slots, + rebalance_experts, +) def test_basic_rebalance(): @@ -306,3 +309,136 @@ if __name__ == "__main__": print(phy2log) test_basic_rebalance() + + +def _make_phyrank_from_phy2log(phy2log: torch.Tensor) -> torch.Tensor: + """Create phyrank from phy2log""" + pr = torch.zeros_like(phy2log) + for layer in range(phy2log.shape[0]): + seen: dict[int, int] = {} + row = phy2log[layer].tolist() + for i, expert in enumerate(row): + r = seen.get(expert, 0) + pr[layer, i] = r + seen[expert] = r + 1 + return pr + + +def _validate_intragpu_rearrangement( + old_global_expert_indices: torch.Tensor, + new_phy2log: torch.Tensor, + new_phyrank: torch.Tensor, + post_phy2log: torch.Tensor, + post_phyrank: torch.Tensor, + num_gpus: int, + slots_per_gpu: int, +): + # Per-GPU checks + for gpu_idx in range(num_gpus): + start = gpu_idx * slots_per_gpu + end = start + slots_per_gpu + old_seg = old_global_expert_indices[0, start:end] + new_seg = new_phy2log[0, start:end] + new_rnk = new_phyrank[0, start:end] + post_seg = post_phy2log[0, start:end] + post_rnk = post_phyrank[0, start:end] + + # Pairwise equality for (expert, rank) pairs to ensure nothing is lost + def sorted_pairs(seg: torch.Tensor, rnk: torch.Tensor): + pairs = list(zip(seg.tolist(), rnk.tolist())) + pairs.sort() + return pairs + + assert sorted_pairs(post_seg, post_rnk) == sorted_pairs(new_seg, new_rnk), ( + f"Per-GPU pairs of (expert,rank) must match new mapping for GPU {gpu_idx}" + ) + + # For experts that remain on the same GPU, the old slot is preserved + # for at least one occurrence; rank at that slot must be valid for that expert + old_list = old_seg.tolist() + new_list = new_seg.tolist() + post_list = post_seg.tolist() + remained = set(old_list) & set(new_list) + new_ranks_for_expert: dict[int, list[int]] = {} + for v, r in zip(new_list, new_rnk.tolist()): + new_ranks_for_expert.setdefault(v, []).append(r) + for expert in remained: + old_pos = old_list.index(expert) + assert post_list[old_pos] == expert, ( + f"Expert {expert} on GPU {gpu_idx} should stay at old slot {old_pos}" + ) + # Rank at preserved slot must be one of the ranks + # the expert has in new mapping + assert post_rnk.tolist()[old_pos] in new_ranks_for_expert[expert], ( + f"Rank for expert {expert} at preserved slot on GPU {gpu_idx} " + "must come from new mapping" + ) + + +def test_preserve_intragpu_slots_simple(): + """Experts that stay on a GPU keep their old slots; incoming not lost.""" + # Setup: 2 GPUs, 4 slots each, 1 layer + num_gpus = 2 + slots_per_gpu = 4 + # Old mapping: GPU0 -> [0,1,2,3], GPU1 -> [4,5,6,7] + old_global_expert_indices = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) + # New mapping shuffles within GPU0 and brings 4,5 into GPU0. + # GPU0 new -> [1,5,0,4] (0 and 1 remain on GPU0 but at different slots) + # GPU1 new -> [6,2,7,3] (6 and 7 remain on GPU1, 2 and 3 move in) + phy2log = torch.tensor([[1, 5, 0, 4, 6, 2, 7, 3]]) + # Derive phyrank from replica occurrence order per expert + phyrank = _make_phyrank_from_phy2log(phy2log) + + post_phy2log, post_phyrank = preserve_intragpu_slots( + phy2log, phyrank, num_gpus, old_global_expert_indices + ) + + # Shapes preserved + assert post_phy2log.shape == phy2log.shape + assert post_phyrank.shape == phyrank.shape + + _validate_intragpu_rearrangement( + old_global_expert_indices, + phy2log, + phyrank, + post_phy2log, + post_phyrank, + num_gpus, + slots_per_gpu, + ) + + +def test_preserve_intragpu_slots_with_duplicates(): + """Test preserve intragpu slots with duplicates""" + # Setup: 2 GPUs, 5 slots each (total 10 physical experts), 1 layer + num_gpus = 2 + slots_per_gpu = 5 + # Old mapping: + # GPU0 -> [0, 1, 0, 2, 3] (expert 0 duplicated) + # GPU1 -> [4, 5, 6, 1, 2] + old_global_expert_indices = torch.tensor([[0, 1, 0, 2, 3, 4, 5, 6, 1, 2]]) + # New mapping reorders within GPUs and moves some experts across GPUs, + # while still including duplicates: + # GPU0 new -> [0, 5, 4, 0, 1] (expert 0 duplicated, 4/5 incoming) + # GPU1 new -> [6, 2, 3, 1, 2] (expert 2 duplicated) + phy2log = torch.tensor([[0, 5, 4, 0, 1, 6, 2, 3, 1, 2]]) + # Derive ranks so duplicates have ranks [0,1,...] by occurrence + phyrank = _make_phyrank_from_phy2log(phy2log) + + post_phy2log, post_phyrank = preserve_intragpu_slots( + phy2log, phyrank, num_gpus, old_global_expert_indices + ) + + # Shapes preserved + assert post_phy2log.shape == phy2log.shape + assert post_phyrank.shape == phyrank.shape + + _validate_intragpu_rearrangement( + old_global_expert_indices, + phy2log, + phyrank, + post_phy2log, + post_phyrank, + num_gpus, + slots_per_gpu, + ) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 0dd3999c1a65a..c768cc9a0593b 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -795,6 +795,7 @@ class EplbState: num_groups, num_nodes, num_gpus, + eplb_model_state.physical_to_logical_map, ) if not eplb_model_state.is_async_enabled or is_profile: diff --git a/vllm/distributed/eplb/rebalance_algo.py b/vllm/distributed/eplb/rebalance_algo.py index e6645e524cc3e..72fcefaddf522 100644 --- a/vllm/distributed/eplb/rebalance_algo.py +++ b/vllm/distributed/eplb/rebalance_algo.py @@ -197,12 +197,110 @@ def rebalance_experts_hierarchical( return pphy2log, pphyrank, logcnt +def preserve_intragpu_slots( + phy2log: torch.Tensor, + phyrank: torch.Tensor, + num_gpus: int, + old_global_expert_indices: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Reorder the new mapping per GPU so that experts that remain on the same GPU + keep their previous slot positions when possible. Incoming experts to that GPU + fill any remaining available slots. This is applied only when the number of GPUs + is unchanged and the slots per GPU remain the same between the old and new mappings. + """ + device = phy2log.device + new_num_phy = phy2log.shape[1] + old_num_phy = old_global_expert_indices.shape[1] + if ( + num_gpus <= 0 + or new_num_phy % num_gpus != 0 + or old_num_phy % num_gpus != 0 + or (new_num_phy // num_gpus) != (old_num_phy // num_gpus) + ): + return phy2log, phyrank + + # Move to CPU and convert to NumPy for processing + phy2log_np = phy2log.cpu().numpy() + phyrank_np = phyrank.cpu().numpy() + old_np = old_global_expert_indices.cpu().numpy() + + slots_per_gpu = new_num_phy // num_gpus + num_layers = phy2log_np.shape[0] + + post_phy2log_np = phy2log_np.copy() + post_phyrank_np = phyrank_np.copy() + + for gpu_idx in range(num_gpus): + start = gpu_idx * slots_per_gpu + end = start + slots_per_gpu + # Segments across all layers for this GPU + old_seg = old_np[:, start:end] # [L, S] + new_seg = phy2log_np[:, start:end] # [L, S] + new_rnk = phyrank_np[:, start:end] # [L, S] + + used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool) + preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool) + + # First pass: preserve same-logical experts in their previous slots + for pos in range(slots_per_gpu): + # matches: [L, S], True where new_seg has the same logical value + # as the old slot 'pos' and not used + matches = (new_seg == old_seg[:, pos][:, None]) & (~used_new_indices) + has_any = matches.any(axis=1) + if np.any(has_any): + first_idx = np.argmax(matches, axis=1) + rows = np.nonzero(has_any)[0] + cols = first_idx[rows] + post_phy2log_np[rows, start + pos] = new_seg[rows, cols] + post_phyrank_np[rows, start + pos] = new_rnk[rows, cols] + used_new_indices[rows, cols] = True + preserved_positions[rows, pos] = True + + # Second pass: fill remaining slots with remaining new experts + remaining_mask = ~used_new_indices # [L, S] + fill_mask = ~preserved_positions # [L, S] + if remaining_mask.any() and fill_mask.any(): + idx_base = np.broadcast_to( + np.arange(slots_per_gpu), (num_layers, slots_per_gpu) + ) + large = slots_per_gpu + 1 + remaining_priority = np.where(remaining_mask, idx_base, large) + fill_priority = np.where(fill_mask, idx_base, large) + # Sort to get per-row ordered indices of True positions + remaining_indices = np.argsort(remaining_priority, axis=1) + fill_indices = np.argsort(fill_priority, axis=1) + # How many to fill per row + remaining_counts = remaining_mask.sum(axis=1) + fill_counts = fill_mask.sum(axis=1) + take_counts = np.minimum(remaining_counts, fill_counts) + # Assign per row + for layer_idx in range(num_layers): + k = int(take_counts[layer_idx]) + if k <= 0: + continue + src_pos = remaining_indices[layer_idx, :k] + dst_pos = fill_indices[layer_idx, :k] + post_phy2log_np[layer_idx, start + dst_pos] = new_seg[ + layer_idx, src_pos + ] + post_phyrank_np[layer_idx, start + dst_pos] = new_rnk[ + layer_idx, src_pos + ] + + # Convert back to torch and move to original device + post_phy2log = torch.from_numpy(post_phy2log_np).to(device) + post_phyrank = torch.from_numpy(post_phyrank_np).to(device) + return post_phy2log, post_phyrank + + def rebalance_experts( weight: torch.Tensor, num_replicas: int, num_groups: int, num_nodes: int, num_gpus: int, + old_global_expert_indices: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Entry point for expert-parallelism load balancer. @@ -239,6 +337,14 @@ def rebalance_experts( phy2log, phyrank, logcnt = rebalance_experts_hierarchical( weight, num_replicas, 1, 1, num_gpus ) + + # Optional postprocessing to preserve slots for experts moving within the same GPU + # Only apply when the number of GPUs and slots per GPU remain unchanged. + # Helps to avoid unnecessary weight copying when experts move within the same GPU. + if old_global_expert_indices is not None: + phy2log, phyrank = preserve_intragpu_slots( + phy2log, phyrank, num_gpus, old_global_expert_indices + ) num_redundant_experts = num_replicas - num_logical_experts maxlogcnt = num_redundant_experts + 1 log2phy: torch.Tensor = torch.full( @@ -257,4 +363,4 @@ def rebalance_experts( return phy2log, log2phy, logcnt -__all__ = ["rebalance_experts"] +__all__ = ["rebalance_experts", "preserve_intragpu_slots"] From 30bab971c02f34971da93f9834f3789ff48a2511 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 26 Nov 2025 14:11:19 +0000 Subject: [PATCH 04/30] Edit config and fix config post_init Signed-off-by: ilmarkov --- vllm/config/parallel.py | 4 ++++ vllm/distributed/eplb/eplb_state.py | 10 ++++++++-- vllm/distributed/eplb/rebalance_execute.py | 6 +++--- vllm/engine/arg_utils.py | 8 ++++---- 4 files changed, 19 insertions(+), 9 deletions(-) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 7ba1da5db3849..44b89c3d24cbe 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -60,6 +60,10 @@ class EPLBConfig: Log the balancedness each step of expert parallelism. This is turned off by default since it will cause communication overhead. """ + log_balancedness_interval: int = 1 + """ + Interval for logging the balancedness. + """ use_async: bool = False """ Whether to use non-blocking EPLB. diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index c768cc9a0593b..3ee421ed3d1cf 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -549,7 +549,12 @@ class EplbState: for eplb_model_state in self.model_states.values(): eplb_model_state.expert_load_pass.zero_() - if log_stats: + if ( + log_stats + and self.expert_rearrangement_step + % self.parallel_config.eplb_config.log_balancedness_interval + == 0 + ): # Sync the expert load pass for each model (main and drafter). # expert_load_pass: (num_moe_layers, num_physical_experts) expert_load_pass_list = self._sync_load_pass() @@ -581,9 +586,10 @@ class EplbState: if ep_group.rank() == 0: logger.info( - "EPLB step: %d for model %s: avg_tokens=%.2f, " + "EPLB step: %d/%d for model %s: avg_tokens=%.2f, " "max_tokens=%d, balancedness=%.4f", self.expert_rearrangement_step, + self.expert_rearrangement_step_interval, eplb_model_state.model_name, avg_tokens, max_tokens, diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index aa9f77f3ca5c4..5bc111cf02756 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -528,9 +528,6 @@ def rearrange_expert_weights_inplace( # Max number of layers to group for communication max_group_layers = envs.VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS max_group_layers = max(min(max_group_layers, num_moe_layers), 1) - logger.info_once( - f"EPLB Sync: rearrange max_group_layers: {max_group_layers}", scope="global" - ) first_layer_weights = list(expert_weights[0]) # Buffers to hold the expert weights during the exchange. @@ -552,6 +549,9 @@ def rearrange_expert_weights_inplace( group=ep_group, ) return + logger.info_once( + f"EPLB Sync: rearrange max_group_layers: {max_group_layers}", scope="global" + ) # NOTE(bowen): We need this synchronize to run, but I don't know why. # If you figure out the reason, please let me know -- thank you! diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 696ff3a1f4024..8fbfcac7d2cd1 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -419,10 +419,10 @@ class EngineArgs: ) _api_process_count: int = ParallelConfig._api_process_count _api_process_rank: int = ParallelConfig._api_process_rank - num_redundant_experts: int = EPLBConfig.num_redundant_experts - eplb_window_size: int = EPLBConfig.window_size - eplb_step_interval: int = EPLBConfig.step_interval - eplb_log_balancedness: bool = EPLBConfig.log_balancedness + num_redundant_experts: int | None = None + eplb_window_size: int | None = None + eplb_step_interval: int | None = None + eplb_log_balancedness: bool | None = None max_parallel_loading_workers: int | None = ( ParallelConfig.max_parallel_loading_workers ) From c6f14d1a2715278cfcfaa53508ac34d7eefda653 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 26 Nov 2025 17:31:13 +0000 Subject: [PATCH 05/30] Optimize after codex review Signed-off-by: ilmarkov --- vllm/distributed/eplb/rebalance_execute.py | 55 +++++++++++++--------- 1 file changed, 33 insertions(+), 22 deletions(-) diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 5bc111cf02756..83e1e675ba639 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -7,7 +7,6 @@ This involves the exchange of expert weights between GPUs. """ from collections.abc import Iterable, Sequence -from functools import partial import numpy as np import torch @@ -123,6 +122,10 @@ def move_to_buffer( is_unchanged = np.zeros((group_size, num_local_experts), dtype=np.bool_) is_received_locally = np.zeros((group_size, num_local_experts), dtype=np.bool_) recv_primary_mask = np.zeros((group_size, num_local_experts), dtype=np.bool_) + # Cache desired new expert ids per local row, for all layers + new_local_expert_ids_mat = np.full( + (group_size, num_local_experts), -1, dtype=np.int64 + ) send_counts = np.zeros(group_size, dtype=np.int32) send_expert_ids = np.full((group_size, num_local_experts), -1, dtype=np.int64) send_src_rows = np.full((group_size, num_local_experts), -1, dtype=np.int32) @@ -140,6 +143,7 @@ def move_to_buffer( old_local_expert_ids = old_indices[local_global] new_local_expert_ids = layer_new_indices[local_global] + new_local_expert_ids_mat[layer_idx, :] = new_local_expert_ids # Unchanged per-dst mask unchanged_mask = old_local_expert_ids == new_local_expert_ids @@ -187,34 +191,41 @@ def move_to_buffer( else: recv_counts[layer_idx] = 0 + # Precompute per-layer destination mask that actually needs local buffering: + # need change, received locally, and valid target expert id + eligible_local_buffer_mask = np.logical_and( + np.logical_and(~is_unchanged, is_received_locally), + new_local_expert_ids_mat != -1, + ) + # 1. Local moves into tmp buffers for layer_idx in range(group_size): - layer_is_unchanged = is_unchanged[layer_idx, :] - layer_is_received_locally = is_received_locally[layer_idx, :] - layer_new_indices = new_indices_group[layer_idx] layer_send_count = int(send_counts[layer_idx]) + if layer_send_count <= 0: + continue + layer_send_experts = send_expert_ids[layer_idx, :layer_send_count] layer_send_srcs = send_src_rows[layer_idx, :layer_send_count] - local2global = partial( - idx_local_to_global, - local_cnt=num_local_experts, - ep_rank=ep_rank, - ) layer_weights_list = list(expert_weights_group[layer_idx]) layer_buffers_list = list(buffers_group[layer_idx]) - for dst in range(num_local_experts): - if layer_is_unchanged[dst] or not layer_is_received_locally[dst]: - continue - dst_global = local2global(dst) - expert = layer_new_indices[dst_global] - if expert == -1: - continue - matches = np.nonzero(layer_send_experts == expert)[0] - if matches.size == 0: - continue - src_local = int(layer_send_srcs[matches[0]]) - for w, b in zip(layer_weights_list, layer_buffers_list): - b[dst].copy_(w[src_local]) + new_local_expert_ids = new_local_expert_ids_mat[layer_idx, :] + + # Only consider destination rows that are eligible for local buffering + eligible_mask = eligible_local_buffer_mask[layer_idx, :] + if not bool(eligible_mask.any()): + continue + + dest_indices = np.nonzero(eligible_mask)[0].tolist() + # Build a map from expert_id to its source row. + expert_to_src_map = dict(zip(layer_send_experts, layer_send_srcs)) + + for dst in dest_indices: + expert = new_local_expert_ids[dst] + src_local = expert_to_src_map.get(expert, -1) + if src_local != -1: + for w, b in zip(layer_weights_list, layer_buffers_list): + b[dst].copy_(w[src_local]) + p2p_ops: list[P2POp] = [] # 2. Post sends per layer From 08083749bee09b52794d8db0d57254504ac29427 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 26 Nov 2025 21:50:09 +0000 Subject: [PATCH 06/30] Vectorize get_ep_ranks_with_experts Signed-off-by: ilmarkov --- vllm/distributed/eplb/rebalance_execute.py | 194 +++++++++++++-------- 1 file changed, 119 insertions(+), 75 deletions(-) diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 83e1e675ba639..a7e7c402aac5c 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -24,80 +24,111 @@ from vllm.logger import init_logger logger = init_logger(__name__) -def idx_local_to_global( - local_idx: int, - local_cnt: int, - ep_rank: int, -) -> int: - """ - Convert a local expert index to a global expert index. - """ - return ep_rank * local_cnt + local_idx - - -def idx_global_to_local( - global_idx: int, - local_cnt: int, - ep_rank: int, -) -> int: - """ - Convert a global expert index to a local expert index. - """ - return global_idx - ep_rank * local_cnt - - -def global_idx_to_rank( - global_idx: int, - local_cnt: int, -) -> int: - """ - Convert a global expert index to a rank index. - """ - return global_idx // local_cnt - - -def get_ep_ranks_with_expert( - idx: int, +def get_ep_ranks_with_experts_batch( + expert_ids: np.ndarray, num_local_experts: int, old_indices: np.ndarray, new_indices: np.ndarray, -) -> tuple[list[int], list[int]]: +) -> tuple[dict[int, list[int]], dict[int, list[int]]]: """ Get the ranks of the experts that need to be exchanged. Args: - idx: The index of the expert. + expert_ids: 1D array of expert indices to query. num_local_experts: The number of local experts. old_indices: The old indices of the experts. new_indices: The new indices of the experts. Returns: - A tuple of two lists: - - The ranks of the experts that need to be sent. - - The ranks of the experts that need to be received. + A tuple of two dictionaries mapping expert_id to: + - ranks_to_send: The ranks that have this expert and need to send. + - ranks_to_recv: The ranks that need to receive this expert. """ - # Indices where expert idx appears - old_pos = np.nonzero(old_indices == idx)[0] - new_pos = np.nonzero(new_indices == idx)[0] - # Map positions to ranks - if old_pos.size > 0: - old_ranks = old_pos // num_local_experts - uniq_send, first_idx_send = np.unique(old_ranks, return_index=True) - order_send = np.argsort(first_idx_send) - ranks_to_send = uniq_send[order_send].astype(int).tolist() - else: - ranks_to_send = [] - if new_pos.size > 0: - new_ranks = new_pos // num_local_experts - uniq_recv, first_idx_recv = np.unique(new_ranks, return_index=True) - order_recv = np.argsort(first_idx_recv) - ranks_to_recv = uniq_recv[order_recv].astype(int).tolist() - else: - ranks_to_recv = [] - # Remove ranks that have local copies to avoid unnecessary recv - ranks_to_send_set = set(ranks_to_send) - ranks_to_recv_actual = [r for r in ranks_to_recv if r not in ranks_to_send_set] - return ranks_to_send, ranks_to_recv_actual + ranks_to_send_map: dict[int, list[int]] = {} + ranks_to_recv_map: dict[int, list[int]] = {} + + # Fast path: if no experts, return empty dicts + if expert_ids.size == 0: + return ranks_to_send_map, ranks_to_recv_map + + unique_experts = np.unique(expert_ids) + num_positions = len(old_indices) + position_indices = np.arange(num_positions, dtype=np.int32) + + # Vectorized approach: find all positions matching any query expert in one pass + # Use np.isin to get boolean masks for all relevant positions at once + old_relevant_mask = np.isin(old_indices, unique_experts) + new_relevant_mask = np.isin(new_indices, unique_experts) + + # Process old_indices (send ranks) + if np.any(old_relevant_mask): + old_relevant_positions = position_indices[old_relevant_mask] + old_relevant_experts = old_indices[old_relevant_mask] + old_relevant_ranks = old_relevant_positions // num_local_experts + + # Sort by expert first, then by position (to maintain first-appearance order) + sort_order = np.lexsort((old_relevant_positions, old_relevant_experts)) + sorted_experts = old_relevant_experts[sort_order] + sorted_ranks = old_relevant_ranks[sort_order] + + # Find boundaries where expert changes + expert_boundaries = np.concatenate( + [[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]] + ) + + # For each expert, extract unique ranks in order of first appearance + for i in range(len(expert_boundaries) - 1): + start, end = expert_boundaries[i], expert_boundaries[i + 1] + expert = int(sorted_experts[start]) + expert_ranks = sorted_ranks[start:end] + + # Get unique ranks preserving order + _, unique_idx = np.unique(expert_ranks, return_index=True) + unique_ranks = expert_ranks[np.sort(unique_idx)] + ranks_to_send_map[expert] = unique_ranks.tolist() + + # Process new_indices (recv ranks) + if np.any(new_relevant_mask): + new_relevant_positions = position_indices[new_relevant_mask] + new_relevant_experts = new_indices[new_relevant_mask] + new_relevant_ranks = new_relevant_positions // num_local_experts + + # Sort by expert first, then by position + sort_order = np.lexsort((new_relevant_positions, new_relevant_experts)) + sorted_experts = new_relevant_experts[sort_order] + sorted_ranks = new_relevant_ranks[sort_order] + + # Find boundaries where expert changes + expert_boundaries = np.concatenate( + [[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]] + ) + + # For each expert, extract unique ranks and exclude local copies + for i in range(len(expert_boundaries) - 1): + start, end = expert_boundaries[i], expert_boundaries[i + 1] + expert = int(sorted_experts[start]) + expert_ranks = sorted_ranks[start:end] + + # Get unique ranks preserving order + _, unique_idx = np.unique(expert_ranks, return_index=True) + unique_ranks = expert_ranks[np.sort(unique_idx)] + + # Remove ranks that have local copies (in send map) + send_ranks_set = set(ranks_to_send_map.get(expert, [])) + recv_ranks_actual = [ + int(r) for r in unique_ranks if r not in send_ranks_set + ] + ranks_to_recv_map[expert] = recv_ranks_actual + + # Handle experts that only appear in old (send only) or new (recv only) + for expert in unique_experts: + expert = int(expert) + if expert not in ranks_to_send_map: + ranks_to_send_map[expert] = [] + if expert not in ranks_to_recv_map: + ranks_to_recv_map[expert] = [] + + return ranks_to_send_map, ranks_to_recv_map def move_to_buffer( @@ -228,6 +259,10 @@ def move_to_buffer( p2p_ops: list[P2POp] = [] + # Pre-compute global ranks mapping + ep_size = ep_group.size() + rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)} + # 2. Post sends per layer for layer_idx in range(group_size): old_indices = old_indices_group[layer_idx] @@ -241,13 +276,17 @@ def move_to_buffer( order = np.argsort(experts, kind="stable") experts = experts[order] srcs = srcs[order] + + send_map, recv_map = get_ep_ranks_with_experts_batch( + experts, + num_local_experts, + old_indices, + layer_new_indices, + ) + for expert, src in zip(experts.tolist(), srcs.tolist()): - ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( - expert, - num_local_experts, - old_indices, - layer_new_indices, - ) + ranks_to_send = send_map[expert] + ranks_to_recv = recv_map[expert] if not ranks_to_send or not ranks_to_recv: continue num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) @@ -260,7 +299,7 @@ def move_to_buffer( if recver_pos < len(ranks_to_recv): recv_ranks.append(ranks_to_recv[recver_pos]) for dst in recv_ranks: - dst_global = get_global_rank(ep_group, dst) + dst_global = rank_to_global[dst] p2p_ops += [ P2POp( torch.distributed.isend, @@ -283,13 +322,18 @@ def move_to_buffer( order = np.argsort(experts, kind="stable") experts = experts[order] dsts = dsts[order] + + # Batch query all experts for this layer + send_map, recv_map = get_ep_ranks_with_experts_batch( + experts, + num_local_experts, + old_indices, + layer_new_indices, + ) + for expert, dst in zip(experts.tolist(), dsts.tolist()): - ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert( - expert, - num_local_experts, - old_indices, - layer_new_indices, - ) + ranks_to_send = send_map[expert] + ranks_to_recv = recv_map[expert] if not ranks_to_send or not ranks_to_recv: continue num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send) @@ -299,7 +343,7 @@ def move_to_buffer( src = ranks_to_send[recver_pos // num_dst_per_sender] else: src = ranks_to_send[recver_pos - remainder_start] - src_global = get_global_rank(ep_group, src) + src_global = rank_to_global[src] p2p_ops += [ P2POp( torch.distributed.irecv, From b4628728043a63f3574a309ed36f146c4cff12b3 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 9 Dec 2025 12:21:31 +0000 Subject: [PATCH 07/30] Fix pre-commit Signed-off-by: ilmarkov --- vllm/distributed/eplb/policy/abstract.py | 5 ++++- vllm/distributed/eplb/policy/default.py | 4 +++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/eplb/policy/abstract.py b/vllm/distributed/eplb/policy/abstract.py index 40ed621c84892..f4435f11bd57b 100644 --- a/vllm/distributed/eplb/policy/abstract.py +++ b/vllm/distributed/eplb/policy/abstract.py @@ -16,6 +16,7 @@ class AbstractEplbPolicy(ABC): num_groups: int, num_nodes: int, num_ranks: int, + old_global_expert_indices: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Entry point for expert-parallelism load balancer. @@ -28,7 +29,9 @@ class AbstractEplbPolicy(ABC): num_groups: number of expert groups num_nodes: number of server nodes num_ranks: number of ranks, must be a multiple of `num_nodes` - + old_global_expert_indices: [layers, num_logical_experts], the old global + expert indices. Used to avoid unnecessary weight copying + for experts moving within one rank. Returns: physical_to_logical_map: [layers, num_replicas], the expert index of each replica diff --git a/vllm/distributed/eplb/policy/default.py b/vllm/distributed/eplb/policy/default.py index 82fd1b94acaea..970a1614933ee 100644 --- a/vllm/distributed/eplb/policy/default.py +++ b/vllm/distributed/eplb/policy/default.py @@ -328,7 +328,9 @@ class DefaultEplbPolicy(AbstractEplbPolicy): num_nodes: number of server nodes, where the intra-node network (e.g, NVLink) is faster num_ranks: number of ranks, must be a multiple of `num_nodes` - + old_global_expert_indices: [layers, num_logical_experts], the old global + expert indices. Used to avoid unnecessary weight copying + for experts moving within one rank. Returns: phy2log: [layers, num_replicas], the expert index of each replica From cfac6b3f648fec000d0786a7bc2dad02b436ab82 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 9 Dec 2025 21:28:32 +0000 Subject: [PATCH 08/30] Remove eplb config fix Signed-off-by: ilmarkov --- vllm/engine/arg_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d0bfab4e6569f..2f307a7ccf16d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -423,10 +423,6 @@ class EngineArgs: ) _api_process_count: int = ParallelConfig._api_process_count _api_process_rank: int = ParallelConfig._api_process_rank - num_redundant_experts: int | None = None - eplb_window_size: int | None = None - eplb_step_interval: int | None = None - eplb_log_balancedness: bool | None = None max_parallel_loading_workers: int | None = ( ParallelConfig.max_parallel_loading_workers ) From 6b2a1de50061f1e55fce723e346f5b92b1efd233 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 10 Dec 2025 12:44:29 +0000 Subject: [PATCH 09/30] Updates after review Signed-off-by: ilmarkov --- vllm/distributed/eplb/eplb_state.py | 33 ++++--- vllm/distributed/eplb/rebalance_execute.py | 101 +++++++++++++++++---- 2 files changed, 105 insertions(+), 29 deletions(-) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 0fc1f03e1947c..575425bcf81c6 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -47,7 +47,11 @@ from vllm.model_executor.models.interfaces import MixtureOfExperts from .async_worker import start_async_worker from .policy import EPLB_POLICIES, AbstractEplbPolicy, DefaultEplbPolicy -from .rebalance_execute import move_from_buffer, rearrange_expert_weights_inplace +from .rebalance_execute import ( + RecvMetadata, + move_from_buffer, + rearrange_expert_weights_inplace, +) logger = init_logger(__name__) @@ -175,14 +179,9 @@ class EplbModelState: intermediate variable between `move_to_buffer` and `move_to_workspace`. The size is same as the num of physical experts in the current layer. """ - recv_metadata: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] + recv_metadata: RecvMetadata """ intermediate variable between `move_to_buffer` and `move_to_workspace`. - The tuple contains: - - recv_primary_mask: np.ndarray, shape (group_size, num_local_experts) - - recv_counts: np.ndarray, shape (group_size,) - - recv_expert_ids: np.ndarray, shape (group_size, num_local_experts) - - recv_dst_rows: np.ndarray, shape (group_size, num_local_experts) """ is_async_enabled: bool """ @@ -514,7 +513,12 @@ class EplbState: pending_global_ready_check=False, is_unchanged=np.array([]), is_received_locally=np.array([]), - recv_metadata=(np.array([]), np.array([]), np.array([]), np.array([])), + recv_metadata=RecvMetadata( + recv_primary_mask=np.array([]), + recv_counts=np.array([]), + recv_expert_ids=np.array([]), + recv_dst_rows=np.array([]), + ), is_async_enabled=self.is_async, cuda_device_index=self.cuda_device_index, new_physical_to_logical_map=new_physical_to_logical_map, @@ -985,17 +989,20 @@ class EplbState: model_state.model.expert_weights[model_state.layer_to_transfer] ] buffers_group = [model_state.expert_buffer] + new_indices_group = ( + model_state.new_physical_to_logical_map[ + model_state.layer_to_transfer : model_state.layer_to_transfer + 1 + ] + .cpu() + .numpy() + ) move_from_buffer( weights_group=weights_group, buffers_group=buffers_group, is_unchanged=model_state.is_unchanged, is_received_locally=model_state.is_received_locally, recv_metadata=model_state.recv_metadata, - new_indices_group=model_state.new_physical_to_logical_map[ - model_state.layer_to_transfer : model_state.layer_to_transfer + 1 - ] - .cpu() - .numpy(), + new_indices_group=new_indices_group, ep_group=ep_group, ) transferred_layer = model_state.layer_to_transfer diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index a7e7c402aac5c..4a8590a027213 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -7,6 +7,7 @@ This involves the exchange of expert weights between GPUs. """ from collections.abc import Iterable, Sequence +from dataclasses import dataclass import numpy as np import torch @@ -24,6 +25,26 @@ from vllm.logger import init_logger logger = init_logger(__name__) +@dataclass +class RecvMetadata: + """Metadata describing remote receives during EPLB rebalancing.""" + + recv_primary_mask: np.ndarray + """Mask of (layer_group_size, num_local_experts) + indicating primary experts received.""" + recv_counts: np.ndarray + """Number of received experts for each layer.""" + recv_expert_ids: np.ndarray + """Expert ids (layer_group_size, num_local_experts) of remote primary experts.""" + recv_dst_rows: np.ndarray + """Target expert indices (layer_group_size, num_local_experts) + in local tensors to send.""" + + +# Type alias for the result of move_to_buffer or transfer_layer +MoveToBufferResult = tuple[np.ndarray, np.ndarray, RecvMetadata] + + def get_ep_ranks_with_experts_batch( expert_ids: np.ndarray, num_local_experts: int, @@ -139,11 +160,28 @@ def move_to_buffer( buffers_group: Sequence[Sequence[torch.Tensor]], cuda_stream: torch.cuda.Stream | None, ep_group: ProcessGroup, -) -> tuple[ - np.ndarray, np.ndarray, tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] -]: +) -> MoveToBufferResult: """ - Perform expert weights rearrangement of a group of layers. + Rearranges expert weights across a group of layers + during mixture-of-experts (MoE) expert parallel rebalancing. + + Args: + num_local_experts: Number of local experts. + old_indices_group: (num_layers, num_experts_total) ndarray of current + (old) global-to-local expert assignments. + new_indices_group: (num_layers, num_experts_total) ndarray of desired + (new) global-to-local assignments after rebalance. + expert_weights_group: Original expert weights for each layer. + buffers_group: List of per-layer intermediate buffers (one per tensor). + cuda_stream: CUDA stream for async copies (can be None for sync mode). + ep_group: Distributed process group for expert parallel comms. + + Returns: + is_unchanged (np.ndarray): (num_layers, num_local_experts), True where an + expert row is unchanged after rebalance. + is_received_locally (np.ndarray): (num_layers, num_local_experts), True + where a row can be updated from local data. + RecvMetadata: Metadata needed for completing remote weight transfers. """ assert len(old_indices_group) == len(new_indices_group) == len(expert_weights_group) group_size = len(old_indices_group) @@ -214,10 +252,10 @@ def move_to_buffer( desired_experts, return_index=True ) dst_rows = desired_dsts[uniq_indices] - layer_send_count = int(uniq_recv_experts.shape[0]) - recv_counts[layer_idx] = layer_send_count - recv_expert_ids[layer_idx, :layer_send_count] = uniq_recv_experts - recv_dst_rows[layer_idx, :layer_send_count] = dst_rows + layer_recv_count = int(uniq_recv_experts.shape[0]) + recv_counts[layer_idx] = layer_recv_count + recv_expert_ids[layer_idx, :layer_recv_count] = uniq_recv_experts + recv_dst_rows[layer_idx, :layer_recv_count] = dst_rows recv_primary_mask[layer_idx, dst_rows] = True else: recv_counts[layer_idx] = 0 @@ -367,7 +405,12 @@ def move_to_buffer( return ( is_unchanged, is_received_locally, - (recv_primary_mask, recv_counts, recv_expert_ids, recv_dst_rows), + RecvMetadata( + recv_primary_mask=recv_primary_mask, + recv_counts=recv_counts, + recv_expert_ids=recv_expert_ids, + recv_dst_rows=recv_dst_rows, + ), ) @@ -376,21 +419,42 @@ def move_from_buffer( buffers_group: Sequence[Sequence[torch.Tensor]], is_unchanged: np.ndarray, is_received_locally: np.ndarray, - recv_metadata: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray], + recv_metadata: RecvMetadata, new_indices_group: np.ndarray, ep_group: ProcessGroup, ) -> None: + """ + Copies expert weights from communication buffers back to the target weight tensors, + after EPLB rebalancing. + + Args: + weights_group: Groups of consecutive MoE layers, each containing one or more + weight tensors. + buffers_group: Intermediate buffers matching weights_group.. + is_unchanged: (num_layers, num_local_experts), True + where an expert row is unchanged after rebalance. + is_received_locally: (num_layers, num_local_experts), True + where a row can be updated from local data. + recv_metadata: RecvMetadata containing remote receive metadata. + new_indices_group: np.ndarray giving for each layer the mapping from local rows + to desired (possibly global) expert id, after rebalance. + ep_group: torch.distributed.ProcessGroup for expert parallel communication + domain. + """ assert ( len(weights_group) == len(buffers_group) == len(is_unchanged) == len(is_received_locally) - == len(recv_metadata[0]) + == len(recv_metadata.recv_primary_mask) == len(new_indices_group) ), "Unmatching layer group size" ep_rank = ep_group.rank() group_size = len(is_unchanged) - recv_primary_mask, recv_counts, recv_expert_ids, recv_dst_rows = recv_metadata + recv_primary_mask = recv_metadata.recv_primary_mask + recv_counts = recv_metadata.recv_counts + recv_expert_ids = recv_metadata.recv_expert_ids + recv_dst_rows = recv_metadata.recv_dst_rows num_local_experts = is_unchanged.shape[1] # Mask for rows to copy back from buffers: # copy if locally received OR remote primary recv @@ -468,9 +532,7 @@ async def transfer_layer( layer: int = 0, cuda_stream: torch.cuda.Stream | None = None, rank_mapping: dict[int, int] | None = None, -) -> tuple[ - np.ndarray, np.ndarray, tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray] -]: +) -> MoveToBufferResult: """ Rearranges the expert weights in place according to the new expert indices. @@ -488,6 +550,13 @@ async def transfer_layer( is_profile (bool): If `True`, do not perform any actual weight copy. This is used during profile run, where we only perform dummy communications to reserve enough memory for the buffers. + + Returns: + is_unchanged (np.ndarray): (1, num_local_experts), True where expert + is left unchanged. + is_received_locally (np.ndarray): (1, num_local_experts), True where expert + is not copied locally. + RecvMetadata: Metadata needed for completing remote weight transfers. """ ep_size = ep_group.size() if rank_mapping is not None: @@ -733,4 +802,4 @@ def _map_new_expert_indices_with_rank_mapping( return mapped_expert_indices -__all__ = ["transfer_layer", "move_from_buffer"] +__all__ = ["transfer_layer", "move_from_buffer", "RecvMetadata", "MoveToBufferResult"] From fc54d760a6957c218941a9f94f99d64fa6c8a544 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 11 Dec 2025 09:20:33 +0000 Subject: [PATCH 10/30] Correct eplb state logs Signed-off-by: ilmarkov --- vllm/distributed/eplb/eplb_state.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 575425bcf81c6..3c71755350dc5 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -27,7 +27,6 @@ physical experts. """ import threading -import time from collections.abc import Sequence from dataclasses import dataclass @@ -699,11 +698,14 @@ class EplbState: ep_group = get_ep_group().device_group ep_rank = ep_group.rank() - time_start = None + start_event = None + end_event = None is_main_rank = ep_rank == 0 if is_main_rank: - torch.cuda.synchronize() - time_start = time.perf_counter() + if not self.is_async or is_profile: + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() logger.info( "Rearranging experts %s %s...", "(async mode)" if self.is_async else "sync mode", @@ -864,13 +866,15 @@ class EplbState: new_logical_replica_count ) if is_main_rank: - assert time_start is not None - torch.cuda.synchronize() - time_end = time.perf_counter() + assert start_event is not None + assert end_event is not None + end_event.record() + end_event.synchronize() + gpu_elapsed = start_event.elapsed_time(end_event) / 1000.0 logger.info( - "Rearranged experts%sin %.2f seconds.", + "Rearranged experts %s in %.2f s.", " (profile) " if is_profile else " ", - time_end - time_start, + gpu_elapsed, ) else: max_slots = eplb_model_state.logical_to_physical_map.shape[-1] @@ -1010,7 +1014,7 @@ class EplbState: # After the main thread consumes, advance layer_to_transfer model_state.layer_to_transfer += 1 model_state.ep_buffer_ready = 0 - logger.info( + logger.debug( "model %s successfully move_to_workspace layer %d", model_state.model_name, transferred_layer, From f28720db88db6439da2016e16ada7153ca9b605d Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 11 Dec 2025 10:38:32 +0000 Subject: [PATCH 11/30] Remove layer grouping Signed-off-by: ilmarkov --- tests/distributed/eplb_utils.py | 3 +- tests/distributed/test_eplb_execute.py | 10 +- vllm/distributed/eplb/eplb_state.py | 20 +- vllm/distributed/eplb/rebalance_execute.py | 411 ++++++++------------- vllm/envs.py | 5 - 5 files changed, 168 insertions(+), 281 deletions(-) diff --git a/tests/distributed/eplb_utils.py b/tests/distributed/eplb_utils.py index 9d06e705968bd..27a63e0215148 100644 --- a/tests/distributed/eplb_utils.py +++ b/tests/distributed/eplb_utils.py @@ -15,7 +15,7 @@ from vllm.utils.system_utils import update_environment_variables mp.set_start_method("spawn", force=True) -def distributed_run(fn, world_size, *args, max_grouped_layers=1): +def distributed_run(fn, world_size, *args): number_of_processes = world_size processes: list[mp.Process] = [] for i in range(number_of_processes): @@ -26,7 +26,6 @@ def distributed_run(fn, world_size, *args, max_grouped_layers=1): env["LOCAL_WORLD_SIZE"] = str(number_of_processes) env["MASTER_ADDR"] = "localhost" env["MASTER_PORT"] = "12345" - env["VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS"] = str(max_grouped_layers) p = mp.Process(target=fn, args=(env, world_size, *args)) processes.append(p) p.start() diff --git a/tests/distributed/test_eplb_execute.py b/tests/distributed/test_eplb_execute.py index 1bf231500d9aa..dfb15311202e5 100644 --- a/tests/distributed/test_eplb_execute.py +++ b/tests/distributed/test_eplb_execute.py @@ -306,12 +306,12 @@ def _test_async_transfer_layer_without_mtp_worker( ) cuda_stream.synchronize() move_from_buffer( - weights_group=[expert_weights[layer_idx]], - buffers_group=[expert_buffer], + expert_weights=expert_weights[layer_idx], + expert_weights_buffers=expert_buffer, is_unchanged=is_unchanged, is_received_locally=is_received_locally, recv_metadata=recv_metadata, - new_indices_group=new_indices_cpu[layer_idx : layer_idx + 1], + new_indices=new_indices_cpu[layer_idx], ep_group=ep_group, ) @@ -427,9 +427,8 @@ def _test_rearrange_expert_weights_with_redundancy( (4, 8, 8, 16), ], ) -@pytest.mark.parametrize("group_layers", [1, 2]) def test_rearrange_expert_weights_with_redundancy( - world_size, num_layers, num_local_experts, num_logical_experts, group_layers + world_size, num_layers, num_local_experts, num_logical_experts ): """Test the functionality of rearranging expert weights with redundancy.""" @@ -441,7 +440,6 @@ def test_rearrange_expert_weights_with_redundancy( num_layers, num_local_experts, num_logical_experts, - max_grouped_layers=group_layers, ) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 3c71755350dc5..111479f2ee2f4 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -514,7 +514,7 @@ class EplbState: is_received_locally=np.array([]), recv_metadata=RecvMetadata( recv_primary_mask=np.array([]), - recv_counts=np.array([]), + recv_count=0, recv_expert_ids=np.array([]), recv_dst_rows=np.array([]), ), @@ -989,24 +989,22 @@ class EplbState: stream = torch.cuda.current_stream(device=device_index) stream.wait_event(model_state.buffer_ready_event) model_state.buffer_ready_event = None - weights_group = [ - model_state.model.expert_weights[model_state.layer_to_transfer] + expert_weights = model_state.model.expert_weights[ + model_state.layer_to_transfer ] - buffers_group = [model_state.expert_buffer] - new_indices_group = ( - model_state.new_physical_to_logical_map[ - model_state.layer_to_transfer : model_state.layer_to_transfer + 1 - ] + expert_weights_buffer = model_state.expert_buffer + new_indices = ( + model_state.new_physical_to_logical_map[model_state.layer_to_transfer] .cpu() .numpy() ) move_from_buffer( - weights_group=weights_group, - buffers_group=buffers_group, + expert_weights=expert_weights, + expert_weights_buffers=expert_weights_buffer, is_unchanged=model_state.is_unchanged, is_received_locally=model_state.is_received_locally, recv_metadata=model_state.recv_metadata, - new_indices_group=new_indices_group, + new_indices=new_indices, ep_group=ep_group, ) transferred_layer = model_state.layer_to_transfer diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 4a8590a027213..e37292362582d 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -19,7 +19,6 @@ from torch.distributed import ( get_global_rank, ) -import vllm.envs as envs from vllm.logger import init_logger logger = init_logger(__name__) @@ -30,15 +29,13 @@ class RecvMetadata: """Metadata describing remote receives during EPLB rebalancing.""" recv_primary_mask: np.ndarray - """Mask of (layer_group_size, num_local_experts) - indicating primary experts received.""" - recv_counts: np.ndarray - """Number of received experts for each layer.""" + """Mask of (num_local_experts,) indicating primary experts received.""" + recv_count: int + """Number of received experts for the layer.""" recv_expert_ids: np.ndarray - """Expert ids (layer_group_size, num_local_experts) of remote primary experts.""" + """Expert ids (num_local_experts,) of remote primary experts.""" recv_dst_rows: np.ndarray - """Target expert indices (layer_group_size, num_local_experts) - in local tensors to send.""" + """Target expert indices (num_local_experts,) in local tensors to send.""" # Type alias for the result of move_to_buffer or transfer_layer @@ -154,146 +151,105 @@ def get_ep_ranks_with_experts_batch( def move_to_buffer( num_local_experts: int, - old_indices_group: np.ndarray, - new_indices_group: np.ndarray, - expert_weights_group: Sequence[Iterable[torch.Tensor]], - buffers_group: Sequence[Sequence[torch.Tensor]], + old_indices: np.ndarray, + new_indices: np.ndarray, + expert_weights: Iterable[torch.Tensor], + expert_weights_buffers: Sequence[torch.Tensor], cuda_stream: torch.cuda.Stream | None, ep_group: ProcessGroup, ) -> MoveToBufferResult: """ - Rearranges expert weights across a group of layers - during mixture-of-experts (MoE) expert parallel rebalancing. + Rearranges expert weights during EPLB rebalancing. Args: num_local_experts: Number of local experts. - old_indices_group: (num_layers, num_experts_total) ndarray of current - (old) global-to-local expert assignments. - new_indices_group: (num_layers, num_experts_total) ndarray of desired - (new) global-to-local assignments after rebalance. - expert_weights_group: Original expert weights for each layer. - buffers_group: List of per-layer intermediate buffers (one per tensor). + old_indices: (num_experts_total,) ndarray of current (old) + global-to-local expert assignments. + new_indices: (num_experts_total,) ndarray of desired (new) + global-to-local assignments after rebalance. + expert_weights: Original expert weights for the layer. + expert_weights_buffers: Intermediate buffers (one per tensor). cuda_stream: CUDA stream for async copies (can be None for sync mode). ep_group: Distributed process group for expert parallel comms. Returns: - is_unchanged (np.ndarray): (num_layers, num_local_experts), True where an - expert row is unchanged after rebalance. - is_received_locally (np.ndarray): (num_layers, num_local_experts), True - where a row can be updated from local data. + is_unchanged (np.ndarray): (num_local_experts,), True where an expert row + is unchanged after rebalance. + is_received_locally (np.ndarray): (num_local_experts,), True where a row + can be updated from local data. RecvMetadata: Metadata needed for completing remote weight transfers. """ - assert len(old_indices_group) == len(new_indices_group) == len(expert_weights_group) - group_size = len(old_indices_group) + assert old_indices.shape == new_indices.shape ep_rank = ep_group.rank() - # Pre-allocate per-layer compact maps/masks (numpy) - is_unchanged = np.zeros((group_size, num_local_experts), dtype=np.bool_) - is_received_locally = np.zeros((group_size, num_local_experts), dtype=np.bool_) - recv_primary_mask = np.zeros((group_size, num_local_experts), dtype=np.bool_) - # Cache desired new expert ids per local row, for all layers - new_local_expert_ids_mat = np.full( - (group_size, num_local_experts), -1, dtype=np.int64 - ) - send_counts = np.zeros(group_size, dtype=np.int32) - send_expert_ids = np.full((group_size, num_local_experts), -1, dtype=np.int64) - send_src_rows = np.full((group_size, num_local_experts), -1, dtype=np.int32) - recv_counts = np.zeros(group_size, dtype=np.int32) - recv_expert_ids = np.full((group_size, num_local_experts), -1, dtype=np.int64) - recv_dst_rows = np.full((group_size, num_local_experts), -1, dtype=np.int32) + recv_primary_mask = np.zeros((num_local_experts,), dtype=np.bool_) + send_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64) + send_src_rows = np.full((num_local_experts,), -1, dtype=np.int32) + recv_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64) + recv_dst_rows = np.full((num_local_experts,), -1, dtype=np.int32) + base = ep_rank * num_local_experts local_rows = np.arange(num_local_experts, dtype=np.int32) local_global = base + local_rows - # Build masks and expert maps per layer - for layer_idx in range(group_size): - old_indices = old_indices_group[layer_idx] - layer_new_indices = new_indices_group[layer_idx] + old_local_expert_ids = old_indices[local_global] + new_local_expert_ids = new_indices[local_global] - old_local_expert_ids = old_indices[local_global] - new_local_expert_ids = layer_new_indices[local_global] - new_local_expert_ids_mat[layer_idx, :] = new_local_expert_ids + # Unchanged mask + is_unchanged = old_local_expert_ids == new_local_expert_ids - # Unchanged per-dst mask - unchanged_mask = old_local_expert_ids == new_local_expert_ids - is_unchanged[layer_idx, :] = unchanged_mask + # Local receive eligibility + new_valid = new_local_expert_ids != -1 + can_recv_local = np.isin( + new_local_expert_ids, old_local_expert_ids, assume_unique=False + ) + is_received_locally = np.logical_or( + is_unchanged, np.logical_and(new_valid, can_recv_local) + ) - # Local receive eligibility - new_valid = new_local_expert_ids != -1 - can_recv_local = np.isin( - new_local_expert_ids, old_local_expert_ids, assume_unique=False + # Send map: first src row per unique expert present locally in old mapping + send_count = 0 + valid_old = old_local_expert_ids != -1 + if np.any(valid_old): + uniq_experts, first_idx = np.unique( + old_local_expert_ids[valid_old], return_index=True ) - is_local_recv = np.logical_or( - unchanged_mask, np.logical_and(new_valid, can_recv_local) - ) - is_received_locally[layer_idx, :] = is_local_recv + filtered_rows = local_rows[valid_old] + src_rows = filtered_rows[first_idx] + send_count = int(uniq_experts.shape[0]) + send_expert_ids[:send_count] = uniq_experts + send_src_rows[:send_count] = src_rows - # Send map: first src row per unique expert present locally in old mapping - valid_old = old_local_expert_ids != -1 - if np.any(valid_old): - uniq_experts, first_idx = np.unique( - old_local_expert_ids[valid_old], return_index=True - ) - filtered_rows = local_rows[valid_old] - src_rows = filtered_rows[first_idx] - layer_send_count = int(uniq_experts.shape[0]) - send_counts[layer_idx] = layer_send_count - send_expert_ids[layer_idx, :layer_send_count] = uniq_experts - send_src_rows[layer_idx, :layer_send_count] = src_rows - else: - send_counts[layer_idx] = 0 + # Recv map: primary dst per unique expert needed remotely + recv_count = 0 + need_recv_mask = np.logical_and(~is_received_locally, new_valid) + if np.any(need_recv_mask): + desired_experts = new_local_expert_ids[need_recv_mask] + desired_dsts = local_rows[need_recv_mask] + uniq_recv_experts, uniq_indices = np.unique(desired_experts, return_index=True) + dst_rows = desired_dsts[uniq_indices] + recv_count = int(uniq_recv_experts.shape[0]) + recv_expert_ids[:recv_count] = uniq_recv_experts + recv_dst_rows[:recv_count] = dst_rows + recv_primary_mask[dst_rows] = True - # Recv map: primary dst per unique expert needed remotely - need_recv_mask = np.logical_and(~is_local_recv, new_valid) - if np.any(need_recv_mask): - desired_experts = new_local_expert_ids[need_recv_mask] - desired_dsts = local_rows[need_recv_mask] - uniq_recv_experts, uniq_indices = np.unique( - desired_experts, return_index=True - ) - dst_rows = desired_dsts[uniq_indices] - layer_recv_count = int(uniq_recv_experts.shape[0]) - recv_counts[layer_idx] = layer_recv_count - recv_expert_ids[layer_idx, :layer_recv_count] = uniq_recv_experts - recv_dst_rows[layer_idx, :layer_recv_count] = dst_rows - recv_primary_mask[layer_idx, dst_rows] = True - else: - recv_counts[layer_idx] = 0 - - # Precompute per-layer destination mask that actually needs local buffering: - # need change, received locally, and valid target expert id eligible_local_buffer_mask = np.logical_and( np.logical_and(~is_unchanged, is_received_locally), - new_local_expert_ids_mat != -1, + new_local_expert_ids != -1, ) # 1. Local moves into tmp buffers - for layer_idx in range(group_size): - layer_send_count = int(send_counts[layer_idx]) - if layer_send_count <= 0: - continue - - layer_send_experts = send_expert_ids[layer_idx, :layer_send_count] - layer_send_srcs = send_src_rows[layer_idx, :layer_send_count] - layer_weights_list = list(expert_weights_group[layer_idx]) - layer_buffers_list = list(buffers_group[layer_idx]) - new_local_expert_ids = new_local_expert_ids_mat[layer_idx, :] - - # Only consider destination rows that are eligible for local buffering - eligible_mask = eligible_local_buffer_mask[layer_idx, :] - if not bool(eligible_mask.any()): - continue - - dest_indices = np.nonzero(eligible_mask)[0].tolist() - # Build a map from expert_id to its source row. - expert_to_src_map = dict(zip(layer_send_experts, layer_send_srcs)) - + if bool(eligible_local_buffer_mask.any()) and send_count > 0: + dest_indices = np.nonzero(eligible_local_buffer_mask)[0].tolist() + expert_to_src_map = dict( + zip(send_expert_ids[:send_count], send_src_rows[:send_count]) + ) for dst in dest_indices: expert = new_local_expert_ids[dst] src_local = expert_to_src_map.get(expert, -1) if src_local != -1: - for w, b in zip(layer_weights_list, layer_buffers_list): - b[dst].copy_(w[src_local]) + for w, b in zip(expert_weights, expert_weights_buffers): + b[dst].copy_(w[src_local], non_blocking=True) p2p_ops: list[P2POp] = [] @@ -301,16 +257,10 @@ def move_to_buffer( ep_size = ep_group.size() rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)} - # 2. Post sends per layer - for layer_idx in range(group_size): - old_indices = old_indices_group[layer_idx] - layer_new_indices = new_indices_group[layer_idx] - layer_weights_list = list(expert_weights_group[layer_idx]) - layer_send_count = int(send_counts[layer_idx]) - if layer_send_count == 0: - continue - experts = send_expert_ids[layer_idx, :layer_send_count] - srcs = send_src_rows[layer_idx, :layer_send_count] + # 2. Post sends + if send_count > 0: + experts = send_expert_ids[:send_count] + srcs = send_src_rows[:send_count] order = np.argsort(experts, kind="stable") experts = experts[order] srcs = srcs[order] @@ -319,7 +269,7 @@ def move_to_buffer( experts, num_local_experts, old_indices, - layer_new_indices, + new_indices, ) for expert, src in zip(experts.tolist(), srcs.tolist()): @@ -344,29 +294,22 @@ def move_to_buffer( w[src], dst_global, ) - for w in layer_weights_list + for w in expert_weights ] - # 3. Post recvs per layer - for layer_idx in range(group_size): - old_indices = old_indices_group[layer_idx] - layer_new_indices = new_indices_group[layer_idx] - layer_buffers_list = list(buffers_group[layer_idx]) - layer_recv_count = int(recv_counts[layer_idx]) - if layer_recv_count == 0: - continue - experts = recv_expert_ids[layer_idx, :layer_recv_count] - dsts = recv_dst_rows[layer_idx, :layer_recv_count] + # 3. Post recvs + if recv_count > 0: + experts = recv_expert_ids[:recv_count] + dsts = recv_dst_rows[:recv_count] order = np.argsort(experts, kind="stable") experts = experts[order] dsts = dsts[order] - # Batch query all experts for this layer send_map, recv_map = get_ep_ranks_with_experts_batch( experts, num_local_experts, old_indices, - layer_new_indices, + new_indices, ) for expert, dst in zip(experts.tolist(), dsts.tolist()): @@ -388,7 +331,7 @@ def move_to_buffer( b[dst], src_global, ) - for b in layer_buffers_list + for b in expert_weights_buffers ] # 4. Execute the P2P operations. The real communication happens here. @@ -407,7 +350,7 @@ def move_to_buffer( is_received_locally, RecvMetadata( recv_primary_mask=recv_primary_mask, - recv_counts=recv_counts, + recv_count=recv_count, recv_expert_ids=recv_expert_ids, recv_dst_rows=recv_dst_rows, ), @@ -415,111 +358,82 @@ def move_to_buffer( def move_from_buffer( - weights_group: Sequence[Iterable[torch.Tensor]], - buffers_group: Sequence[Sequence[torch.Tensor]], + expert_weights: Iterable[torch.Tensor], + expert_weights_buffers: list[torch.Tensor], is_unchanged: np.ndarray, is_received_locally: np.ndarray, recv_metadata: RecvMetadata, - new_indices_group: np.ndarray, + new_indices: np.ndarray, ep_group: ProcessGroup, ) -> None: """ - Copies expert weights from communication buffers back to the target weight tensors, + Copies expert weights from communication buffers back to the target weight tensors after EPLB rebalancing. Args: - weights_group: Groups of consecutive MoE layers, each containing one or more - weight tensors. - buffers_group: Intermediate buffers matching weights_group.. - is_unchanged: (num_layers, num_local_experts), True - where an expert row is unchanged after rebalance. - is_received_locally: (num_layers, num_local_experts), True - where a row can be updated from local data. + expert_weights: List of weight tensors for a MoE layer. + expert_weights_buffers: Intermediate buffers matching ``expert_weights``. + is_unchanged: (num_local_experts,), True where an expert row is unchanged. + is_received_locally: (num_local_experts,), True where a row is updated locally. recv_metadata: RecvMetadata containing remote receive metadata. - new_indices_group: np.ndarray giving for each layer the mapping from local rows - to desired (possibly global) expert id, after rebalance. + new_indices: (num_experts_total,) mapping from local rows to desired + (possibly global) expert id, after rebalance. ep_group: torch.distributed.ProcessGroup for expert parallel communication domain. """ - assert ( - len(weights_group) - == len(buffers_group) - == len(is_unchanged) - == len(is_received_locally) - == len(recv_metadata.recv_primary_mask) - == len(new_indices_group) - ), "Unmatching layer group size" ep_rank = ep_group.rank() - group_size = len(is_unchanged) recv_primary_mask = recv_metadata.recv_primary_mask - recv_counts = recv_metadata.recv_counts + recv_count = recv_metadata.recv_count recv_expert_ids = recv_metadata.recv_expert_ids recv_dst_rows = recv_metadata.recv_dst_rows - num_local_experts = is_unchanged.shape[1] + num_local_experts = is_unchanged.shape[0] + # Mask for rows to copy back from buffers: # copy if locally received OR remote primary recv copy_mask = np.logical_or(is_received_locally, recv_primary_mask) - # Copy back local buffered rows into destination weights - for layer_idx in range(group_size): - layer_is_unchanged = is_unchanged[layer_idx, :] - layer_copy_mask = copy_mask[layer_idx, :] - weights_list = list(weights_group[layer_idx]) - buffers_list = list(buffers_group[layer_idx]) - # rows to copy = (~unchanged) & copy_mask - dest_mask_np = np.logical_and(~layer_is_unchanged, layer_copy_mask) - if not bool(dest_mask_np.any()): - continue + dest_mask_np = np.logical_and(~is_unchanged, copy_mask) + if bool(dest_mask_np.any()): dest_indices = np.nonzero(dest_mask_np)[0].tolist() for dst in dest_indices: - for w, b in zip(weights_list, buffers_list): + for w, b in zip(expert_weights, expert_weights_buffers): w[dst].copy_(b[dst]) # Duplicate remote received rows to non-primary duplicate dsts - for layer_idx in range(group_size): - layer_is_unchanged = is_unchanged[layer_idx, :] - layer_is_received_locally = is_received_locally[layer_idx, :] - new_indices = new_indices_group[layer_idx] - weights_list = list(weights_group[layer_idx]) + if recv_count == 0: + return - count_recv = int(recv_counts[layer_idx]) - if count_recv == 0: - # No remote primaries on this layer → no remote duplicates to materialize - continue - # Local view of desired expert ids per local row - base = ep_rank * num_local_experts - local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)] - # Duplicate rows mask: need remote, not primary, and valid expert id - duplicate_mask = np.logical_and( - np.logical_and(~layer_is_unchanged, ~layer_is_received_locally), - np.logical_and(~recv_primary_mask[layer_idx, :], local_experts != -1), - ) - if not bool(duplicate_mask.any()): - continue - dup_dst_rows = np.nonzero(duplicate_mask)[0] - dup_experts = local_experts[dup_dst_rows] + base = ep_rank * num_local_experts + local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)] + duplicate_mask = np.logical_and( + np.logical_and(~is_unchanged, ~is_received_locally), + np.logical_and(~recv_primary_mask, local_experts != -1), + ) + if not bool(duplicate_mask.any()): + return - # Build primary mapping arrays (expert -> primary dst) and vector-match - prim_experts = recv_expert_ids[layer_idx, :count_recv] - prim_dsts = recv_dst_rows[layer_idx, :count_recv] - order = np.argsort(prim_experts, kind="stable") - prim_experts_sorted = prim_experts[order] - prim_dsts_sorted = prim_dsts[order] - pos = np.searchsorted(prim_experts_sorted, dup_experts) - # Filter to experts that have a matching primary entry - valid = np.logical_and( - pos < prim_experts_sorted.shape[0], - prim_experts_sorted[np.minimum(pos, prim_experts_sorted.shape[0] - 1)] - == dup_experts, - ) - if not bool(valid.any()): - continue - matched_dst_rows = dup_dst_rows[valid] - matched_src_rows = prim_dsts_sorted[pos[valid]] + dup_dst_rows = np.nonzero(duplicate_mask)[0] + dup_experts = local_experts[dup_dst_rows] - # Perform row copies per (dst, src) pair without tensor indexing - for dst, src in zip(matched_dst_rows.tolist(), matched_src_rows.tolist()): - for w in weights_list: - w[dst].copy_(w[src]) + prim_experts = recv_expert_ids[:recv_count] + prim_dsts = recv_dst_rows[:recv_count] + order = np.argsort(prim_experts, kind="stable") + prim_experts_sorted = prim_experts[order] + prim_dsts_sorted = prim_dsts[order] + pos = np.searchsorted(prim_experts_sorted, dup_experts) + valid = np.logical_and( + pos < prim_experts_sorted.shape[0], + prim_experts_sorted[np.minimum(pos, prim_experts_sorted.shape[0] - 1)] + == dup_experts, + ) + if not bool(valid.any()): + return + + matched_dst_rows = dup_dst_rows[valid] + matched_src_rows = prim_dsts_sorted[pos[valid]] + + for dst, src in zip(matched_dst_rows.tolist(), matched_src_rows.tolist()): + for w in expert_weights: + w[dst].copy_(w[src]) async def transfer_layer( @@ -555,7 +469,7 @@ async def transfer_layer( is_unchanged (np.ndarray): (1, num_local_experts), True where expert is left unchanged. is_received_locally (np.ndarray): (1, num_local_experts), True where expert - is not copied locally. + can be received locally. RecvMetadata: Metadata needed for completing remote weight transfers. """ ep_size = ep_group.size() @@ -586,10 +500,10 @@ async def transfer_layer( is_unchanged, is_received_locally, recv_metadata = move_to_buffer( num_local_experts=num_local_physical_experts, - old_indices_group=old_global_expert_indices_np[layer : layer + 1], - new_indices_group=new_global_expert_indices_np[layer : layer + 1], - expert_weights_group=[expert_weights[layer]], - buffers_group=[expert_weights_buffer], + old_indices=old_global_expert_indices_np[layer], + new_indices=new_global_expert_indices_np[layer], + expert_weights=expert_weights[layer], + expert_weights_buffers=expert_weights_buffer, cuda_stream=cuda_stream, ep_group=ep_group, ) @@ -649,33 +563,24 @@ def rearrange_expert_weights_inplace( ep_size = ep_group.size() assert num_physical_experts == ep_size * num_local_physical_experts - # Max number of layers to group for communication - max_group_layers = envs.VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS - max_group_layers = max(min(max_group_layers, num_moe_layers), 1) - first_layer_weights = list(expert_weights[0]) # Buffers to hold the expert weights during the exchange. # NOTE: Currently we assume the same weights across different layers # have the same shape. - weights_buffers: list[list[torch.Tensor]] = [ - [torch.empty_like(w) for w in first_layer_weights] - for _ in range(max_group_layers) + weights_buffer: list[torch.Tensor] = [ + torch.empty_like(w) for w in first_layer_weights ] if is_profile: # Reserve communication buffers via a minimal dummy all_gather on first layer - for layer_idx in range(max_group_layers): - for weight, buffer in zip(expert_weights[0], weights_buffers[layer_idx]): - dummy_recv_buffer = [buffer for _ in range(ep_size)] - torch.distributed.barrier() - all_gather( - dummy_recv_buffer, - weight, - group=ep_group, - ) + for weight, buffer in zip(expert_weights[0], weights_buffer): + dummy_recv_buffer = [buffer for _ in range(ep_size)] + torch.distributed.barrier() + all_gather( + dummy_recv_buffer, + weight, + group=ep_group, + ) return - logger.info_once( - f"EPLB Sync: rearrange max_group_layers: {max_group_layers}", scope="global" - ) # NOTE(bowen): We need this synchronize to run, but I don't know why. # If you figure out the reason, please let me know -- thank you! @@ -684,34 +589,26 @@ def rearrange_expert_weights_inplace( old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy() new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy() - start = 0 - while start < num_moe_layers: - end = min(start + max_group_layers, num_moe_layers) - old_group = old_global_expert_indices_cpu[start:end] - new_group = new_global_expert_indices_cpu[start:end] - weights_group = [expert_weights[i] for i in range(start, end)] - buffers_group = weights_buffers[: (end - start)] - + for layer_idx in range(num_moe_layers): is_unchanged, is_received_locally, recv_metadata = move_to_buffer( num_local_experts=num_local_physical_experts, - old_indices_group=old_group, - new_indices_group=new_group, - expert_weights_group=weights_group, - buffers_group=buffers_group, + old_indices=old_global_expert_indices_cpu[layer_idx], + new_indices=new_global_expert_indices_cpu[layer_idx], + expert_weights=expert_weights[layer_idx], + expert_weights_buffers=weights_buffer, cuda_stream=None, ep_group=ep_group, ) move_from_buffer( - weights_group=weights_group, - buffers_group=buffers_group, + expert_weights=expert_weights[layer_idx], + expert_weights_buffers=weights_buffer, is_unchanged=is_unchanged, is_received_locally=is_received_locally, recv_metadata=recv_metadata, - new_indices_group=new_group, + new_indices=new_global_expert_indices_cpu[layer_idx], ep_group=ep_group, ) - start = end def _map_old_expert_indices_with_rank_mapping( diff --git a/vllm/envs.py b/vllm/envs.py index 86d3925710227..8246109eb73af 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -244,7 +244,6 @@ if TYPE_CHECKING: VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_USE_V2_MODEL_RUNNER: bool = False - VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS: int = 1 def get_default_cache_root(): @@ -1564,10 +1563,6 @@ environment_variables: dict[str, Callable[[], Any]] = { "VLLM_USE_V2_MODEL_RUNNER": lambda: bool( int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0")) ), - # Max number of layers to group in synchronous EPLB weight communication. - "VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS": lambda: int( - os.getenv("VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS", "1") - ), } # --8<-- [end:env-vars-definition] From ab0ca861da579ad1c5b3d2562162a12a6e9af73a Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 11 Dec 2025 11:35:10 +0000 Subject: [PATCH 12/30] Futher optimize rearrange Signed-off-by: ilmarkov --- vllm/distributed/eplb/rebalance_execute.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index e37292362582d..4c0528bbd22bc 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -233,10 +233,7 @@ def move_to_buffer( recv_dst_rows[:recv_count] = dst_rows recv_primary_mask[dst_rows] = True - eligible_local_buffer_mask = np.logical_and( - np.logical_and(~is_unchanged, is_received_locally), - new_local_expert_ids != -1, - ) + eligible_local_buffer_mask = np.logical_and(~is_unchanged, is_received_locally) # 1. Local moves into tmp buffers if bool(eligible_local_buffer_mask.any()) and send_count > 0: @@ -396,7 +393,7 @@ def move_from_buffer( dest_indices = np.nonzero(dest_mask_np)[0].tolist() for dst in dest_indices: for w, b in zip(expert_weights, expert_weights_buffers): - w[dst].copy_(b[dst]) + w[dst].copy_(b[dst], non_blocking=True) # Duplicate remote received rows to non-primary duplicate dsts if recv_count == 0: @@ -433,7 +430,7 @@ def move_from_buffer( for dst, src in zip(matched_dst_rows.tolist(), matched_src_rows.tolist()): for w in expert_weights: - w[dst].copy_(w[src]) + w[dst].copy_(w[src], non_blocking=True) async def transfer_layer( From 208b51bcd8ae2325da625b99a28d122381e54d72 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 11 Dec 2025 11:59:20 +0000 Subject: [PATCH 13/30] Upd Signed-off-by: ilmarkov --- vllm/distributed/eplb/rebalance_execute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 021f084937e10..641fad9f4e788 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -699,4 +699,4 @@ def _map_new_expert_indices_with_rank_mapping( return mapped_expert_indices -__all__ = ["transfer_layer", "move_from_buffer", "RecvMetadata", "MoveToBufferResult"] +__all__ = ["transfer_layer", "move_from_buffer", "RecvMetadata"] From a5ecdc18c017d50fbfcf91438e12369e457509b9 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 11 Dec 2025 12:56:54 +0000 Subject: [PATCH 14/30] Add comments Signed-off-by: ilmarkov --- vllm/distributed/eplb/policy/default.py | 29 +++++++++++++++---------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/vllm/distributed/eplb/policy/default.py b/vllm/distributed/eplb/policy/default.py index 970a1614933ee..172f6f044f89e 100644 --- a/vllm/distributed/eplb/policy/default.py +++ b/vllm/distributed/eplb/policy/default.py @@ -263,31 +263,36 @@ class DefaultEplbPolicy(AbstractEplbPolicy): has_any = matches.any(axis=1) if np.any(has_any): first_idx = np.argmax(matches, axis=1) - rows = np.nonzero(has_any)[0] - cols = first_idx[rows] - post_phy2log_np[rows, start + pos] = new_seg[rows, cols] - post_phyrank_np[rows, start + pos] = new_rnk[rows, cols] - used_new_indices[rows, cols] = True - preserved_positions[rows, pos] = True + layer_indices = np.nonzero(has_any)[0] + matched_new_positions = first_idx[layer_indices] + post_phy2log_np[layer_indices, start + pos] = new_seg[ + layer_indices, matched_new_positions + ] + post_phyrank_np[layer_indices, start + pos] = new_rnk[ + layer_indices, matched_new_positions + ] + used_new_indices[layer_indices, matched_new_positions] = True + preserved_positions[layer_indices, pos] = True # Second pass: fill remaining slots with remaining new experts remaining_mask = ~used_new_indices # [L, S] fill_mask = ~preserved_positions # [L, S] if remaining_mask.any() and fill_mask.any(): - idx_base = np.broadcast_to( - np.arange(slots_per_gpu), (num_layers, slots_per_gpu) - ) + idx_base = np.tile(np.arange(slots_per_gpu), (num_layers, 1)) + # Sentinel value for unavailable positions. large = slots_per_gpu + 1 + # Priorities: keep original index for available spots, set sentinel + # for unavailable; lower is earlier. remaining_priority = np.where(remaining_mask, idx_base, large) fill_priority = np.where(fill_mask, idx_base, large) - # Sort to get per-row ordered indices of True positions + # Sort to get ordered indices of available src/dst positions per layer. remaining_indices = np.argsort(remaining_priority, axis=1) fill_indices = np.argsort(fill_priority, axis=1) - # How many to fill per row + # Fill count per layer (cannot exceed either side). remaining_counts = remaining_mask.sum(axis=1) fill_counts = fill_mask.sum(axis=1) take_counts = np.minimum(remaining_counts, fill_counts) - # Assign per row + # Assign remaining new experts to remaining slots per layer. for layer_idx in range(num_layers): k = int(take_counts[layer_idx]) if k <= 0: From 040ae89c5e9d970224568a005e2b6ea0a9bd59ab Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Fri, 12 Dec 2025 14:45:00 +0000 Subject: [PATCH 15/30] Address review comments Signed-off-by: ilmarkov --- vllm/distributed/eplb/policy/default.py | 90 +++++++++++----------- vllm/distributed/eplb/rebalance_execute.py | 16 ++-- 2 files changed, 51 insertions(+), 55 deletions(-) diff --git a/vllm/distributed/eplb/policy/default.py b/vllm/distributed/eplb/policy/default.py index 172f6f044f89e..d4f3b65cdc0c5 100644 --- a/vllm/distributed/eplb/policy/default.py +++ b/vllm/distributed/eplb/policy/default.py @@ -93,7 +93,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy): Returns: phy2log: [X, num_phy], logical expert id of each physical expert - rank: [X, num_phy], the replica rank + replica_idx: [X, num_phy], the index of the replica for each logical expert logcnt: [X, num_log], number of replicas for each logical expert """ n, num_log = weight.shape @@ -101,15 +101,15 @@ class DefaultEplbPolicy(AbstractEplbPolicy): assert num_redundant >= 0 device = weight.device phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) - rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) + replica_idx = torch.zeros(n, num_phy, dtype=torch.int64, device=device) logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) arangen = torch.arange(n, dtype=torch.int64, device=device) for i in range(num_log, num_phy): redundant_indices = (weight / logcnt).max(dim=-1).indices phy2log[:, i] = redundant_indices - rank[:, i] = logcnt[arangen, redundant_indices] + replica_idx[:, i] = logcnt[arangen, redundant_indices] logcnt[arangen, redundant_indices] += 1 - return phy2log, rank, logcnt + return phy2log, replica_idx, logcnt @classmethod def rebalance_experts_hierarchical( @@ -132,7 +132,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy): Returns: phy2log: [layers, num_replicas], the expert index of each replica - log2phy: [layers, num_logical_experts, X], + pphy_replicas_idx: [layers, num_logical_experts, X], the replica indices for each expert logcnt: [layers, num_logical_experts], number of physical replicas for each logical expert @@ -177,7 +177,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy): tokens_per_mlog = weight.gather(-1, mlog2log).view( -1, num_logical_experts // num_nodes ) - phy2mlog, phyrank, mlogcnt = cls.replicate_experts( + phy2mlog, replicas_idx, mlogcnt = cls.replicate_experts( tokens_per_mlog, num_physical_experts // num_nodes ) @@ -203,15 +203,15 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ).view(1, -1, 1) ).flatten(-2) pphy2log = mlog2log.gather(-1, pphy2mlog) - pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) + pphy_replicas_idx = replicas_idx.gather(-1, pphy2phy).view(num_layers, -1) logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) - return pphy2log, pphyrank, logcnt + return pphy2log, pphy_replicas_idx, logcnt @classmethod def preserve_intragpu_slots( cls, phy2log: torch.Tensor, - phyrank: torch.Tensor, + phy_replicas_idx: torch.Tensor, num_ranks: int, old_global_expert_indices: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -223,56 +223,52 @@ class DefaultEplbPolicy(AbstractEplbPolicy): the old and new mappings. """ device = phy2log.device - new_num_phy = phy2log.shape[1] - old_num_phy = old_global_expert_indices.shape[1] - if ( - num_ranks <= 0 - or new_num_phy % num_ranks != 0 - or old_num_phy % num_ranks != 0 - or (new_num_phy // num_ranks) != (old_num_phy // num_ranks) - ): - return phy2log, phyrank + num_phy_experts = phy2log.shape[1] + if num_ranks <= 0 or num_phy_experts % num_ranks != 0: + return phy2log, phy_replicas_idx # Move to CPU and convert to NumPy for processing - phy2log_np = phy2log.cpu().numpy() - phyrank_np = phyrank.cpu().numpy() - old_np = old_global_expert_indices.cpu().numpy() + new_phy2log_np = phy2log.cpu().numpy() + replicas_idx_np = phy_replicas_idx.cpu().numpy() + old_phy2log_np = old_global_expert_indices.cpu().numpy() - slots_per_gpu = new_num_phy // num_ranks - num_layers = phy2log_np.shape[0] + slots_per_gpu = num_phy_experts // num_ranks + num_layers = new_phy2log_np.shape[0] - post_phy2log_np = phy2log_np.copy() - post_phyrank_np = phyrank_np.copy() + post_phy2log_np = new_phy2log_np.copy() + post_phy_replicas_idx_np = replicas_idx_np.copy() for gpu_idx in range(num_ranks): start = gpu_idx * slots_per_gpu end = start + slots_per_gpu - # Segments across all layers for this GPU - old_seg = old_np[:, start:end] # [L, S] - new_seg = phy2log_np[:, start:end] # [L, S] - new_rnk = phyrank_np[:, start:end] # [L, S] + # Experts across all layers for this GPU + old_local = old_phy2log_np[:, start:end] # [layers, slots] + new_local = new_phy2log_np[:, start:end] # [layers, slots] + new_ridx = replicas_idx_np[:, start:end] # [layers, slots] used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool) preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool) # First pass: preserve same-logical experts in their previous slots - for pos in range(slots_per_gpu): - # matches: [L, S], True where new_seg has the same logical value - # as the old slot 'pos' and not used - matches = (new_seg == old_seg[:, pos][:, None]) & (~used_new_indices) + for slot_idx in range(slots_per_gpu): + # matches: [layers, slots], True where new local experts have + # the same logical value as the old from 'slot_idx' and not checked yet + matches = (new_local == old_local[:, slot_idx][:, None]) & ( + ~used_new_indices + ) has_any = matches.any(axis=1) if np.any(has_any): first_idx = np.argmax(matches, axis=1) layer_indices = np.nonzero(has_any)[0] matched_new_positions = first_idx[layer_indices] - post_phy2log_np[layer_indices, start + pos] = new_seg[ - layer_indices, matched_new_positions - ] - post_phyrank_np[layer_indices, start + pos] = new_rnk[ + post_phy2log_np[layer_indices, start + slot_idx] = new_local[ layer_indices, matched_new_positions ] + post_phy_replicas_idx_np[layer_indices, start + slot_idx] = ( + new_ridx[layer_indices, matched_new_positions] + ) used_new_indices[layer_indices, matched_new_positions] = True - preserved_positions[layer_indices, pos] = True + preserved_positions[layer_indices, slot_idx] = True # Second pass: fill remaining slots with remaining new experts remaining_mask = ~used_new_indices # [L, S] @@ -299,17 +295,17 @@ class DefaultEplbPolicy(AbstractEplbPolicy): continue src_pos = remaining_indices[layer_idx, :k] dst_pos = fill_indices[layer_idx, :k] - post_phy2log_np[layer_idx, start + dst_pos] = new_seg[ + post_phy2log_np[layer_idx, start + dst_pos] = new_local[ layer_idx, src_pos ] - post_phyrank_np[layer_idx, start + dst_pos] = new_rnk[ + post_phy_replicas_idx_np[layer_idx, start + dst_pos] = new_ridx[ layer_idx, src_pos ] # Convert back to torch and move to original device post_phy2log = torch.from_numpy(post_phy2log_np).to(device) - post_phyrank = torch.from_numpy(post_phyrank_np).to(device) - return post_phy2log, post_phyrank + post_phy_replicas_idx = torch.from_numpy(post_phy_replicas_idx_np).to(device) + return post_phy2log, post_phy_replicas_idx @classmethod def rebalance_experts( @@ -348,12 +344,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy): weight = weight.float() if num_groups % num_nodes == 0: # use hierarchical load-balance policy - phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical( + phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical( weight, num_replicas, num_groups, num_nodes, num_ranks ) else: # use global load-balance policy - phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical( + phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical( weight, num_replicas, 1, 1, num_ranks ) # Optional postprocessing to preserve slots for experts moving @@ -362,8 +358,8 @@ class DefaultEplbPolicy(AbstractEplbPolicy): # Helps to avoid unnecessary weight copying when experts move # within the same GPU. if old_global_expert_indices is not None: - phy2log, phyrank = cls.preserve_intragpu_slots( - phy2log, phyrank, num_ranks, old_global_expert_indices + phy2log, phy_replicas_idx = cls.preserve_intragpu_slots( + phy2log, phy_replicas_idx, num_ranks, old_global_expert_indices ) num_redundant_experts = num_replicas - num_logical_experts maxlogcnt = num_redundant_experts + 1 @@ -375,7 +371,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ) log2phy.view(num_layers, -1).scatter_( -1, - phy2log * maxlogcnt + phyrank, + phy2log * maxlogcnt + phy_replicas_idx, torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand( num_layers, -1 ), diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 641fad9f4e788..717d9f6d88793 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -361,24 +361,23 @@ def move_from_buffer( is_received_locally: np.ndarray, recv_metadata: RecvMetadata, new_indices: np.ndarray, - ep_group: ProcessGroup, + ep_rank: int, ) -> None: """ Copies expert weights from communication buffers back to the target weight tensors after EPLB rebalancing. Args: - expert_weights: List of weight tensors for a MoE layer. - expert_weights_buffers: Intermediate buffers matching ``expert_weights``. + expert_weights: List of the actual MoE layer weights used in the execution. + expert_weights_buffers: Intermediate buffers containing the experts weights + after the transfer is completed. is_unchanged: (num_local_experts,), True where an expert row is unchanged. is_received_locally: (num_local_experts,), True where a row is updated locally. recv_metadata: RecvMetadata containing remote receive metadata. new_indices: (num_experts_total,) mapping from local rows to desired (possibly global) expert id, after rebalance. - ep_group: torch.distributed.ProcessGroup for expert parallel communication - domain. + ep_rank: Rank of the process in the expert parallel group. """ - ep_rank = ep_group.rank() recv_primary_mask = recv_metadata.recv_primary_mask recv_count = recv_metadata.recv_count recv_expert_ids = recv_metadata.recv_expert_ids @@ -395,16 +394,17 @@ def move_from_buffer( for w, b in zip(expert_weights, expert_weights_buffers): w[dst].copy_(b[dst], non_blocking=True) - # Duplicate remote received rows to non-primary duplicate dsts if recv_count == 0: return + # Duplicate remote received rows to non-primary duplicate dsts base = ep_rank * num_local_experts local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)] duplicate_mask = np.logical_and( np.logical_and(~is_unchanged, ~is_received_locally), np.logical_and(~recv_primary_mask, local_experts != -1), ) + # All received experts are unique in the destination, so no need to copy duplicates if not bool(duplicate_mask.any()): return @@ -607,7 +607,7 @@ def rearrange_expert_weights_inplace( is_received_locally=is_received_locally, recv_metadata=recv_metadata, new_indices=new_global_expert_indices_cpu[layer_idx], - ep_group=ep_group, + ep_rank=ep_group.rank(), ) From 9a41b9130e37c362217a3bcf6e04709fd7b6eb63 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Fri, 12 Dec 2025 14:52:49 +0000 Subject: [PATCH 16/30] Fix precommit Signed-off-by: ilmarkov --- vllm/distributed/eplb/eplb_state.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 111479f2ee2f4..961544290fed5 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -1005,7 +1005,7 @@ class EplbState: is_received_locally=model_state.is_received_locally, recv_metadata=model_state.recv_metadata, new_indices=new_indices, - ep_group=ep_group, + ep_rank=ep_group.rank(), ) transferred_layer = model_state.layer_to_transfer self._update_layer_mapping_from_new(model_state, transferred_layer) From 7d0ab7d4b91292003da743afd4a19e78c14f71c7 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 15 Dec 2025 10:46:20 +0000 Subject: [PATCH 17/30] Refactor tests and address nits Signed-off-by: ilmarkov --- tests/distributed/test_eplb_algo.py | 131 +++++++++++++----------- vllm/distributed/eplb/policy/default.py | 4 +- 2 files changed, 71 insertions(+), 64 deletions(-) diff --git a/tests/distributed/test_eplb_algo.py b/tests/distributed/test_eplb_algo.py index 9356ef1533488..c14cf5efed226 100644 --- a/tests/distributed/test_eplb_algo.py +++ b/tests/distributed/test_eplb_algo.py @@ -312,8 +312,8 @@ if __name__ == "__main__": test_basic_rebalance() -def _make_phyrank_from_phy2log(phy2log: torch.Tensor) -> torch.Tensor: - """Create phyrank from phy2log""" +def _make_phy_replicas_idx_from_phy2log(phy2log: torch.Tensor) -> torch.Tensor: + """Create replicas indices mapping from phy2log""" pr = torch.zeros_like(phy2log) for layer in range(phy2log.shape[0]): seen: dict[int, int] = {} @@ -328,9 +328,9 @@ def _make_phyrank_from_phy2log(phy2log: torch.Tensor) -> torch.Tensor: def _validate_intragpu_rearrangement( old_global_expert_indices: torch.Tensor, new_phy2log: torch.Tensor, - new_phyrank: torch.Tensor, + new_phy_replicas_idx: torch.Tensor, post_phy2log: torch.Tensor, - post_phyrank: torch.Tensor, + post_phy_replicas_idx: torch.Tensor, num_ranks: int, slots_per_gpu: int, ): @@ -340,9 +340,9 @@ def _validate_intragpu_rearrangement( end = start + slots_per_gpu old_seg = old_global_expert_indices[0, start:end] new_seg = new_phy2log[0, start:end] - new_rnk = new_phyrank[0, start:end] + new_rnk = new_phy_replicas_idx[0, start:end] post_seg = post_phy2log[0, start:end] - post_rnk = post_phyrank[0, start:end] + post_rnk = post_phy_replicas_idx[0, start:end] # Pairwise equality for (expert, rank) pairs to ensure nothing is lost def sorted_pairs(seg: torch.Tensor, rnk: torch.Tensor): @@ -376,70 +376,77 @@ def _validate_intragpu_rearrangement( ) -def test_preserve_intragpu_slots_simple(): +@pytest.mark.parametrize( + "num_ranks, slots_per_gpu, old_phy2log, new_phy2log", + [ + pytest.param( + # Setup: 2 GPUs, 4 slots each, 1 layer + # Old mapping: GPU0 -> [0,1,2,3], GPU1 -> [4,5,6,7] + # New mapping shuffles within GPU0 and brings 4,5 into GPU0. + # GPU0 new -> [1,5,0,4]; GPU1 new -> [6,2,7,3] + 2, + 4, + torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]), + torch.tensor([[1, 5, 0, 4, 6, 2, 7, 3]]), + id="simple", + ), + pytest.param( + # Setup: 2 GPUs, 5 slots each (total 10 physical experts), 1 layer + # Old mapping: + # GPU0 -> [0, 1, 0, 2, 3] (expert 0 duplicated) + # GPU1 -> [4, 5, 6, 1, 2] + # New mapping reorders within GPUs and moves some experts across GPUs, + # while still including duplicates: + # GPU0 new -> [0, 5, 4, 0, 1] (expert 0 duplicated, 4/5 incoming) + # GPU1 new -> [6, 2, 3, 2, 1] (expert 2 duplicated) + 2, + 5, + torch.tensor([[0, 1, 0, 2, 3, 4, 5, 6, 1, 2]]), + torch.tensor([[0, 5, 4, 0, 1, 6, 2, 3, 2, 1]]), + id="duplicates", + ), + pytest.param( + # Setup: 3 GPUs, 4 slots each (total 12 physical experts), 1 layer + # Old mapping: + # GPU0 -> [0, 1, 2, 3] + # GPU1 -> [0, 1, 2, 3] + # GPU2 -> [0, 1, 2, 3] + # New mapping decides to use one expert on 2 GPUs and shuffles + # experts on the third GPU, + # GPU0 new -> [0, 0, 0, 0] + # GPU1 new -> [0, 0, 0, 0] + # GPU2 new -> [1, 2, 3, 0] + 3, + 4, + torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]]), + torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0]]), + id="skewed_expert", + ), + ], +) +def test_preserve_intragpu_slots( + num_ranks: int, + slots_per_gpu: int, + old_phy2log: torch.Tensor, + new_phy2log: torch.Tensor, +): """Experts that stay on a GPU keep their old slots; incoming not lost.""" - # Setup: 2 GPUs, 4 slots each, 1 layer - num_ranks = 2 - slots_per_gpu = 4 - # Old mapping: GPU0 -> [0,1,2,3], GPU1 -> [4,5,6,7] - old_global_expert_indices = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) - # New mapping shuffles within GPU0 and brings 4,5 into GPU0. - # GPU0 new -> [1,5,0,4] (0 and 1 remain on GPU0 but at different slots) - # GPU1 new -> [6,2,7,3] (6 and 7 remain on GPU1, 2 and 3 move in) - phy2log = torch.tensor([[1, 5, 0, 4, 6, 2, 7, 3]]) - # Derive phyrank from replica occurrence order per expert - phyrank = _make_phyrank_from_phy2log(phy2log) + phy_replicas_idx = _make_phy_replicas_idx_from_phy2log(new_phy2log) - post_phy2log, post_phyrank = DefaultEplbPolicy.preserve_intragpu_slots( - phy2log, phyrank, num_ranks, old_global_expert_indices + post_phy2log, post_phy_replicas_idx = DefaultEplbPolicy.preserve_intragpu_slots( + new_phy2log, phy_replicas_idx, num_ranks, old_phy2log ) # Shapes preserved - assert post_phy2log.shape == phy2log.shape - assert post_phyrank.shape == phyrank.shape + assert post_phy2log.shape == new_phy2log.shape + assert post_phy_replicas_idx.shape == phy_replicas_idx.shape _validate_intragpu_rearrangement( - old_global_expert_indices, - phy2log, - phyrank, + old_phy2log, + new_phy2log, + phy_replicas_idx, post_phy2log, - post_phyrank, - num_ranks, - slots_per_gpu, - ) - - -def test_preserve_intragpu_slots_with_duplicates(): - """Test preserve intragpu slots with duplicates""" - # Setup: 2 GPUs, 5 slots each (total 10 physical experts), 1 layer - num_ranks = 2 - slots_per_gpu = 5 - # Old mapping: - # GPU0 -> [0, 1, 0, 2, 3] (expert 0 duplicated) - # GPU1 -> [4, 5, 6, 1, 2] - old_global_expert_indices = torch.tensor([[0, 1, 0, 2, 3, 4, 5, 6, 1, 2]]) - # New mapping reorders within GPUs and moves some experts across GPUs, - # while still including duplicates: - # GPU0 new -> [0, 5, 4, 0, 1] (expert 0 duplicated, 4/5 incoming) - # GPU1 new -> [6, 2, 3, 1, 2] (expert 2 duplicated) - phy2log = torch.tensor([[0, 5, 4, 0, 1, 6, 2, 3, 1, 2]]) - # Derive ranks so duplicates have ranks [0,1,...] by occurrence - phyrank = _make_phyrank_from_phy2log(phy2log) - - post_phy2log, post_phyrank = DefaultEplbPolicy.preserve_intragpu_slots( - phy2log, phyrank, num_ranks, old_global_expert_indices - ) - - # Shapes preserved - assert post_phy2log.shape == phy2log.shape - assert post_phyrank.shape == phyrank.shape - - _validate_intragpu_rearrangement( - old_global_expert_indices, - phy2log, - phyrank, - post_phy2log, - post_phyrank, + post_phy_replicas_idx, num_ranks, slots_per_gpu, ) diff --git a/vllm/distributed/eplb/policy/default.py b/vllm/distributed/eplb/policy/default.py index d4f3b65cdc0c5..ebbbc6db1d9f9 100644 --- a/vllm/distributed/eplb/policy/default.py +++ b/vllm/distributed/eplb/policy/default.py @@ -271,8 +271,8 @@ class DefaultEplbPolicy(AbstractEplbPolicy): preserved_positions[layer_indices, slot_idx] = True # Second pass: fill remaining slots with remaining new experts - remaining_mask = ~used_new_indices # [L, S] - fill_mask = ~preserved_positions # [L, S] + remaining_mask = ~used_new_indices # [layers, slots] + fill_mask = ~preserved_positions # [layers, slots] if remaining_mask.any() and fill_mask.any(): idx_base = np.tile(np.arange(slots_per_gpu), (num_layers, 1)) # Sentinel value for unavailable positions. From 389f86e0c58e1f60a608a241820cbb2302eac45e Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 11 Dec 2025 16:41:33 +0000 Subject: [PATCH 18/30] Convert eplb rebalance to numpy Signed-off-by: ilmarkov --- tests/distributed/test_eplb_algo.py | 31 ++-- vllm/distributed/eplb/policy/default.py | 197 ++++++++++++------------ 2 files changed, 113 insertions(+), 115 deletions(-) diff --git a/tests/distributed/test_eplb_algo.py b/tests/distributed/test_eplb_algo.py index c14cf5efed226..6fe44fc218016 100644 --- a/tests/distributed/test_eplb_algo.py +++ b/tests/distributed/test_eplb_algo.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import numpy as np import pytest import torch @@ -312,9 +313,9 @@ if __name__ == "__main__": test_basic_rebalance() -def _make_phy_replicas_idx_from_phy2log(phy2log: torch.Tensor) -> torch.Tensor: - """Create replicas indices mapping from phy2log""" - pr = torch.zeros_like(phy2log) +def _make_phy_replicas_idx_from_phy2log(phy2log: np.ndarray) -> np.ndarray: + """Create replicas indices mapping from phy2log.""" + pr = np.zeros_like(phy2log, dtype=np.int64) for layer in range(phy2log.shape[0]): seen: dict[int, int] = {} row = phy2log[layer].tolist() @@ -326,11 +327,11 @@ def _make_phy_replicas_idx_from_phy2log(phy2log: torch.Tensor) -> torch.Tensor: def _validate_intragpu_rearrangement( - old_global_expert_indices: torch.Tensor, - new_phy2log: torch.Tensor, - new_phy_replicas_idx: torch.Tensor, - post_phy2log: torch.Tensor, - post_phy_replicas_idx: torch.Tensor, + old_global_expert_indices: np.ndarray, + new_phy2log: np.ndarray, + new_phy_replicas_idx: np.ndarray, + post_phy2log: np.ndarray, + post_phy_replicas_idx: np.ndarray, num_ranks: int, slots_per_gpu: int, ): @@ -345,7 +346,7 @@ def _validate_intragpu_rearrangement( post_rnk = post_phy_replicas_idx[0, start:end] # Pairwise equality for (expert, rank) pairs to ensure nothing is lost - def sorted_pairs(seg: torch.Tensor, rnk: torch.Tensor): + def sorted_pairs(seg, rnk): pairs = list(zip(seg.tolist(), rnk.tolist())) pairs.sort() return pairs @@ -386,8 +387,8 @@ def _validate_intragpu_rearrangement( # GPU0 new -> [1,5,0,4]; GPU1 new -> [6,2,7,3] 2, 4, - torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]), - torch.tensor([[1, 5, 0, 4, 6, 2, 7, 3]]), + np.array([[0, 1, 2, 3, 4, 5, 6, 7]]), + np.array([[1, 5, 0, 4, 6, 2, 7, 3]]), id="simple", ), pytest.param( @@ -401,8 +402,8 @@ def _validate_intragpu_rearrangement( # GPU1 new -> [6, 2, 3, 2, 1] (expert 2 duplicated) 2, 5, - torch.tensor([[0, 1, 0, 2, 3, 4, 5, 6, 1, 2]]), - torch.tensor([[0, 5, 4, 0, 1, 6, 2, 3, 2, 1]]), + np.array([[0, 1, 0, 2, 3, 4, 5, 6, 1, 2]]), + np.array([[0, 5, 4, 0, 1, 6, 2, 3, 2, 1]]), id="duplicates", ), pytest.param( @@ -418,8 +419,8 @@ def _validate_intragpu_rearrangement( # GPU2 new -> [1, 2, 3, 0] 3, 4, - torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]]), - torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0]]), + np.array([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]]), + np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0]]), id="skewed_expert", ), ], diff --git a/vllm/distributed/eplb/policy/default.py b/vllm/distributed/eplb/policy/default.py index ebbbc6db1d9f9..d6f030a391eea 100644 --- a/vllm/distributed/eplb/policy/default.py +++ b/vllm/distributed/eplb/policy/default.py @@ -21,8 +21,8 @@ from .abstract import AbstractEplbPolicy class DefaultEplbPolicy(AbstractEplbPolicy): @classmethod def balanced_packing( - cls, weight: torch.Tensor, num_packs: int - ) -> tuple[torch.Tensor, torch.Tensor]: + cls, weight: np.ndarray, num_packs: int + ) -> tuple[np.ndarray, np.ndarray]: """ Pack n weighted objects to m packs, such that each bin contains exactly n/m objects and the weights of all packs are as balanced as possible. @@ -39,19 +39,15 @@ class DefaultEplbPolicy(AbstractEplbPolicy): assert num_groups % num_packs == 0 groups_per_pack = num_groups // num_packs - device = weight.device - if groups_per_pack == 1: - pack_index = torch.arange( - weight.size(-1), dtype=torch.int64, device=device - ).expand(weight.shape) - rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device) - return pack_index, rank_in_pack - - weight_np = weight.cpu().numpy() + pack_index_np = np.tile( + np.arange(num_groups, dtype=np.int64), (num_layers, 1) + ) + rank_in_pack_np = np.zeros_like(pack_index_np, dtype=np.int64) + return pack_index_np, rank_in_pack_np # Sort and get indices in decending order - indices_np = np.argsort(-weight_np, axis=-1) + indices = np.argsort(-weight, axis=-1) pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64) rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64) @@ -61,7 +57,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy): pack_weights = [0.0] * num_packs pack_items = [0] * num_packs - for group in indices_np[i]: + for group in indices[i]: # Find a pack with capacity that has the lowest weight pack = min( (j for j in range(num_packs) if pack_items[j] < groups_per_pack), @@ -71,18 +67,15 @@ class DefaultEplbPolicy(AbstractEplbPolicy): assert pack_items[pack] < groups_per_pack pack_index_np[i, group] = pack rank_in_pack_np[i, group] = pack_items[pack] - pack_weights[pack] += weight_np[i, group] + pack_weights[pack] += weight[i, group] pack_items[pack] += 1 - pack_index = torch.from_numpy(pack_index_np).to(device) - rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device) - - return pack_index, rank_in_pack + return pack_index_np, rank_in_pack_np @classmethod def replicate_experts( - cls, weight: torch.Tensor, num_phy: int - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + cls, weight: np.ndarray, num_phy: int + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Replicate `num_log` experts to `num_phy` replicas, such that the maximum load of all replicas is minimized. @@ -99,13 +92,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy): n, num_log = weight.shape num_redundant = num_phy - num_log assert num_redundant >= 0 - device = weight.device - phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) - replica_idx = torch.zeros(n, num_phy, dtype=torch.int64, device=device) - logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) - arangen = torch.arange(n, dtype=torch.int64, device=device) + phy2log = np.tile(np.arange(num_phy, dtype=np.int64), (n, 1)) + replica_idx = np.zeros((n, num_phy), dtype=np.int64) + logcnt = np.ones((n, num_log), dtype=np.int64) + arangen = np.arange(n, dtype=np.int64) for i in range(num_log, num_phy): - redundant_indices = (weight / logcnt).max(dim=-1).indices + redundant_indices = np.argmax(weight / logcnt, axis=-1) phy2log[:, i] = redundant_indices replica_idx[:, i] = logcnt[arangen, redundant_indices] logcnt[arangen, redundant_indices] += 1 @@ -114,12 +106,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy): @classmethod def rebalance_experts_hierarchical( cls, - weight: torch.Tensor, + weight: np.ndarray, num_physical_experts: int, num_groups: int, num_nodes: int, num_gpus: int, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: """ Parameters: weight: [num_moe_layers, num_logical_experts] @@ -146,35 +138,33 @@ class DefaultEplbPolicy(AbstractEplbPolicy): assert num_physical_experts % num_gpus == 0 phy_experts_per_gpu = num_physical_experts // num_gpus - def inverse(perm: torch.Tensor) -> torch.Tensor: - inv = torch.empty_like(perm) - inv.scatter_( - 1, - perm, - torch.arange( - perm.size(1), dtype=torch.int64, device=perm.device - ).expand(perm.shape), - ) + def inverse(perm: np.ndarray) -> np.ndarray: + inv = np.empty_like(perm) + row_idx = np.arange(perm.shape[0])[:, None] + col_idx = np.arange(perm.shape[1], dtype=np.int64) + inv[row_idx, perm] = col_idx return inv # Step 1: pack groups to nodes - tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) + tokens_per_group = weight.reshape(num_layers, num_groups, group_size).sum( + axis=-1 + ) group_pack_index, group_rank_in_pack = cls.balanced_packing( tokens_per_group, num_nodes ) + # Map each logical expert into a node-local ordering based on packed groups. log2mlog = ( ( - (group_pack_index * groups_per_node + group_rank_in_pack) * group_size - ).unsqueeze(-1) - + torch.arange( - group_size, dtype=torch.int64, device=group_pack_index.device + (group_pack_index * groups_per_node + group_rank_in_pack)[..., None] + * group_size ) - ).flatten(-2) + + np.arange(group_size, dtype=np.int64) + ).reshape(num_layers, num_logical_experts) mlog2log = inverse(log2mlog) # Step 2: construct redundant experts within nodes - # [num_layers * num_nodes, num_logical_experts // num_nodes] - tokens_per_mlog = weight.gather(-1, mlog2log).view( + # Reorder weights into the node-local layout so replication is done per node. + tokens_per_mlog = np.take_along_axis(weight, mlog2log, axis=1).reshape( -1, num_logical_experts // num_nodes ) phy2mlog, replicas_idx, mlogcnt = cls.replicate_experts( @@ -182,39 +172,43 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ) # Step 3: pack physical_experts to GPUs - # [num_layers * num_nodes, num_physical_experts // num_nodes] - tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) + # Effective per-physical load = logical load divided by replica count. + tokens_per_phy = np.take_along_axis(tokens_per_mlog / mlogcnt, phy2mlog, axis=1) pack_index, rank_in_pack = cls.balanced_packing( tokens_per_phy, num_gpus // num_nodes ) phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack pphy2phy = inverse(phy2pphy) - pphy2mlog = phy2mlog.gather( - -1, pphy2phy - ) # [num_layers * num_nodes, num_log_per_nodes] + # Reorder node-local logical indices into the post-packing physical order. + pphy2mlog = np.take_along_axis(phy2mlog, pphy2phy, axis=1) pphy2mlog = ( - pphy2mlog.view(num_layers, num_nodes, -1) - + torch.arange( + pphy2mlog.reshape(num_layers, num_nodes, -1) + + np.arange( 0, num_logical_experts, num_logical_experts // num_nodes, - device=group_pack_index.device, - ).view(1, -1, 1) - ).flatten(-2) - pphy2log = mlog2log.gather(-1, pphy2mlog) - pphy_replicas_idx = replicas_idx.gather(-1, pphy2phy).view(num_layers, -1) - logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) + dtype=np.int64, + )[None, :, None] + ).reshape(num_layers, -1) + # Map node-local logical indices back to global logical expert ids. + pphy2log = np.take_along_axis(mlog2log, pphy2mlog, axis=1) + # Reorder replica ranks to the post-packing physical ordering. + pphy_replicas_idx = np.take_along_axis(replicas_idx, pphy2phy, axis=1).reshape( + num_layers, -1 + ) + # Convert replica counts back to the original logical ordering. + logcnt = np.take_along_axis(mlogcnt.reshape(num_layers, -1), log2mlog, axis=1) return pphy2log, pphy_replicas_idx, logcnt @classmethod def preserve_intragpu_slots( cls, - phy2log: torch.Tensor, - phy_replicas_idx: torch.Tensor, + phy2log: np.ndarray, + phy_replicas_idx: np.ndarray, num_ranks: int, - old_global_expert_indices: torch.Tensor, - ) -> tuple[torch.Tensor, torch.Tensor]: + old_phy2log: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray]: """ Reorder the new mapping per GPU so that experts that remain on the same GPU keep their previous slot positions when possible. Incoming experts to that GPU @@ -222,29 +216,24 @@ class DefaultEplbPolicy(AbstractEplbPolicy): is unchanged and the slots per GPU remain the same between the old and new mappings. """ - device = phy2log.device num_phy_experts = phy2log.shape[1] if num_ranks <= 0 or num_phy_experts % num_ranks != 0: return phy2log, phy_replicas_idx # Move to CPU and convert to NumPy for processing - new_phy2log_np = phy2log.cpu().numpy() - replicas_idx_np = phy_replicas_idx.cpu().numpy() - old_phy2log_np = old_global_expert_indices.cpu().numpy() - slots_per_gpu = num_phy_experts // num_ranks - num_layers = new_phy2log_np.shape[0] + num_layers = phy2log.shape[0] - post_phy2log_np = new_phy2log_np.copy() - post_phy_replicas_idx_np = replicas_idx_np.copy() + post_phy2log = phy2log.copy() + post_phy_replicas_idx = phy_replicas_idx.copy() for gpu_idx in range(num_ranks): start = gpu_idx * slots_per_gpu end = start + slots_per_gpu # Experts across all layers for this GPU - old_local = old_phy2log_np[:, start:end] # [layers, slots] - new_local = new_phy2log_np[:, start:end] # [layers, slots] - new_ridx = replicas_idx_np[:, start:end] # [layers, slots] + old_local = old_phy2log[:, start:end] # [layers, slots] + new_local = phy2log[:, start:end] # [layers, slots] + new_ridx = phy_replicas_idx[:, start:end] # [layers, slots] used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool) preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool) @@ -261,12 +250,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy): first_idx = np.argmax(matches, axis=1) layer_indices = np.nonzero(has_any)[0] matched_new_positions = first_idx[layer_indices] - post_phy2log_np[layer_indices, start + slot_idx] = new_local[ + post_phy2log[layer_indices, start + slot_idx] = new_local[ + layer_indices, matched_new_positions + ] + post_phy_replicas_idx[layer_indices, start + slot_idx] = new_ridx[ layer_indices, matched_new_positions ] - post_phy_replicas_idx_np[layer_indices, start + slot_idx] = ( - new_ridx[layer_indices, matched_new_positions] - ) used_new_indices[layer_indices, matched_new_positions] = True preserved_positions[layer_indices, slot_idx] = True @@ -295,16 +284,13 @@ class DefaultEplbPolicy(AbstractEplbPolicy): continue src_pos = remaining_indices[layer_idx, :k] dst_pos = fill_indices[layer_idx, :k] - post_phy2log_np[layer_idx, start + dst_pos] = new_local[ + post_phy2log[layer_idx, start + dst_pos] = new_local[ layer_idx, src_pos ] - post_phy_replicas_idx_np[layer_idx, start + dst_pos] = new_ridx[ + post_phy_replicas_idx[layer_idx, start + dst_pos] = new_ridx[ layer_idx, src_pos ] - # Convert back to torch and move to original device - post_phy2log = torch.from_numpy(post_phy2log_np).to(device) - post_phy_replicas_idx = torch.from_numpy(post_phy_replicas_idx_np).to(device) return post_phy2log, post_phy_replicas_idx @classmethod @@ -340,40 +326,51 @@ class DefaultEplbPolicy(AbstractEplbPolicy): logcnt: [layers, num_logical_experts], number of physical replicas for each logical expert """ + device = weight.device num_layers, num_logical_experts = weight.shape - weight = weight.float() + weight_np = weight.float().cpu().numpy() + old_phy2log_np = ( + old_global_expert_indices.cpu().numpy() + if old_global_expert_indices is not None + else None + ) + if num_groups % num_nodes == 0: # use hierarchical load-balance policy - phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical( - weight, num_replicas, num_groups, num_nodes, num_ranks + phy2log_np, phy_replicas_idx_np, logcnt_np = ( + cls.rebalance_experts_hierarchical( + weight_np, num_replicas, num_groups, num_nodes, num_ranks + ) ) else: # use global load-balance policy - phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical( - weight, num_replicas, 1, 1, num_ranks + phy2log_np, phy_replicas_idx_np, logcnt_np = ( + cls.rebalance_experts_hierarchical( + weight_np, num_replicas, 1, 1, num_ranks + ) ) + # Optional postprocessing to preserve slots for experts moving # within the same GPU # Only apply when the number of GPUs and slots per GPU remain unchanged. # Helps to avoid unnecessary weight copying when experts move # within the same GPU. if old_global_expert_indices is not None: - phy2log, phy_replicas_idx = cls.preserve_intragpu_slots( - phy2log, phy_replicas_idx, num_ranks, old_global_expert_indices + phy2log_np, phy_replicas_idx_np = cls.preserve_intragpu_slots( + phy2log_np, phy_replicas_idx_np, num_ranks, old_phy2log_np ) num_redundant_experts = num_replicas - num_logical_experts maxlogcnt = num_redundant_experts + 1 - log2phy: torch.Tensor = torch.full( - (num_layers, num_logical_experts, maxlogcnt), - -1, - dtype=torch.int64, - device=logcnt.device, + log2phy_np = np.full( + (num_layers, num_logical_experts, maxlogcnt), -1, dtype=np.int64 ) - log2phy.view(num_layers, -1).scatter_( - -1, - phy2log * maxlogcnt + phy_replicas_idx, - torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand( - num_layers, -1 - ), + layer_indices = np.arange(num_layers)[:, None] + replica_indices = np.tile( + np.arange(num_replicas, dtype=np.int64), (num_layers, 1) ) + log2phy_np[layer_indices, phy2log_np, phy_replicas_idx_np] = replica_indices + + phy2log = torch.from_numpy(phy2log_np).to(device) + log2phy = torch.from_numpy(log2phy_np).to(device) + logcnt = torch.from_numpy(logcnt_np).to(device) return phy2log, log2phy, logcnt From dcf47839670fc78703e4f33f0189e3525cace150 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 15 Dec 2025 13:55:41 +0000 Subject: [PATCH 19/30] balanced_packing into numpy Signed-off-by: ilmarkov --- vllm/distributed/eplb/policy/default.py | 43 ++++++++++++------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/vllm/distributed/eplb/policy/default.py b/vllm/distributed/eplb/policy/default.py index d6f030a391eea..66fc0372e6f5b 100644 --- a/vllm/distributed/eplb/policy/default.py +++ b/vllm/distributed/eplb/policy/default.py @@ -40,37 +40,36 @@ class DefaultEplbPolicy(AbstractEplbPolicy): groups_per_pack = num_groups // num_packs if groups_per_pack == 1: - pack_index_np = np.tile( - np.arange(num_groups, dtype=np.int64), (num_layers, 1) - ) - rank_in_pack_np = np.zeros_like(pack_index_np, dtype=np.int64) - return pack_index_np, rank_in_pack_np + pack_index = np.tile(np.arange(num_groups, dtype=np.int64), (num_layers, 1)) + rank_in_pack = np.zeros_like(pack_index, dtype=np.int64) + return pack_index, rank_in_pack # Sort and get indices in decending order indices = np.argsort(-weight, axis=-1) - pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64) - rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64) + pack_index = np.full((num_layers, num_groups), -1, dtype=np.int64) + rank_in_pack = np.full((num_layers, num_groups), -1, dtype=np.int64) + + pack_weights = np.zeros((num_layers, num_packs), dtype=np.float64) + pack_items = np.zeros((num_layers, num_packs), dtype=np.int64) # Run the packing algorithm - for i in range(num_layers): - pack_weights = [0.0] * num_packs - pack_items = [0] * num_packs + for layer_idx in range(num_layers): + weights_row = pack_weights[layer_idx] + items_row = pack_items[layer_idx] - for group in indices[i]: - # Find a pack with capacity that has the lowest weight - pack = min( - (j for j in range(num_packs) if pack_items[j] < groups_per_pack), - key=pack_weights.__getitem__, - ) + for group in indices[layer_idx]: + # Select the lightest pack that still has capacity. + available = items_row < groups_per_pack + assert np.any(available) + pack = int(np.argmin(np.where(available, weights_row, np.inf))) - assert pack_items[pack] < groups_per_pack - pack_index_np[i, group] = pack - rank_in_pack_np[i, group] = pack_items[pack] - pack_weights[pack] += weight[i, group] - pack_items[pack] += 1 + pack_index[layer_idx, group] = pack + rank_in_pack[layer_idx, group] = items_row[pack] + weights_row[pack] += weight[layer_idx, group] + items_row[pack] += 1 - return pack_index_np, rank_in_pack_np + return pack_index, rank_in_pack @classmethod def replicate_experts( From 699ea9710803d9b9c6ba924a1aacc81ec216720d Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 15 Dec 2025 18:08:15 +0000 Subject: [PATCH 20/30] Remove expensive checks Signed-off-by: ilmarkov --- vllm/distributed/eplb/policy/default.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/distributed/eplb/policy/default.py b/vllm/distributed/eplb/policy/default.py index 66fc0372e6f5b..b9cfcae014108 100644 --- a/vllm/distributed/eplb/policy/default.py +++ b/vllm/distributed/eplb/policy/default.py @@ -59,15 +59,16 @@ class DefaultEplbPolicy(AbstractEplbPolicy): items_row = pack_items[layer_idx] for group in indices[layer_idx]: - # Select the lightest pack that still has capacity. - available = items_row < groups_per_pack - assert np.any(available) - pack = int(np.argmin(np.where(available, weights_row, np.inf))) + # Pick the lightest pack; full packs are masked out by inf. + pack = int(np.argmin(weights_row)) pack_index[layer_idx, group] = pack rank_in_pack[layer_idx, group] = items_row[pack] weights_row[pack] += weight[layer_idx, group] items_row[pack] += 1 + if items_row[pack] == groups_per_pack: + # Mark as unavailable for future selections. + weights_row[pack] = np.inf return pack_index, rank_in_pack From 6014dc26d35db160f4e577d9f239175ff5dfe689 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 16 Dec 2025 15:02:53 +0000 Subject: [PATCH 21/30] Fix test Signed-off-by: ilmarkov --- tests/distributed/test_eplb_execute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/distributed/test_eplb_execute.py b/tests/distributed/test_eplb_execute.py index dfb15311202e5..14ad148da6c45 100644 --- a/tests/distributed/test_eplb_execute.py +++ b/tests/distributed/test_eplb_execute.py @@ -312,7 +312,7 @@ def _test_async_transfer_layer_without_mtp_worker( is_received_locally=is_received_locally, recv_metadata=recv_metadata, new_indices=new_indices_cpu[layer_idx], - ep_group=ep_group, + ep_rank=ep_group.rank(), ) verify_expert_weights_after_shuffle( From 7ebd46fe769480d6f2b5242ad8d50a1c61a441bd Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Fri, 12 Dec 2025 13:18:28 +0000 Subject: [PATCH 22/30] Move rebalance algo to async thread Signed-off-by: ilmarkov --- vllm/distributed/eplb/async_worker.py | 41 ++++++++++- vllm/distributed/eplb/eplb_state.py | 99 +++++++++++++++++---------- 2 files changed, 103 insertions(+), 37 deletions(-) diff --git a/vllm/distributed/eplb/async_worker.py b/vllm/distributed/eplb/async_worker.py index 9d7366996e3b2..7ea4bf269db23 100644 --- a/vllm/distributed/eplb/async_worker.py +++ b/vllm/distributed/eplb/async_worker.py @@ -17,7 +17,7 @@ from vllm.logger import init_logger from .rebalance_execute import transfer_layer if TYPE_CHECKING: - from .eplb_state import EplbState + from .eplb_state import EplbModelState, EplbState logger = init_logger(__name__) @@ -57,6 +57,42 @@ def start_async_worker( return thread +def run_rebalance_experts( + model_state: "EplbModelState", + eplb_state: "EplbState", +) -> None: + assert model_state.eplb_stats is not None + eplb_stats = model_state.eplb_stats + # Move the global expert load window to CPU for computation. + global_expert_load_window = eplb_stats.global_expert_load_window.cpu() + # Compute new expert mappings for the model + ( + new_physical_to_logical_map, + new_logical_to_physical_map, + new_logical_replica_count, + ) = eplb_state.policy.rebalance_experts( + global_expert_load_window, + eplb_stats.num_replicas, + eplb_stats.num_groups, + eplb_stats.num_nodes, + eplb_stats.num_gpus, + model_state.physical_to_logical_map, + ) + + # Move map to cpu + model_state.new_physical_to_logical_map = new_physical_to_logical_map + + max_slots = model_state.logical_to_physical_map.shape[-1] + padded_logical = torch.nn.functional.pad( + new_logical_to_physical_map, + (0, max(0, max_slots - new_logical_to_physical_map.shape[-1])), + value=-1, + ).to(model_state.logical_to_physical_map.device) + new_replica = new_logical_replica_count.to(model_state.logical_replica_count.device) + model_state.new_logical_to_physical_map = padded_logical + model_state.new_logical_replica_count = new_replica + + async def transfer_run_periodically( state: "EplbState", ep_group: ProcessGroup, @@ -71,6 +107,9 @@ async def transfer_run_periodically( for model_state in state.model_states.values(): if not model_state.is_async_enabled: continue + if not model_state.new_indices_computed: + run_rebalance_experts(model_state, state) + current_num_layers = model_state.model.num_moe_layers while ( model_state.rebalanced diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 961544290fed5..e4a3d9c160453 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -55,6 +55,35 @@ from .rebalance_execute import ( logger = init_logger(__name__) +@dataclass +class EplbStats: + """ + Model stats used in EPLB rebalanding algorithm. + """ + + global_expert_load_window: torch.Tensor + """ + Experts load window. + Shape: (window_size, num_moe_layers, num_physical_experts) + """ + num_replicas: int + """ + Number of physical experts. + """ + num_groups: int + """ + Number of expert groups. + """ + num_nodes: int + """ + Number of nodes. + """ + num_gpus: int + """ + Number of GPUs. + """ + + @dataclass class EplbModelState: """EPLB metrics.""" @@ -168,6 +197,14 @@ class EplbModelState: """ Whether the async EPLB needs to poll peers for buffer readiness. """ + new_indices_computed: bool + """ + The flag indicates whether the new indices have been computed. + """ + eplb_stats: EplbStats | None + """ + EPLB stats for the model. + """ is_unchanged: np.ndarray """ intermediate variable between `move_to_buffer` and `move_to_workspace`. @@ -510,6 +547,8 @@ class EplbState: layer_to_transfer=0, rebalanced=False, pending_global_ready_check=False, + new_indices_computed=False, + eplb_stats=None, is_unchanged=np.array([]), is_received_locally=np.array([]), recv_metadata=RecvMetadata( @@ -806,21 +845,21 @@ class EplbState: for eplb_model_state, global_expert_load_window in zip( self.model_states.values(), global_expert_load_windows ): - # Get new expert mappings for the model - ( - new_physical_to_logical_map, - new_logical_to_physical_map, - new_logical_replica_count, - ) = self.policy.rebalance_experts( - global_expert_load_window, - num_replicas, - num_groups, - num_nodes, - num_gpus, - eplb_model_state.physical_to_logical_map, - ) - if not eplb_model_state.is_async_enabled or is_profile: + # Get new expert mappings for the model + ( + new_physical_to_logical_map, + new_logical_to_physical_map, + new_logical_replica_count, + ) = self.policy.rebalance_experts( + global_expert_load_window, + num_replicas, + num_groups, + num_nodes, + num_gpus, + eplb_model_state.physical_to_logical_map, + ) + # Update expert weights rearrange_expert_weights_inplace( eplb_model_state.physical_to_logical_map, @@ -877,27 +916,17 @@ class EplbState: gpu_elapsed, ) else: - max_slots = eplb_model_state.logical_to_physical_map.shape[-1] - padded_logical = torch.nn.functional.pad( - new_logical_to_physical_map, - (0, max(0, max_slots - new_logical_to_physical_map.shape[-1])), - value=-1, - ).to(eplb_model_state.logical_to_physical_map.device) - new_replica = new_logical_replica_count.to( - eplb_model_state.logical_replica_count.device + eplb_model_state.eplb_stats = EplbStats( + global_expert_load_window=global_expert_load_window, + num_replicas=num_replicas, + num_groups=num_groups, + num_nodes=num_nodes, + num_gpus=num_gpus, ) - - # Move map to cpu in advance - eplb_model_state.new_physical_to_logical_map = ( - new_physical_to_logical_map.cpu() - ) - eplb_model_state.new_logical_to_physical_map = padded_logical - eplb_model_state.new_logical_replica_count = new_replica - eplb_model_state.rebalanced = True eplb_model_state.layer_to_transfer = 0 eplb_model_state.pending_global_ready_check = True - + eplb_model_state.new_indices_computed = False # Signal async thread to start transferring layers if self.is_async and (not is_profile): self.rearrange_event.set() @@ -993,11 +1022,9 @@ class EplbState: model_state.layer_to_transfer ] expert_weights_buffer = model_state.expert_buffer - new_indices = ( - model_state.new_physical_to_logical_map[model_state.layer_to_transfer] - .cpu() - .numpy() - ) + new_indices = model_state.new_physical_to_logical_map[ + model_state.layer_to_transfer + ].numpy() move_from_buffer( expert_weights=expert_weights, expert_weights_buffers=expert_weights_buffer, From c761ce527a256199e1abdcf20ac17843f8aa60d7 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 15 Dec 2025 17:16:40 +0000 Subject: [PATCH 23/30] Update Signed-off-by: ilmarkov --- vllm/distributed/eplb/async_worker.py | 6 +++++- vllm/distributed/eplb/eplb_state.py | 4 +++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/vllm/distributed/eplb/async_worker.py b/vllm/distributed/eplb/async_worker.py index 7ea4bf269db23..fcad45b9e2449 100644 --- a/vllm/distributed/eplb/async_worker.py +++ b/vllm/distributed/eplb/async_worker.py @@ -78,8 +78,8 @@ def run_rebalance_experts( eplb_stats.num_gpus, model_state.physical_to_logical_map, ) + assert new_physical_to_logical_map.device == torch.device("cpu") - # Move map to cpu model_state.new_physical_to_logical_map = new_physical_to_logical_map max_slots = model_state.logical_to_physical_map.shape[-1] @@ -109,6 +109,10 @@ async def transfer_run_periodically( continue if not model_state.new_indices_computed: run_rebalance_experts(model_state, state) + logger.info( + "Async worker computed new indices for model %s", + model_state.model_name, + ) current_num_layers = model_state.model.num_moe_layers while ( diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index e4a3d9c160453..d169e42df577d 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -917,7 +917,9 @@ class EplbState: ) else: eplb_model_state.eplb_stats = EplbStats( - global_expert_load_window=global_expert_load_window, + # We copy the tensor to snapshot the workload on the main + # thread to be used on the async thread. + global_expert_load_window=global_expert_load_window.clone(), num_replicas=num_replicas, num_groups=num_groups, num_nodes=num_nodes, From 720fe99051bbc428387b3f786223d6bc787ece59 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 16 Dec 2025 14:50:28 +0000 Subject: [PATCH 24/30] Remove unnecessary copies Signed-off-by: ilmarkov --- vllm/distributed/eplb/async_worker.py | 2 ++ vllm/distributed/eplb/eplb_state.py | 43 ++++++++++++++------------- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/vllm/distributed/eplb/async_worker.py b/vllm/distributed/eplb/async_worker.py index fcad45b9e2449..2648ada89abc9 100644 --- a/vllm/distributed/eplb/async_worker.py +++ b/vllm/distributed/eplb/async_worker.py @@ -64,6 +64,8 @@ def run_rebalance_experts( assert model_state.eplb_stats is not None eplb_stats = model_state.eplb_stats # Move the global expert load window to CPU for computation. + # It has to be done in the main stream to avoid race condition + # with the main thread. global_expert_load_window = eplb_stats.global_expert_load_window.cpu() # Compute new expert mappings for the model ( diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index d169e42df577d..ef5e58706722c 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -948,9 +948,7 @@ class EplbState: is_profile=is_profile, ) - def _update_layer_mapping_from_new( - self, model_state: EplbModelState, layer: int - ) -> None: + def _update_layer_mapping_from_new(self, model_state: EplbModelState) -> None: if ( model_state.new_physical_to_logical_map is None or model_state.new_logical_to_physical_map is None @@ -959,27 +957,33 @@ class EplbState: return target_device = model_state.physical_to_logical_map.device - new_physical = model_state.new_physical_to_logical_map + new_physical = model_state.new_physical_to_logical_map.to( + target_device, non_blocking=True + ) if model_state.physical_to_logical_map.shape[1] != new_physical.shape[1]: - model_state.physical_to_logical_map = new_physical.to(target_device) + model_state.physical_to_logical_map = new_physical else: - model_state.physical_to_logical_map[layer].copy_( - new_physical[layer].to(target_device) - ) + for layer_idx in range(model_state.physical_to_logical_map.shape[0]): + model_state.physical_to_logical_map[layer_idx].copy_( + new_physical[layer_idx] + ) logical_device = model_state.logical_to_physical_map.device - new_logical = model_state.new_logical_to_physical_map[layer].to(logical_device) - max_slots = model_state.logical_to_physical_map.shape[-1] - slot_delta = max_slots - new_logical.shape[-1] - if slot_delta > 0: - new_logical = torch.nn.functional.pad( - new_logical, (0, slot_delta), value=-1 + for layer_idx in range(model_state.logical_to_physical_map.shape[0]): + new_logical = model_state.new_logical_to_physical_map[layer_idx].to( + logical_device ) - model_state.logical_to_physical_map[layer].copy_(new_logical) + max_slots = model_state.logical_to_physical_map.shape[-1] + slot_delta = max_slots - new_logical.shape[-1] + if slot_delta > 0: + new_logical = torch.nn.functional.pad( + new_logical, (0, slot_delta), value=-1 + ) + model_state.logical_to_physical_map[layer_idx].copy_(new_logical) replica_device = model_state.logical_replica_count.device - model_state.logical_replica_count[layer].copy_( - model_state.new_logical_replica_count[layer].to(replica_device) + model_state.logical_replica_count.copy_( + model_state.new_logical_replica_count.to(replica_device) ) def _all_ranks_buffer_ready(self, model_state: EplbModelState) -> bool: @@ -1037,7 +1041,7 @@ class EplbState: ep_rank=ep_group.rank(), ) transferred_layer = model_state.layer_to_transfer - self._update_layer_mapping_from_new(model_state, transferred_layer) + # After the main thread consumes, advance layer_to_transfer model_state.layer_to_transfer += 1 model_state.ep_buffer_ready = 0 @@ -1061,8 +1065,7 @@ class EplbState: assert model_state.new_logical_to_physical_map is not None assert model_state.new_logical_replica_count is not None if not is_profile: - for layer_idx in range(model_state.physical_to_logical_map.shape[0]): - self._update_layer_mapping_from_new(model_state, layer_idx) + self._update_layer_mapping_from_new(model_state) model_state.new_physical_to_logical_map = None model_state.new_logical_to_physical_map = None model_state.new_logical_replica_count = None From 777c90e9544c6ccb513597b959f7d8570c93b182 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 16 Dec 2025 14:56:54 +0000 Subject: [PATCH 25/30] Remove flag Signed-off-by: ilmarkov --- vllm/distributed/eplb/async_worker.py | 12 ++++++------ vllm/distributed/eplb/eplb_state.py | 6 ------ 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/vllm/distributed/eplb/async_worker.py b/vllm/distributed/eplb/async_worker.py index 2648ada89abc9..b2a58ded19edb 100644 --- a/vllm/distributed/eplb/async_worker.py +++ b/vllm/distributed/eplb/async_worker.py @@ -109,12 +109,12 @@ async def transfer_run_periodically( for model_state in state.model_states.values(): if not model_state.is_async_enabled: continue - if not model_state.new_indices_computed: - run_rebalance_experts(model_state, state) - logger.info( - "Async worker computed new indices for model %s", - model_state.model_name, - ) + # Rebalance experts is done once, only when the async worker wakes up. + run_rebalance_experts(model_state, state) + logger.info( + "Async worker computed new indices for model %s", + model_state.model_name, + ) current_num_layers = model_state.model.num_moe_layers while ( diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index ef5e58706722c..43c0bbdd72e44 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -197,10 +197,6 @@ class EplbModelState: """ Whether the async EPLB needs to poll peers for buffer readiness. """ - new_indices_computed: bool - """ - The flag indicates whether the new indices have been computed. - """ eplb_stats: EplbStats | None """ EPLB stats for the model. @@ -547,7 +543,6 @@ class EplbState: layer_to_transfer=0, rebalanced=False, pending_global_ready_check=False, - new_indices_computed=False, eplb_stats=None, is_unchanged=np.array([]), is_received_locally=np.array([]), @@ -928,7 +923,6 @@ class EplbState: eplb_model_state.rebalanced = True eplb_model_state.layer_to_transfer = 0 eplb_model_state.pending_global_ready_check = True - eplb_model_state.new_indices_computed = False # Signal async thread to start transferring layers if self.is_async and (not is_profile): self.rearrange_event.set() From 67ff54997d383b8c88db652c3df39b4eb9d6a8ea Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Wed, 17 Dec 2025 16:20:54 +0000 Subject: [PATCH 26/30] Fix accuracy issue Signed-off-by: ilmarkov --- vllm/distributed/eplb/eplb_state.py | 45 ++++++++++++++--------------- 1 file changed, 21 insertions(+), 24 deletions(-) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 43c0bbdd72e44..4e3fdad8bb9cc 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -942,7 +942,9 @@ class EplbState: is_profile=is_profile, ) - def _update_layer_mapping_from_new(self, model_state: EplbModelState) -> None: + def _update_layer_mapping_from_new( + self, model_state: EplbModelState, layer: int + ) -> None: if ( model_state.new_physical_to_logical_map is None or model_state.new_logical_to_physical_map is None @@ -951,33 +953,29 @@ class EplbState: return target_device = model_state.physical_to_logical_map.device - new_physical = model_state.new_physical_to_logical_map.to( - target_device, non_blocking=True - ) + new_physical = model_state.new_physical_to_logical_map if model_state.physical_to_logical_map.shape[1] != new_physical.shape[1]: - model_state.physical_to_logical_map = new_physical + model_state.physical_to_logical_map = new_physical.to( + target_device, non_blocking=True + ) else: - for layer_idx in range(model_state.physical_to_logical_map.shape[0]): - model_state.physical_to_logical_map[layer_idx].copy_( - new_physical[layer_idx] - ) + model_state.physical_to_logical_map[layer].copy_( + new_physical[layer].to(target_device, non_blocking=True) + ) logical_device = model_state.logical_to_physical_map.device - for layer_idx in range(model_state.logical_to_physical_map.shape[0]): - new_logical = model_state.new_logical_to_physical_map[layer_idx].to( - logical_device + new_logical = model_state.new_logical_to_physical_map[layer].to(logical_device) + max_slots = model_state.logical_to_physical_map.shape[-1] + slot_delta = max_slots - new_logical.shape[-1] + if slot_delta > 0: + new_logical = torch.nn.functional.pad( + new_logical, (0, slot_delta), value=-1 ) - max_slots = model_state.logical_to_physical_map.shape[-1] - slot_delta = max_slots - new_logical.shape[-1] - if slot_delta > 0: - new_logical = torch.nn.functional.pad( - new_logical, (0, slot_delta), value=-1 - ) - model_state.logical_to_physical_map[layer_idx].copy_(new_logical) + model_state.logical_to_physical_map[layer].copy_(new_logical) replica_device = model_state.logical_replica_count.device - model_state.logical_replica_count.copy_( - model_state.new_logical_replica_count.to(replica_device) + model_state.logical_replica_count[layer].copy_( + model_state.new_logical_replica_count[layer].to(replica_device) ) def _all_ranks_buffer_ready(self, model_state: EplbModelState) -> bool: @@ -1035,7 +1033,7 @@ class EplbState: ep_rank=ep_group.rank(), ) transferred_layer = model_state.layer_to_transfer - + self._update_layer_mapping_from_new(model_state, transferred_layer) # After the main thread consumes, advance layer_to_transfer model_state.layer_to_transfer += 1 model_state.ep_buffer_ready = 0 @@ -1058,8 +1056,7 @@ class EplbState: assert model_state.new_physical_to_logical_map is not None assert model_state.new_logical_to_physical_map is not None assert model_state.new_logical_replica_count is not None - if not is_profile: - self._update_layer_mapping_from_new(model_state) + model_state.new_physical_to_logical_map = None model_state.new_logical_to_physical_map = None model_state.new_logical_replica_count = None From f7b899202088418ad3bc0b04701662acdee8c5f5 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Thu, 18 Dec 2025 15:10:43 +0000 Subject: [PATCH 27/30] Keep memory copy blocking for elastic EP case Signed-off-by: ilmarkov --- vllm/distributed/eplb/eplb_state.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 4e3fdad8bb9cc..d1f0e181c5354 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -954,10 +954,10 @@ class EplbState: target_device = model_state.physical_to_logical_map.device new_physical = model_state.new_physical_to_logical_map + # In order to avoid race condition with async eplb worker, + # we need to copy blocking in case of updated EP size. if model_state.physical_to_logical_map.shape[1] != new_physical.shape[1]: - model_state.physical_to_logical_map = new_physical.to( - target_device, non_blocking=True - ) + model_state.physical_to_logical_map = new_physical.to(target_device) else: model_state.physical_to_logical_map[layer].copy_( new_physical[layer].to(target_device, non_blocking=True) From b679f66e586f53c8dc742cc245078b34561fd44a Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Mon, 22 Dec 2025 13:44:29 +0000 Subject: [PATCH 28/30] Make index access granular Signed-off-by: ilmarkov --- tests/distributed/test_eplb_execute.py | 7 ++- vllm/distributed/eplb/async_worker.py | 27 ++++++---- vllm/distributed/eplb/rebalance_execute.py | 59 ++++++++++++---------- 3 files changed, 52 insertions(+), 41 deletions(-) diff --git a/tests/distributed/test_eplb_execute.py b/tests/distributed/test_eplb_execute.py index 14ad148da6c45..b1ba1d1c48d23 100644 --- a/tests/distributed/test_eplb_execute.py +++ b/tests/distributed/test_eplb_execute.py @@ -295,12 +295,11 @@ def _test_async_transfer_layer_without_mtp_worker( for layer_idx in range(num_layers): is_unchanged, is_received_locally, recv_metadata = asyncio.run( transfer_layer( - old_global_expert_indices=old_indices_cpu, - new_global_expert_indices=new_indices_cpu, - expert_weights=expert_weights, + old_layer_indices=old_indices_cpu[layer_idx], + new_layer_indices=new_indices_cpu[layer_idx], + expert_weights=expert_weights[layer_idx], expert_weights_buffer=expert_buffer, ep_group=ep_group, - layer=layer_idx, cuda_stream=cuda_stream, ) ) diff --git a/vllm/distributed/eplb/async_worker.py b/vllm/distributed/eplb/async_worker.py index b2a58ded19edb..6897015375614 100644 --- a/vllm/distributed/eplb/async_worker.py +++ b/vllm/distributed/eplb/async_worker.py @@ -109,8 +109,7 @@ async def transfer_run_periodically( for model_state in state.model_states.values(): if not model_state.is_async_enabled: continue - # Rebalance experts is done once, only when the async worker wakes up. - run_rebalance_experts(model_state, state) + rebalancing_algorithm_executed = False logger.info( "Async worker computed new indices for model %s", model_state.model_name, @@ -121,28 +120,34 @@ async def transfer_run_periodically( model_state.rebalanced and model_state.layer_to_transfer < current_num_layers ): - if ( - not model_state.ep_buffer_ready - and model_state.rebalanced - and model_state.new_physical_to_logical_map is not None - ): + if not model_state.ep_buffer_ready and model_state.rebalanced: await asyncio.to_thread(model_state.buffer_lock.acquire) try: if model_state.layer_to_transfer >= current_num_layers: break + if not rebalancing_algorithm_executed: + run_rebalance_experts(model_state, state) + rebalancing_algorithm_executed = True + + layer_idx = model_state.layer_to_transfer + old_layer_indices = model_state.old_physical_to_logical_map[ + layer_idx + ] + new_layer_indices = model_state.new_physical_to_logical_map[ + layer_idx + ] ( model_state.is_unchanged, model_state.is_received_locally, model_state.recv_metadata, ) = await transfer_layer( - old_global_expert_indices=model_state.physical_to_logical_map, - new_global_expert_indices=model_state.new_physical_to_logical_map, - expert_weights=model_state.model.expert_weights, + old_layer_indices=old_layer_indices, + new_layer_indices=new_layer_indices, + expert_weights=model_state.model.expert_weights[layer_idx], expert_weights_buffer=model_state.expert_buffer, ep_group=ep_group, is_profile=is_profile, - layer=model_state.layer_to_transfer, cuda_stream=cuda_stream, rank_mapping=rank_mapping, ) diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index b7b6c11b239ac..493df18675a10 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -434,13 +434,12 @@ def move_from_buffer( async def transfer_layer( - old_global_expert_indices: torch.Tensor, - new_global_expert_indices: torch.Tensor, - expert_weights: Sequence[Iterable[torch.Tensor]], + old_layer_indices: torch.Tensor, + new_layer_indices: torch.Tensor, + expert_weights: Iterable[torch.Tensor], expert_weights_buffer: Sequence[torch.Tensor], ep_group: ProcessGroup, is_profile: bool = False, - layer: int = 0, cuda_stream: torch.cuda.Stream | None = None, rank_mapping: dict[int, int] | None = None, ) -> MoveToBufferResult: @@ -451,55 +450,63 @@ async def transfer_layer( while keys are physical. Args: - old_global_expert_indices: Shape (num_moe_layers, num_physical_experts). - new_global_expert_indices: Shape (num_moe_layers, num_physical_experts). - expert_weights: A sequence of shape (num_moe_layers)(weight_count) - of tensors of shape (num_local_physical_experts, hidden_size_i). - For example, a linear layer may have up and down projection, - so weight_count = 2. Each weight's hidden size can be different. + old_layer_indices: Shape (num_physical_experts,). + new_layer_indices: Shape (num_physical_experts,). + expert_weights: Iterable of weight tensors for this layer, each with shape + (num_local_physical_experts, hidden_size_i). + For example, a linear layer may have up and down projection. + expert_weights_buffer: Intermediate buffers (one per weight tensor). ep_group: The device process group for expert parallelism. is_profile (bool): If `True`, do not perform any actual weight copy. This is used during profile run, where we only perform dummy communications to reserve enough memory for the buffers. + cuda_stream: CUDA stream for async copies (can be None for sync mode). + rank_mapping: Optional rank mapping for elastic expert parallelism. Returns: - is_unchanged (np.ndarray): (1, num_local_experts), True where expert + is_unchanged (np.ndarray): (num_local_experts,), True where expert is left unchanged. - is_received_locally (np.ndarray): (1, num_local_experts), True where expert + is_received_locally (np.ndarray): (num_local_experts,), True where expert can be received locally. RecvMetadata: Metadata needed for completing remote weight transfers. """ ep_size = ep_group.size() if rank_mapping is not None: + # Add a layer dimension for compatibility with mapping functions + old_layer_indices_2d = old_layer_indices.unsqueeze(0) + new_layer_indices_2d = new_layer_indices.unsqueeze(0) + if len(rank_mapping) == ep_group.size(): # scale down - new_global_expert_indices = _map_new_expert_indices_with_rank_mapping( - new_global_expert_indices, + new_layer_indices_2d = _map_new_expert_indices_with_rank_mapping( + new_layer_indices_2d, rank_mapping, ) else: # scale up - old_global_expert_indices = _map_old_expert_indices_with_rank_mapping( - old_global_expert_indices, + old_layer_indices_2d = _map_old_expert_indices_with_rank_mapping( + old_layer_indices_2d, rank_mapping, ep_group.size(), ) - assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1] - num_moe_layers, num_physical_experts = old_global_expert_indices.shape - assert len(expert_weights) == num_moe_layers - num_local_physical_experts = next(iter(expert_weights[0])).shape[0] - assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts) + # Remove the layer dimension + old_layer_indices = old_layer_indices_2d.squeeze(0) + new_layer_indices = new_layer_indices_2d.squeeze(0) + + assert old_layer_indices.shape == new_layer_indices.shape + num_physical_experts = old_layer_indices.shape[0] + num_local_physical_experts = next(iter(expert_weights)).shape[0] assert num_physical_experts == ep_size * num_local_physical_experts - old_global_expert_indices_np = old_global_expert_indices.cpu().numpy() - new_global_expert_indices_np = new_global_expert_indices.cpu().numpy() + old_layer_indices_np = old_layer_indices.cpu().numpy() + new_layer_indices_np = new_layer_indices.cpu().numpy() is_unchanged, is_received_locally, recv_metadata = move_to_buffer( num_local_experts=num_local_physical_experts, - old_indices=old_global_expert_indices_np[layer], - new_indices=new_global_expert_indices_np[layer], - expert_weights=expert_weights[layer], + old_indices=old_layer_indices_np, + new_indices=new_layer_indices_np, + expert_weights=expert_weights, expert_weights_buffers=expert_weights_buffer, cuda_stream=cuda_stream, ep_group=ep_group, From 154245dc5eca8ab6a7108017e92d27c4b9fe8885 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 23 Dec 2025 12:09:14 +0000 Subject: [PATCH 29/30] Fix Signed-off-by: ilmarkov --- vllm/distributed/eplb/async_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/distributed/eplb/async_worker.py b/vllm/distributed/eplb/async_worker.py index 6897015375614..3e4017e86b300 100644 --- a/vllm/distributed/eplb/async_worker.py +++ b/vllm/distributed/eplb/async_worker.py @@ -130,7 +130,7 @@ async def transfer_run_periodically( rebalancing_algorithm_executed = True layer_idx = model_state.layer_to_transfer - old_layer_indices = model_state.old_physical_to_logical_map[ + old_layer_indices = model_state.physical_to_logical_map[ layer_idx ] new_layer_indices = model_state.new_physical_to_logical_map[ From 5cd45c646f53923fc5b5cc046aa044c1ce94aa08 Mon Sep 17 00:00:00 2001 From: ilmarkov Date: Tue, 23 Dec 2025 15:07:29 +0000 Subject: [PATCH 30/30] Fix pre-commit Signed-off-by: ilmarkov --- vllm/distributed/eplb/async_worker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/distributed/eplb/async_worker.py b/vllm/distributed/eplb/async_worker.py index 3e4017e86b300..b0265262632bf 100644 --- a/vllm/distributed/eplb/async_worker.py +++ b/vllm/distributed/eplb/async_worker.py @@ -128,6 +128,7 @@ async def transfer_run_periodically( if not rebalancing_algorithm_executed: run_rebalance_experts(model_state, state) rebalancing_algorithm_executed = True + assert model_state.new_physical_to_logical_map is not None layer_idx = model_state.layer_to_transfer old_layer_indices = model_state.physical_to_logical_map[