Add comments

Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
ilmarkov 2025-12-11 12:56:54 +00:00
parent 208b51bcd8
commit a5ecdc18c0

View File

@ -263,31 +263,36 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
has_any = matches.any(axis=1)
if np.any(has_any):
first_idx = np.argmax(matches, axis=1)
rows = np.nonzero(has_any)[0]
cols = first_idx[rows]
post_phy2log_np[rows, start + pos] = new_seg[rows, cols]
post_phyrank_np[rows, start + pos] = new_rnk[rows, cols]
used_new_indices[rows, cols] = True
preserved_positions[rows, pos] = True
layer_indices = np.nonzero(has_any)[0]
matched_new_positions = first_idx[layer_indices]
post_phy2log_np[layer_indices, start + pos] = new_seg[
layer_indices, matched_new_positions
]
post_phyrank_np[layer_indices, start + pos] = new_rnk[
layer_indices, matched_new_positions
]
used_new_indices[layer_indices, matched_new_positions] = True
preserved_positions[layer_indices, pos] = True
# Second pass: fill remaining slots with remaining new experts
remaining_mask = ~used_new_indices # [L, S]
fill_mask = ~preserved_positions # [L, S]
if remaining_mask.any() and fill_mask.any():
idx_base = np.broadcast_to(
np.arange(slots_per_gpu), (num_layers, slots_per_gpu)
)
idx_base = np.tile(np.arange(slots_per_gpu), (num_layers, 1))
# Sentinel value for unavailable positions.
large = slots_per_gpu + 1
# Priorities: keep original index for available spots, set sentinel
# for unavailable; lower is earlier.
remaining_priority = np.where(remaining_mask, idx_base, large)
fill_priority = np.where(fill_mask, idx_base, large)
# Sort to get per-row ordered indices of True positions
# Sort to get ordered indices of available src/dst positions per layer.
remaining_indices = np.argsort(remaining_priority, axis=1)
fill_indices = np.argsort(fill_priority, axis=1)
# How many to fill per row
# Fill count per layer (cannot exceed either side).
remaining_counts = remaining_mask.sum(axis=1)
fill_counts = fill_mask.sum(axis=1)
take_counts = np.minimum(remaining_counts, fill_counts)
# Assign per row
# Assign remaining new experts to remaining slots per layer.
for layer_idx in range(num_layers):
k = int(take_counts[layer_idx])
if k <= 0: