Refactor tests and address nits

Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
ilmarkov 2025-12-15 10:46:20 +00:00
parent 9a41b9130e
commit 7d0ab7d4b9
2 changed files with 71 additions and 64 deletions

View File

@ -312,8 +312,8 @@ if __name__ == "__main__":
test_basic_rebalance()
def _make_phyrank_from_phy2log(phy2log: torch.Tensor) -> torch.Tensor:
"""Create phyrank from phy2log"""
def _make_phy_replicas_idx_from_phy2log(phy2log: torch.Tensor) -> torch.Tensor:
"""Create replicas indices mapping from phy2log"""
pr = torch.zeros_like(phy2log)
for layer in range(phy2log.shape[0]):
seen: dict[int, int] = {}
@ -328,9 +328,9 @@ def _make_phyrank_from_phy2log(phy2log: torch.Tensor) -> torch.Tensor:
def _validate_intragpu_rearrangement(
old_global_expert_indices: torch.Tensor,
new_phy2log: torch.Tensor,
new_phyrank: torch.Tensor,
new_phy_replicas_idx: torch.Tensor,
post_phy2log: torch.Tensor,
post_phyrank: torch.Tensor,
post_phy_replicas_idx: torch.Tensor,
num_ranks: int,
slots_per_gpu: int,
):
@ -340,9 +340,9 @@ def _validate_intragpu_rearrangement(
end = start + slots_per_gpu
old_seg = old_global_expert_indices[0, start:end]
new_seg = new_phy2log[0, start:end]
new_rnk = new_phyrank[0, start:end]
new_rnk = new_phy_replicas_idx[0, start:end]
post_seg = post_phy2log[0, start:end]
post_rnk = post_phyrank[0, start:end]
post_rnk = post_phy_replicas_idx[0, start:end]
# Pairwise equality for (expert, rank) pairs to ensure nothing is lost
def sorted_pairs(seg: torch.Tensor, rnk: torch.Tensor):
@ -376,70 +376,77 @@ def _validate_intragpu_rearrangement(
)
def test_preserve_intragpu_slots_simple():
@pytest.mark.parametrize(
"num_ranks, slots_per_gpu, old_phy2log, new_phy2log",
[
pytest.param(
# Setup: 2 GPUs, 4 slots each, 1 layer
# Old mapping: GPU0 -> [0,1,2,3], GPU1 -> [4,5,6,7]
# New mapping shuffles within GPU0 and brings 4,5 into GPU0.
# GPU0 new -> [1,5,0,4]; GPU1 new -> [6,2,7,3]
2,
4,
torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]),
torch.tensor([[1, 5, 0, 4, 6, 2, 7, 3]]),
id="simple",
),
pytest.param(
# Setup: 2 GPUs, 5 slots each (total 10 physical experts), 1 layer
# Old mapping:
# GPU0 -> [0, 1, 0, 2, 3] (expert 0 duplicated)
# GPU1 -> [4, 5, 6, 1, 2]
# New mapping reorders within GPUs and moves some experts across GPUs,
# while still including duplicates:
# GPU0 new -> [0, 5, 4, 0, 1] (expert 0 duplicated, 4/5 incoming)
# GPU1 new -> [6, 2, 3, 2, 1] (expert 2 duplicated)
2,
5,
torch.tensor([[0, 1, 0, 2, 3, 4, 5, 6, 1, 2]]),
torch.tensor([[0, 5, 4, 0, 1, 6, 2, 3, 2, 1]]),
id="duplicates",
),
pytest.param(
# Setup: 3 GPUs, 4 slots each (total 12 physical experts), 1 layer
# Old mapping:
# GPU0 -> [0, 1, 2, 3]
# GPU1 -> [0, 1, 2, 3]
# GPU2 -> [0, 1, 2, 3]
# New mapping decides to use one expert on 2 GPUs and shuffles
# experts on the third GPU,
# GPU0 new -> [0, 0, 0, 0]
# GPU1 new -> [0, 0, 0, 0]
# GPU2 new -> [1, 2, 3, 0]
3,
4,
torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]]),
torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0]]),
id="skewed_expert",
),
],
)
def test_preserve_intragpu_slots(
num_ranks: int,
slots_per_gpu: int,
old_phy2log: torch.Tensor,
new_phy2log: torch.Tensor,
):
"""Experts that stay on a GPU keep their old slots; incoming not lost."""
# Setup: 2 GPUs, 4 slots each, 1 layer
num_ranks = 2
slots_per_gpu = 4
# Old mapping: GPU0 -> [0,1,2,3], GPU1 -> [4,5,6,7]
old_global_expert_indices = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]])
# New mapping shuffles within GPU0 and brings 4,5 into GPU0.
# GPU0 new -> [1,5,0,4] (0 and 1 remain on GPU0 but at different slots)
# GPU1 new -> [6,2,7,3] (6 and 7 remain on GPU1, 2 and 3 move in)
phy2log = torch.tensor([[1, 5, 0, 4, 6, 2, 7, 3]])
# Derive phyrank from replica occurrence order per expert
phyrank = _make_phyrank_from_phy2log(phy2log)
phy_replicas_idx = _make_phy_replicas_idx_from_phy2log(new_phy2log)
post_phy2log, post_phyrank = DefaultEplbPolicy.preserve_intragpu_slots(
phy2log, phyrank, num_ranks, old_global_expert_indices
post_phy2log, post_phy_replicas_idx = DefaultEplbPolicy.preserve_intragpu_slots(
new_phy2log, phy_replicas_idx, num_ranks, old_phy2log
)
# Shapes preserved
assert post_phy2log.shape == phy2log.shape
assert post_phyrank.shape == phyrank.shape
assert post_phy2log.shape == new_phy2log.shape
assert post_phy_replicas_idx.shape == phy_replicas_idx.shape
_validate_intragpu_rearrangement(
old_global_expert_indices,
phy2log,
phyrank,
old_phy2log,
new_phy2log,
phy_replicas_idx,
post_phy2log,
post_phyrank,
num_ranks,
slots_per_gpu,
)
def test_preserve_intragpu_slots_with_duplicates():
"""Test preserve intragpu slots with duplicates"""
# Setup: 2 GPUs, 5 slots each (total 10 physical experts), 1 layer
num_ranks = 2
slots_per_gpu = 5
# Old mapping:
# GPU0 -> [0, 1, 0, 2, 3] (expert 0 duplicated)
# GPU1 -> [4, 5, 6, 1, 2]
old_global_expert_indices = torch.tensor([[0, 1, 0, 2, 3, 4, 5, 6, 1, 2]])
# New mapping reorders within GPUs and moves some experts across GPUs,
# while still including duplicates:
# GPU0 new -> [0, 5, 4, 0, 1] (expert 0 duplicated, 4/5 incoming)
# GPU1 new -> [6, 2, 3, 1, 2] (expert 2 duplicated)
phy2log = torch.tensor([[0, 5, 4, 0, 1, 6, 2, 3, 1, 2]])
# Derive ranks so duplicates have ranks [0,1,...] by occurrence
phyrank = _make_phyrank_from_phy2log(phy2log)
post_phy2log, post_phyrank = DefaultEplbPolicy.preserve_intragpu_slots(
phy2log, phyrank, num_ranks, old_global_expert_indices
)
# Shapes preserved
assert post_phy2log.shape == phy2log.shape
assert post_phyrank.shape == phyrank.shape
_validate_intragpu_rearrangement(
old_global_expert_indices,
phy2log,
phyrank,
post_phy2log,
post_phyrank,
post_phy_replicas_idx,
num_ranks,
slots_per_gpu,
)

View File

@ -271,8 +271,8 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
preserved_positions[layer_indices, slot_idx] = True
# Second pass: fill remaining slots with remaining new experts
remaining_mask = ~used_new_indices # [L, S]
fill_mask = ~preserved_positions # [L, S]
remaining_mask = ~used_new_indices # [layers, slots]
fill_mask = ~preserved_positions # [layers, slots]
if remaining_mask.any() and fill_mask.any():
idx_base = np.tile(np.arange(slots_per_gpu), (num_layers, 1))
# Sentinel value for unavailable positions.