Address review comments

Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
ilmarkov 2025-12-12 14:45:00 +00:00
parent a5ecdc18c0
commit 040ae89c5e
2 changed files with 51 additions and 55 deletions

View File

@ -93,7 +93,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
Returns: Returns:
phy2log: [X, num_phy], logical expert id of each physical expert phy2log: [X, num_phy], logical expert id of each physical expert
rank: [X, num_phy], the replica rank replica_idx: [X, num_phy], the index of the replica for each logical expert
logcnt: [X, num_log], number of replicas for each logical expert logcnt: [X, num_log], number of replicas for each logical expert
""" """
n, num_log = weight.shape n, num_log = weight.shape
@ -101,15 +101,15 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
assert num_redundant >= 0 assert num_redundant >= 0
device = weight.device device = weight.device
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) replica_idx = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
arangen = torch.arange(n, dtype=torch.int64, device=device) arangen = torch.arange(n, dtype=torch.int64, device=device)
for i in range(num_log, num_phy): for i in range(num_log, num_phy):
redundant_indices = (weight / logcnt).max(dim=-1).indices redundant_indices = (weight / logcnt).max(dim=-1).indices
phy2log[:, i] = redundant_indices phy2log[:, i] = redundant_indices
rank[:, i] = logcnt[arangen, redundant_indices] replica_idx[:, i] = logcnt[arangen, redundant_indices]
logcnt[arangen, redundant_indices] += 1 logcnt[arangen, redundant_indices] += 1
return phy2log, rank, logcnt return phy2log, replica_idx, logcnt
@classmethod @classmethod
def rebalance_experts_hierarchical( def rebalance_experts_hierarchical(
@ -132,7 +132,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
Returns: Returns:
phy2log: [layers, num_replicas], the expert phy2log: [layers, num_replicas], the expert
index of each replica index of each replica
log2phy: [layers, num_logical_experts, X], pphy_replicas_idx: [layers, num_logical_experts, X],
the replica indices for each expert the replica indices for each expert
logcnt: [layers, num_logical_experts], number of logcnt: [layers, num_logical_experts], number of
physical replicas for each logical expert physical replicas for each logical expert
@ -177,7 +177,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
tokens_per_mlog = weight.gather(-1, mlog2log).view( tokens_per_mlog = weight.gather(-1, mlog2log).view(
-1, num_logical_experts // num_nodes -1, num_logical_experts // num_nodes
) )
phy2mlog, phyrank, mlogcnt = cls.replicate_experts( phy2mlog, replicas_idx, mlogcnt = cls.replicate_experts(
tokens_per_mlog, num_physical_experts // num_nodes tokens_per_mlog, num_physical_experts // num_nodes
) )
@ -203,15 +203,15 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
).view(1, -1, 1) ).view(1, -1, 1)
).flatten(-2) ).flatten(-2)
pphy2log = mlog2log.gather(-1, pphy2mlog) pphy2log = mlog2log.gather(-1, pphy2mlog)
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) pphy_replicas_idx = replicas_idx.gather(-1, pphy2phy).view(num_layers, -1)
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
return pphy2log, pphyrank, logcnt return pphy2log, pphy_replicas_idx, logcnt
@classmethod @classmethod
def preserve_intragpu_slots( def preserve_intragpu_slots(
cls, cls,
phy2log: torch.Tensor, phy2log: torch.Tensor,
phyrank: torch.Tensor, phy_replicas_idx: torch.Tensor,
num_ranks: int, num_ranks: int,
old_global_expert_indices: torch.Tensor, old_global_expert_indices: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
@ -223,56 +223,52 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
the old and new mappings. the old and new mappings.
""" """
device = phy2log.device device = phy2log.device
new_num_phy = phy2log.shape[1] num_phy_experts = phy2log.shape[1]
old_num_phy = old_global_expert_indices.shape[1] if num_ranks <= 0 or num_phy_experts % num_ranks != 0:
if ( return phy2log, phy_replicas_idx
num_ranks <= 0
or new_num_phy % num_ranks != 0
or old_num_phy % num_ranks != 0
or (new_num_phy // num_ranks) != (old_num_phy // num_ranks)
):
return phy2log, phyrank
# Move to CPU and convert to NumPy for processing # Move to CPU and convert to NumPy for processing
phy2log_np = phy2log.cpu().numpy() new_phy2log_np = phy2log.cpu().numpy()
phyrank_np = phyrank.cpu().numpy() replicas_idx_np = phy_replicas_idx.cpu().numpy()
old_np = old_global_expert_indices.cpu().numpy() old_phy2log_np = old_global_expert_indices.cpu().numpy()
slots_per_gpu = new_num_phy // num_ranks slots_per_gpu = num_phy_experts // num_ranks
num_layers = phy2log_np.shape[0] num_layers = new_phy2log_np.shape[0]
post_phy2log_np = phy2log_np.copy() post_phy2log_np = new_phy2log_np.copy()
post_phyrank_np = phyrank_np.copy() post_phy_replicas_idx_np = replicas_idx_np.copy()
for gpu_idx in range(num_ranks): for gpu_idx in range(num_ranks):
start = gpu_idx * slots_per_gpu start = gpu_idx * slots_per_gpu
end = start + slots_per_gpu end = start + slots_per_gpu
# Segments across all layers for this GPU # Experts across all layers for this GPU
old_seg = old_np[:, start:end] # [L, S] old_local = old_phy2log_np[:, start:end] # [layers, slots]
new_seg = phy2log_np[:, start:end] # [L, S] new_local = new_phy2log_np[:, start:end] # [layers, slots]
new_rnk = phyrank_np[:, start:end] # [L, S] new_ridx = replicas_idx_np[:, start:end] # [layers, slots]
used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool) used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool)
preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool) preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool)
# First pass: preserve same-logical experts in their previous slots # First pass: preserve same-logical experts in their previous slots
for pos in range(slots_per_gpu): for slot_idx in range(slots_per_gpu):
# matches: [L, S], True where new_seg has the same logical value # matches: [layers, slots], True where new local experts have
# as the old slot 'pos' and not used # the same logical value as the old from 'slot_idx' and not checked yet
matches = (new_seg == old_seg[:, pos][:, None]) & (~used_new_indices) matches = (new_local == old_local[:, slot_idx][:, None]) & (
~used_new_indices
)
has_any = matches.any(axis=1) has_any = matches.any(axis=1)
if np.any(has_any): if np.any(has_any):
first_idx = np.argmax(matches, axis=1) first_idx = np.argmax(matches, axis=1)
layer_indices = np.nonzero(has_any)[0] layer_indices = np.nonzero(has_any)[0]
matched_new_positions = first_idx[layer_indices] matched_new_positions = first_idx[layer_indices]
post_phy2log_np[layer_indices, start + pos] = new_seg[ post_phy2log_np[layer_indices, start + slot_idx] = new_local[
layer_indices, matched_new_positions
]
post_phyrank_np[layer_indices, start + pos] = new_rnk[
layer_indices, matched_new_positions layer_indices, matched_new_positions
] ]
post_phy_replicas_idx_np[layer_indices, start + slot_idx] = (
new_ridx[layer_indices, matched_new_positions]
)
used_new_indices[layer_indices, matched_new_positions] = True used_new_indices[layer_indices, matched_new_positions] = True
preserved_positions[layer_indices, pos] = True preserved_positions[layer_indices, slot_idx] = True
# Second pass: fill remaining slots with remaining new experts # Second pass: fill remaining slots with remaining new experts
remaining_mask = ~used_new_indices # [L, S] remaining_mask = ~used_new_indices # [L, S]
@ -299,17 +295,17 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
continue continue
src_pos = remaining_indices[layer_idx, :k] src_pos = remaining_indices[layer_idx, :k]
dst_pos = fill_indices[layer_idx, :k] dst_pos = fill_indices[layer_idx, :k]
post_phy2log_np[layer_idx, start + dst_pos] = new_seg[ post_phy2log_np[layer_idx, start + dst_pos] = new_local[
layer_idx, src_pos layer_idx, src_pos
] ]
post_phyrank_np[layer_idx, start + dst_pos] = new_rnk[ post_phy_replicas_idx_np[layer_idx, start + dst_pos] = new_ridx[
layer_idx, src_pos layer_idx, src_pos
] ]
# Convert back to torch and move to original device # Convert back to torch and move to original device
post_phy2log = torch.from_numpy(post_phy2log_np).to(device) post_phy2log = torch.from_numpy(post_phy2log_np).to(device)
post_phyrank = torch.from_numpy(post_phyrank_np).to(device) post_phy_replicas_idx = torch.from_numpy(post_phy_replicas_idx_np).to(device)
return post_phy2log, post_phyrank return post_phy2log, post_phy_replicas_idx
@classmethod @classmethod
def rebalance_experts( def rebalance_experts(
@ -348,12 +344,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
weight = weight.float() weight = weight.float()
if num_groups % num_nodes == 0: if num_groups % num_nodes == 0:
# use hierarchical load-balance policy # use hierarchical load-balance policy
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical( phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_ranks weight, num_replicas, num_groups, num_nodes, num_ranks
) )
else: else:
# use global load-balance policy # use global load-balance policy
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical( phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical(
weight, num_replicas, 1, 1, num_ranks weight, num_replicas, 1, 1, num_ranks
) )
# Optional postprocessing to preserve slots for experts moving # Optional postprocessing to preserve slots for experts moving
@ -362,8 +358,8 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
# Helps to avoid unnecessary weight copying when experts move # Helps to avoid unnecessary weight copying when experts move
# within the same GPU. # within the same GPU.
if old_global_expert_indices is not None: if old_global_expert_indices is not None:
phy2log, phyrank = cls.preserve_intragpu_slots( phy2log, phy_replicas_idx = cls.preserve_intragpu_slots(
phy2log, phyrank, num_ranks, old_global_expert_indices phy2log, phy_replicas_idx, num_ranks, old_global_expert_indices
) )
num_redundant_experts = num_replicas - num_logical_experts num_redundant_experts = num_replicas - num_logical_experts
maxlogcnt = num_redundant_experts + 1 maxlogcnt = num_redundant_experts + 1
@ -375,7 +371,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
) )
log2phy.view(num_layers, -1).scatter_( log2phy.view(num_layers, -1).scatter_(
-1, -1,
phy2log * maxlogcnt + phyrank, phy2log * maxlogcnt + phy_replicas_idx,
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand( torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
num_layers, -1 num_layers, -1
), ),

View File

@ -361,24 +361,23 @@ def move_from_buffer(
is_received_locally: np.ndarray, is_received_locally: np.ndarray,
recv_metadata: RecvMetadata, recv_metadata: RecvMetadata,
new_indices: np.ndarray, new_indices: np.ndarray,
ep_group: ProcessGroup, ep_rank: int,
) -> None: ) -> None:
""" """
Copies expert weights from communication buffers back to the target weight tensors Copies expert weights from communication buffers back to the target weight tensors
after EPLB rebalancing. after EPLB rebalancing.
Args: Args:
expert_weights: List of weight tensors for a MoE layer. expert_weights: List of the actual MoE layer weights used in the execution.
expert_weights_buffers: Intermediate buffers matching ``expert_weights``. expert_weights_buffers: Intermediate buffers containing the experts weights
after the transfer is completed.
is_unchanged: (num_local_experts,), True where an expert row is unchanged. is_unchanged: (num_local_experts,), True where an expert row is unchanged.
is_received_locally: (num_local_experts,), True where a row is updated locally. is_received_locally: (num_local_experts,), True where a row is updated locally.
recv_metadata: RecvMetadata containing remote receive metadata. recv_metadata: RecvMetadata containing remote receive metadata.
new_indices: (num_experts_total,) mapping from local rows to desired new_indices: (num_experts_total,) mapping from local rows to desired
(possibly global) expert id, after rebalance. (possibly global) expert id, after rebalance.
ep_group: torch.distributed.ProcessGroup for expert parallel communication ep_rank: Rank of the process in the expert parallel group.
domain.
""" """
ep_rank = ep_group.rank()
recv_primary_mask = recv_metadata.recv_primary_mask recv_primary_mask = recv_metadata.recv_primary_mask
recv_count = recv_metadata.recv_count recv_count = recv_metadata.recv_count
recv_expert_ids = recv_metadata.recv_expert_ids recv_expert_ids = recv_metadata.recv_expert_ids
@ -395,16 +394,17 @@ def move_from_buffer(
for w, b in zip(expert_weights, expert_weights_buffers): for w, b in zip(expert_weights, expert_weights_buffers):
w[dst].copy_(b[dst], non_blocking=True) w[dst].copy_(b[dst], non_blocking=True)
# Duplicate remote received rows to non-primary duplicate dsts
if recv_count == 0: if recv_count == 0:
return return
# Duplicate remote received rows to non-primary duplicate dsts
base = ep_rank * num_local_experts base = ep_rank * num_local_experts
local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)] local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)]
duplicate_mask = np.logical_and( duplicate_mask = np.logical_and(
np.logical_and(~is_unchanged, ~is_received_locally), np.logical_and(~is_unchanged, ~is_received_locally),
np.logical_and(~recv_primary_mask, local_experts != -1), np.logical_and(~recv_primary_mask, local_experts != -1),
) )
# All received experts are unique in the destination, so no need to copy duplicates
if not bool(duplicate_mask.any()): if not bool(duplicate_mask.any()):
return return
@ -607,7 +607,7 @@ def rearrange_expert_weights_inplace(
is_received_locally=is_received_locally, is_received_locally=is_received_locally,
recv_metadata=recv_metadata, recv_metadata=recv_metadata,
new_indices=new_global_expert_indices_cpu[layer_idx], new_indices=new_global_expert_indices_cpu[layer_idx],
ep_group=ep_group, ep_rank=ep_group.rank(),
) )