Address review comments

Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
ilmarkov 2025-12-12 14:45:00 +00:00
parent a5ecdc18c0
commit 040ae89c5e
2 changed files with 51 additions and 55 deletions

View File

@ -93,7 +93,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
Returns:
phy2log: [X, num_phy], logical expert id of each physical expert
rank: [X, num_phy], the replica rank
replica_idx: [X, num_phy], the index of the replica for each logical expert
logcnt: [X, num_log], number of replicas for each logical expert
"""
n, num_log = weight.shape
@ -101,15 +101,15 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
assert num_redundant >= 0
device = weight.device
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
replica_idx = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
arangen = torch.arange(n, dtype=torch.int64, device=device)
for i in range(num_log, num_phy):
redundant_indices = (weight / logcnt).max(dim=-1).indices
phy2log[:, i] = redundant_indices
rank[:, i] = logcnt[arangen, redundant_indices]
replica_idx[:, i] = logcnt[arangen, redundant_indices]
logcnt[arangen, redundant_indices] += 1
return phy2log, rank, logcnt
return phy2log, replica_idx, logcnt
@classmethod
def rebalance_experts_hierarchical(
@ -132,7 +132,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
Returns:
phy2log: [layers, num_replicas], the expert
index of each replica
log2phy: [layers, num_logical_experts, X],
pphy_replicas_idx: [layers, num_logical_experts, X],
the replica indices for each expert
logcnt: [layers, num_logical_experts], number of
physical replicas for each logical expert
@ -177,7 +177,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
tokens_per_mlog = weight.gather(-1, mlog2log).view(
-1, num_logical_experts // num_nodes
)
phy2mlog, phyrank, mlogcnt = cls.replicate_experts(
phy2mlog, replicas_idx, mlogcnt = cls.replicate_experts(
tokens_per_mlog, num_physical_experts // num_nodes
)
@ -203,15 +203,15 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
).view(1, -1, 1)
).flatten(-2)
pphy2log = mlog2log.gather(-1, pphy2mlog)
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
pphy_replicas_idx = replicas_idx.gather(-1, pphy2phy).view(num_layers, -1)
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
return pphy2log, pphyrank, logcnt
return pphy2log, pphy_replicas_idx, logcnt
@classmethod
def preserve_intragpu_slots(
cls,
phy2log: torch.Tensor,
phyrank: torch.Tensor,
phy_replicas_idx: torch.Tensor,
num_ranks: int,
old_global_expert_indices: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
@ -223,56 +223,52 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
the old and new mappings.
"""
device = phy2log.device
new_num_phy = phy2log.shape[1]
old_num_phy = old_global_expert_indices.shape[1]
if (
num_ranks <= 0
or new_num_phy % num_ranks != 0
or old_num_phy % num_ranks != 0
or (new_num_phy // num_ranks) != (old_num_phy // num_ranks)
):
return phy2log, phyrank
num_phy_experts = phy2log.shape[1]
if num_ranks <= 0 or num_phy_experts % num_ranks != 0:
return phy2log, phy_replicas_idx
# Move to CPU and convert to NumPy for processing
phy2log_np = phy2log.cpu().numpy()
phyrank_np = phyrank.cpu().numpy()
old_np = old_global_expert_indices.cpu().numpy()
new_phy2log_np = phy2log.cpu().numpy()
replicas_idx_np = phy_replicas_idx.cpu().numpy()
old_phy2log_np = old_global_expert_indices.cpu().numpy()
slots_per_gpu = new_num_phy // num_ranks
num_layers = phy2log_np.shape[0]
slots_per_gpu = num_phy_experts // num_ranks
num_layers = new_phy2log_np.shape[0]
post_phy2log_np = phy2log_np.copy()
post_phyrank_np = phyrank_np.copy()
post_phy2log_np = new_phy2log_np.copy()
post_phy_replicas_idx_np = replicas_idx_np.copy()
for gpu_idx in range(num_ranks):
start = gpu_idx * slots_per_gpu
end = start + slots_per_gpu
# Segments across all layers for this GPU
old_seg = old_np[:, start:end] # [L, S]
new_seg = phy2log_np[:, start:end] # [L, S]
new_rnk = phyrank_np[:, start:end] # [L, S]
# Experts across all layers for this GPU
old_local = old_phy2log_np[:, start:end] # [layers, slots]
new_local = new_phy2log_np[:, start:end] # [layers, slots]
new_ridx = replicas_idx_np[:, start:end] # [layers, slots]
used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool)
preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool)
# First pass: preserve same-logical experts in their previous slots
for pos in range(slots_per_gpu):
# matches: [L, S], True where new_seg has the same logical value
# as the old slot 'pos' and not used
matches = (new_seg == old_seg[:, pos][:, None]) & (~used_new_indices)
for slot_idx in range(slots_per_gpu):
# matches: [layers, slots], True where new local experts have
# the same logical value as the old from 'slot_idx' and not checked yet
matches = (new_local == old_local[:, slot_idx][:, None]) & (
~used_new_indices
)
has_any = matches.any(axis=1)
if np.any(has_any):
first_idx = np.argmax(matches, axis=1)
layer_indices = np.nonzero(has_any)[0]
matched_new_positions = first_idx[layer_indices]
post_phy2log_np[layer_indices, start + pos] = new_seg[
layer_indices, matched_new_positions
]
post_phyrank_np[layer_indices, start + pos] = new_rnk[
post_phy2log_np[layer_indices, start + slot_idx] = new_local[
layer_indices, matched_new_positions
]
post_phy_replicas_idx_np[layer_indices, start + slot_idx] = (
new_ridx[layer_indices, matched_new_positions]
)
used_new_indices[layer_indices, matched_new_positions] = True
preserved_positions[layer_indices, pos] = True
preserved_positions[layer_indices, slot_idx] = True
# Second pass: fill remaining slots with remaining new experts
remaining_mask = ~used_new_indices # [L, S]
@ -299,17 +295,17 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
continue
src_pos = remaining_indices[layer_idx, :k]
dst_pos = fill_indices[layer_idx, :k]
post_phy2log_np[layer_idx, start + dst_pos] = new_seg[
post_phy2log_np[layer_idx, start + dst_pos] = new_local[
layer_idx, src_pos
]
post_phyrank_np[layer_idx, start + dst_pos] = new_rnk[
post_phy_replicas_idx_np[layer_idx, start + dst_pos] = new_ridx[
layer_idx, src_pos
]
# Convert back to torch and move to original device
post_phy2log = torch.from_numpy(post_phy2log_np).to(device)
post_phyrank = torch.from_numpy(post_phyrank_np).to(device)
return post_phy2log, post_phyrank
post_phy_replicas_idx = torch.from_numpy(post_phy_replicas_idx_np).to(device)
return post_phy2log, post_phy_replicas_idx
@classmethod
def rebalance_experts(
@ -348,12 +344,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
weight = weight.float()
if num_groups % num_nodes == 0:
# use hierarchical load-balance policy
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical(
phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_ranks
)
else:
# use global load-balance policy
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical(
phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical(
weight, num_replicas, 1, 1, num_ranks
)
# Optional postprocessing to preserve slots for experts moving
@ -362,8 +358,8 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
# Helps to avoid unnecessary weight copying when experts move
# within the same GPU.
if old_global_expert_indices is not None:
phy2log, phyrank = cls.preserve_intragpu_slots(
phy2log, phyrank, num_ranks, old_global_expert_indices
phy2log, phy_replicas_idx = cls.preserve_intragpu_slots(
phy2log, phy_replicas_idx, num_ranks, old_global_expert_indices
)
num_redundant_experts = num_replicas - num_logical_experts
maxlogcnt = num_redundant_experts + 1
@ -375,7 +371,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
)
log2phy.view(num_layers, -1).scatter_(
-1,
phy2log * maxlogcnt + phyrank,
phy2log * maxlogcnt + phy_replicas_idx,
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
num_layers, -1
),

View File

@ -361,24 +361,23 @@ def move_from_buffer(
is_received_locally: np.ndarray,
recv_metadata: RecvMetadata,
new_indices: np.ndarray,
ep_group: ProcessGroup,
ep_rank: int,
) -> None:
"""
Copies expert weights from communication buffers back to the target weight tensors
after EPLB rebalancing.
Args:
expert_weights: List of weight tensors for a MoE layer.
expert_weights_buffers: Intermediate buffers matching ``expert_weights``.
expert_weights: List of the actual MoE layer weights used in the execution.
expert_weights_buffers: Intermediate buffers containing the experts weights
after the transfer is completed.
is_unchanged: (num_local_experts,), True where an expert row is unchanged.
is_received_locally: (num_local_experts,), True where a row is updated locally.
recv_metadata: RecvMetadata containing remote receive metadata.
new_indices: (num_experts_total,) mapping from local rows to desired
(possibly global) expert id, after rebalance.
ep_group: torch.distributed.ProcessGroup for expert parallel communication
domain.
ep_rank: Rank of the process in the expert parallel group.
"""
ep_rank = ep_group.rank()
recv_primary_mask = recv_metadata.recv_primary_mask
recv_count = recv_metadata.recv_count
recv_expert_ids = recv_metadata.recv_expert_ids
@ -395,16 +394,17 @@ def move_from_buffer(
for w, b in zip(expert_weights, expert_weights_buffers):
w[dst].copy_(b[dst], non_blocking=True)
# Duplicate remote received rows to non-primary duplicate dsts
if recv_count == 0:
return
# Duplicate remote received rows to non-primary duplicate dsts
base = ep_rank * num_local_experts
local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)]
duplicate_mask = np.logical_and(
np.logical_and(~is_unchanged, ~is_received_locally),
np.logical_and(~recv_primary_mask, local_experts != -1),
)
# All received experts are unique in the destination, so no need to copy duplicates
if not bool(duplicate_mask.any()):
return
@ -607,7 +607,7 @@ def rearrange_expert_weights_inplace(
is_received_locally=is_received_locally,
recv_metadata=recv_metadata,
new_indices=new_global_expert_indices_cpu[layer_idx],
ep_group=ep_group,
ep_rank=ep_group.rank(),
)