diff --git a/vllm/distributed/eplb/policy/default.py b/vllm/distributed/eplb/policy/default.py index 172f6f044f89e..d4f3b65cdc0c5 100644 --- a/vllm/distributed/eplb/policy/default.py +++ b/vllm/distributed/eplb/policy/default.py @@ -93,7 +93,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy): Returns: phy2log: [X, num_phy], logical expert id of each physical expert - rank: [X, num_phy], the replica rank + replica_idx: [X, num_phy], the index of the replica for each logical expert logcnt: [X, num_log], number of replicas for each logical expert """ n, num_log = weight.shape @@ -101,15 +101,15 @@ class DefaultEplbPolicy(AbstractEplbPolicy): assert num_redundant >= 0 device = weight.device phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) - rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) + replica_idx = torch.zeros(n, num_phy, dtype=torch.int64, device=device) logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) arangen = torch.arange(n, dtype=torch.int64, device=device) for i in range(num_log, num_phy): redundant_indices = (weight / logcnt).max(dim=-1).indices phy2log[:, i] = redundant_indices - rank[:, i] = logcnt[arangen, redundant_indices] + replica_idx[:, i] = logcnt[arangen, redundant_indices] logcnt[arangen, redundant_indices] += 1 - return phy2log, rank, logcnt + return phy2log, replica_idx, logcnt @classmethod def rebalance_experts_hierarchical( @@ -132,7 +132,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy): Returns: phy2log: [layers, num_replicas], the expert index of each replica - log2phy: [layers, num_logical_experts, X], + pphy_replicas_idx: [layers, num_logical_experts, X], the replica indices for each expert logcnt: [layers, num_logical_experts], number of physical replicas for each logical expert @@ -177,7 +177,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy): tokens_per_mlog = weight.gather(-1, mlog2log).view( -1, num_logical_experts // num_nodes ) - phy2mlog, phyrank, mlogcnt = cls.replicate_experts( + phy2mlog, replicas_idx, mlogcnt = cls.replicate_experts( tokens_per_mlog, num_physical_experts // num_nodes ) @@ -203,15 +203,15 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ).view(1, -1, 1) ).flatten(-2) pphy2log = mlog2log.gather(-1, pphy2mlog) - pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) + pphy_replicas_idx = replicas_idx.gather(-1, pphy2phy).view(num_layers, -1) logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) - return pphy2log, pphyrank, logcnt + return pphy2log, pphy_replicas_idx, logcnt @classmethod def preserve_intragpu_slots( cls, phy2log: torch.Tensor, - phyrank: torch.Tensor, + phy_replicas_idx: torch.Tensor, num_ranks: int, old_global_expert_indices: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -223,56 +223,52 @@ class DefaultEplbPolicy(AbstractEplbPolicy): the old and new mappings. """ device = phy2log.device - new_num_phy = phy2log.shape[1] - old_num_phy = old_global_expert_indices.shape[1] - if ( - num_ranks <= 0 - or new_num_phy % num_ranks != 0 - or old_num_phy % num_ranks != 0 - or (new_num_phy // num_ranks) != (old_num_phy // num_ranks) - ): - return phy2log, phyrank + num_phy_experts = phy2log.shape[1] + if num_ranks <= 0 or num_phy_experts % num_ranks != 0: + return phy2log, phy_replicas_idx # Move to CPU and convert to NumPy for processing - phy2log_np = phy2log.cpu().numpy() - phyrank_np = phyrank.cpu().numpy() - old_np = old_global_expert_indices.cpu().numpy() + new_phy2log_np = phy2log.cpu().numpy() + replicas_idx_np = phy_replicas_idx.cpu().numpy() + old_phy2log_np = old_global_expert_indices.cpu().numpy() - slots_per_gpu = new_num_phy // num_ranks - num_layers = phy2log_np.shape[0] + slots_per_gpu = num_phy_experts // num_ranks + num_layers = new_phy2log_np.shape[0] - post_phy2log_np = phy2log_np.copy() - post_phyrank_np = phyrank_np.copy() + post_phy2log_np = new_phy2log_np.copy() + post_phy_replicas_idx_np = replicas_idx_np.copy() for gpu_idx in range(num_ranks): start = gpu_idx * slots_per_gpu end = start + slots_per_gpu - # Segments across all layers for this GPU - old_seg = old_np[:, start:end] # [L, S] - new_seg = phy2log_np[:, start:end] # [L, S] - new_rnk = phyrank_np[:, start:end] # [L, S] + # Experts across all layers for this GPU + old_local = old_phy2log_np[:, start:end] # [layers, slots] + new_local = new_phy2log_np[:, start:end] # [layers, slots] + new_ridx = replicas_idx_np[:, start:end] # [layers, slots] used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool) preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool) # First pass: preserve same-logical experts in their previous slots - for pos in range(slots_per_gpu): - # matches: [L, S], True where new_seg has the same logical value - # as the old slot 'pos' and not used - matches = (new_seg == old_seg[:, pos][:, None]) & (~used_new_indices) + for slot_idx in range(slots_per_gpu): + # matches: [layers, slots], True where new local experts have + # the same logical value as the old from 'slot_idx' and not checked yet + matches = (new_local == old_local[:, slot_idx][:, None]) & ( + ~used_new_indices + ) has_any = matches.any(axis=1) if np.any(has_any): first_idx = np.argmax(matches, axis=1) layer_indices = np.nonzero(has_any)[0] matched_new_positions = first_idx[layer_indices] - post_phy2log_np[layer_indices, start + pos] = new_seg[ - layer_indices, matched_new_positions - ] - post_phyrank_np[layer_indices, start + pos] = new_rnk[ + post_phy2log_np[layer_indices, start + slot_idx] = new_local[ layer_indices, matched_new_positions ] + post_phy_replicas_idx_np[layer_indices, start + slot_idx] = ( + new_ridx[layer_indices, matched_new_positions] + ) used_new_indices[layer_indices, matched_new_positions] = True - preserved_positions[layer_indices, pos] = True + preserved_positions[layer_indices, slot_idx] = True # Second pass: fill remaining slots with remaining new experts remaining_mask = ~used_new_indices # [L, S] @@ -299,17 +295,17 @@ class DefaultEplbPolicy(AbstractEplbPolicy): continue src_pos = remaining_indices[layer_idx, :k] dst_pos = fill_indices[layer_idx, :k] - post_phy2log_np[layer_idx, start + dst_pos] = new_seg[ + post_phy2log_np[layer_idx, start + dst_pos] = new_local[ layer_idx, src_pos ] - post_phyrank_np[layer_idx, start + dst_pos] = new_rnk[ + post_phy_replicas_idx_np[layer_idx, start + dst_pos] = new_ridx[ layer_idx, src_pos ] # Convert back to torch and move to original device post_phy2log = torch.from_numpy(post_phy2log_np).to(device) - post_phyrank = torch.from_numpy(post_phyrank_np).to(device) - return post_phy2log, post_phyrank + post_phy_replicas_idx = torch.from_numpy(post_phy_replicas_idx_np).to(device) + return post_phy2log, post_phy_replicas_idx @classmethod def rebalance_experts( @@ -348,12 +344,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy): weight = weight.float() if num_groups % num_nodes == 0: # use hierarchical load-balance policy - phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical( + phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical( weight, num_replicas, num_groups, num_nodes, num_ranks ) else: # use global load-balance policy - phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical( + phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical( weight, num_replicas, 1, 1, num_ranks ) # Optional postprocessing to preserve slots for experts moving @@ -362,8 +358,8 @@ class DefaultEplbPolicy(AbstractEplbPolicy): # Helps to avoid unnecessary weight copying when experts move # within the same GPU. if old_global_expert_indices is not None: - phy2log, phyrank = cls.preserve_intragpu_slots( - phy2log, phyrank, num_ranks, old_global_expert_indices + phy2log, phy_replicas_idx = cls.preserve_intragpu_slots( + phy2log, phy_replicas_idx, num_ranks, old_global_expert_indices ) num_redundant_experts = num_replicas - num_logical_experts maxlogcnt = num_redundant_experts + 1 @@ -375,7 +371,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy): ) log2phy.view(num_layers, -1).scatter_( -1, - phy2log * maxlogcnt + phyrank, + phy2log * maxlogcnt + phy_replicas_idx, torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand( num_layers, -1 ), diff --git a/vllm/distributed/eplb/rebalance_execute.py b/vllm/distributed/eplb/rebalance_execute.py index 641fad9f4e788..717d9f6d88793 100644 --- a/vllm/distributed/eplb/rebalance_execute.py +++ b/vllm/distributed/eplb/rebalance_execute.py @@ -361,24 +361,23 @@ def move_from_buffer( is_received_locally: np.ndarray, recv_metadata: RecvMetadata, new_indices: np.ndarray, - ep_group: ProcessGroup, + ep_rank: int, ) -> None: """ Copies expert weights from communication buffers back to the target weight tensors after EPLB rebalancing. Args: - expert_weights: List of weight tensors for a MoE layer. - expert_weights_buffers: Intermediate buffers matching ``expert_weights``. + expert_weights: List of the actual MoE layer weights used in the execution. + expert_weights_buffers: Intermediate buffers containing the experts weights + after the transfer is completed. is_unchanged: (num_local_experts,), True where an expert row is unchanged. is_received_locally: (num_local_experts,), True where a row is updated locally. recv_metadata: RecvMetadata containing remote receive metadata. new_indices: (num_experts_total,) mapping from local rows to desired (possibly global) expert id, after rebalance. - ep_group: torch.distributed.ProcessGroup for expert parallel communication - domain. + ep_rank: Rank of the process in the expert parallel group. """ - ep_rank = ep_group.rank() recv_primary_mask = recv_metadata.recv_primary_mask recv_count = recv_metadata.recv_count recv_expert_ids = recv_metadata.recv_expert_ids @@ -395,16 +394,17 @@ def move_from_buffer( for w, b in zip(expert_weights, expert_weights_buffers): w[dst].copy_(b[dst], non_blocking=True) - # Duplicate remote received rows to non-primary duplicate dsts if recv_count == 0: return + # Duplicate remote received rows to non-primary duplicate dsts base = ep_rank * num_local_experts local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)] duplicate_mask = np.logical_and( np.logical_and(~is_unchanged, ~is_received_locally), np.logical_and(~recv_primary_mask, local_experts != -1), ) + # All received experts are unique in the destination, so no need to copy duplicates if not bool(duplicate_mask.any()): return @@ -607,7 +607,7 @@ def rearrange_expert_weights_inplace( is_received_locally=is_received_locally, recv_metadata=recv_metadata, new_indices=new_global_expert_indices_cpu[layer_idx], - ep_group=ep_group, + ep_rank=ep_group.rank(), )