Convert eplb rebalance to numpy

Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
ilmarkov 2025-12-11 16:41:33 +00:00
parent 11c492ae81
commit 389f86e0c5
2 changed files with 113 additions and 115 deletions

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import pytest
import torch
@ -312,9 +313,9 @@ if __name__ == "__main__":
test_basic_rebalance()
def _make_phy_replicas_idx_from_phy2log(phy2log: torch.Tensor) -> torch.Tensor:
"""Create replicas indices mapping from phy2log"""
pr = torch.zeros_like(phy2log)
def _make_phy_replicas_idx_from_phy2log(phy2log: np.ndarray) -> np.ndarray:
"""Create replicas indices mapping from phy2log."""
pr = np.zeros_like(phy2log, dtype=np.int64)
for layer in range(phy2log.shape[0]):
seen: dict[int, int] = {}
row = phy2log[layer].tolist()
@ -326,11 +327,11 @@ def _make_phy_replicas_idx_from_phy2log(phy2log: torch.Tensor) -> torch.Tensor:
def _validate_intragpu_rearrangement(
old_global_expert_indices: torch.Tensor,
new_phy2log: torch.Tensor,
new_phy_replicas_idx: torch.Tensor,
post_phy2log: torch.Tensor,
post_phy_replicas_idx: torch.Tensor,
old_global_expert_indices: np.ndarray,
new_phy2log: np.ndarray,
new_phy_replicas_idx: np.ndarray,
post_phy2log: np.ndarray,
post_phy_replicas_idx: np.ndarray,
num_ranks: int,
slots_per_gpu: int,
):
@ -345,7 +346,7 @@ def _validate_intragpu_rearrangement(
post_rnk = post_phy_replicas_idx[0, start:end]
# Pairwise equality for (expert, rank) pairs to ensure nothing is lost
def sorted_pairs(seg: torch.Tensor, rnk: torch.Tensor):
def sorted_pairs(seg, rnk):
pairs = list(zip(seg.tolist(), rnk.tolist()))
pairs.sort()
return pairs
@ -386,8 +387,8 @@ def _validate_intragpu_rearrangement(
# GPU0 new -> [1,5,0,4]; GPU1 new -> [6,2,7,3]
2,
4,
torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]),
torch.tensor([[1, 5, 0, 4, 6, 2, 7, 3]]),
np.array([[0, 1, 2, 3, 4, 5, 6, 7]]),
np.array([[1, 5, 0, 4, 6, 2, 7, 3]]),
id="simple",
),
pytest.param(
@ -401,8 +402,8 @@ def _validate_intragpu_rearrangement(
# GPU1 new -> [6, 2, 3, 2, 1] (expert 2 duplicated)
2,
5,
torch.tensor([[0, 1, 0, 2, 3, 4, 5, 6, 1, 2]]),
torch.tensor([[0, 5, 4, 0, 1, 6, 2, 3, 2, 1]]),
np.array([[0, 1, 0, 2, 3, 4, 5, 6, 1, 2]]),
np.array([[0, 5, 4, 0, 1, 6, 2, 3, 2, 1]]),
id="duplicates",
),
pytest.param(
@ -418,8 +419,8 @@ def _validate_intragpu_rearrangement(
# GPU2 new -> [1, 2, 3, 0]
3,
4,
torch.tensor([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]]),
torch.tensor([[0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0]]),
np.array([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]]),
np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0]]),
id="skewed_expert",
),
],

View File

@ -21,8 +21,8 @@ from .abstract import AbstractEplbPolicy
class DefaultEplbPolicy(AbstractEplbPolicy):
@classmethod
def balanced_packing(
cls, weight: torch.Tensor, num_packs: int
) -> tuple[torch.Tensor, torch.Tensor]:
cls, weight: np.ndarray, num_packs: int
) -> tuple[np.ndarray, np.ndarray]:
"""
Pack n weighted objects to m packs, such that each bin contains exactly
n/m objects and the weights of all packs are as balanced as possible.
@ -39,19 +39,15 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
assert num_groups % num_packs == 0
groups_per_pack = num_groups // num_packs
device = weight.device
if groups_per_pack == 1:
pack_index = torch.arange(
weight.size(-1), dtype=torch.int64, device=device
).expand(weight.shape)
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device)
return pack_index, rank_in_pack
weight_np = weight.cpu().numpy()
pack_index_np = np.tile(
np.arange(num_groups, dtype=np.int64), (num_layers, 1)
)
rank_in_pack_np = np.zeros_like(pack_index_np, dtype=np.int64)
return pack_index_np, rank_in_pack_np
# Sort and get indices in decending order
indices_np = np.argsort(-weight_np, axis=-1)
indices = np.argsort(-weight, axis=-1)
pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
@ -61,7 +57,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
pack_weights = [0.0] * num_packs
pack_items = [0] * num_packs
for group in indices_np[i]:
for group in indices[i]:
# Find a pack with capacity that has the lowest weight
pack = min(
(j for j in range(num_packs) if pack_items[j] < groups_per_pack),
@ -71,18 +67,15 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
assert pack_items[pack] < groups_per_pack
pack_index_np[i, group] = pack
rank_in_pack_np[i, group] = pack_items[pack]
pack_weights[pack] += weight_np[i, group]
pack_weights[pack] += weight[i, group]
pack_items[pack] += 1
pack_index = torch.from_numpy(pack_index_np).to(device)
rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device)
return pack_index, rank_in_pack
return pack_index_np, rank_in_pack_np
@classmethod
def replicate_experts(
cls, weight: torch.Tensor, num_phy: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
cls, weight: np.ndarray, num_phy: int
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
load of all replicas is minimized.
@ -99,13 +92,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
n, num_log = weight.shape
num_redundant = num_phy - num_log
assert num_redundant >= 0
device = weight.device
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
replica_idx = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
arangen = torch.arange(n, dtype=torch.int64, device=device)
phy2log = np.tile(np.arange(num_phy, dtype=np.int64), (n, 1))
replica_idx = np.zeros((n, num_phy), dtype=np.int64)
logcnt = np.ones((n, num_log), dtype=np.int64)
arangen = np.arange(n, dtype=np.int64)
for i in range(num_log, num_phy):
redundant_indices = (weight / logcnt).max(dim=-1).indices
redundant_indices = np.argmax(weight / logcnt, axis=-1)
phy2log[:, i] = redundant_indices
replica_idx[:, i] = logcnt[arangen, redundant_indices]
logcnt[arangen, redundant_indices] += 1
@ -114,12 +106,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
@classmethod
def rebalance_experts_hierarchical(
cls,
weight: torch.Tensor,
weight: np.ndarray,
num_physical_experts: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Parameters:
weight: [num_moe_layers, num_logical_experts]
@ -146,35 +138,33 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
assert num_physical_experts % num_gpus == 0
phy_experts_per_gpu = num_physical_experts // num_gpus
def inverse(perm: torch.Tensor) -> torch.Tensor:
inv = torch.empty_like(perm)
inv.scatter_(
1,
perm,
torch.arange(
perm.size(1), dtype=torch.int64, device=perm.device
).expand(perm.shape),
)
def inverse(perm: np.ndarray) -> np.ndarray:
inv = np.empty_like(perm)
row_idx = np.arange(perm.shape[0])[:, None]
col_idx = np.arange(perm.shape[1], dtype=np.int64)
inv[row_idx, perm] = col_idx
return inv
# Step 1: pack groups to nodes
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
tokens_per_group = weight.reshape(num_layers, num_groups, group_size).sum(
axis=-1
)
group_pack_index, group_rank_in_pack = cls.balanced_packing(
tokens_per_group, num_nodes
)
# Map each logical expert into a node-local ordering based on packed groups.
log2mlog = (
(
(group_pack_index * groups_per_node + group_rank_in_pack) * group_size
).unsqueeze(-1)
+ torch.arange(
group_size, dtype=torch.int64, device=group_pack_index.device
(group_pack_index * groups_per_node + group_rank_in_pack)[..., None]
* group_size
)
).flatten(-2)
+ np.arange(group_size, dtype=np.int64)
).reshape(num_layers, num_logical_experts)
mlog2log = inverse(log2mlog)
# Step 2: construct redundant experts within nodes
# [num_layers * num_nodes, num_logical_experts // num_nodes]
tokens_per_mlog = weight.gather(-1, mlog2log).view(
# Reorder weights into the node-local layout so replication is done per node.
tokens_per_mlog = np.take_along_axis(weight, mlog2log, axis=1).reshape(
-1, num_logical_experts // num_nodes
)
phy2mlog, replicas_idx, mlogcnt = cls.replicate_experts(
@ -182,39 +172,43 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
)
# Step 3: pack physical_experts to GPUs
# [num_layers * num_nodes, num_physical_experts // num_nodes]
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
# Effective per-physical load = logical load divided by replica count.
tokens_per_phy = np.take_along_axis(tokens_per_mlog / mlogcnt, phy2mlog, axis=1)
pack_index, rank_in_pack = cls.balanced_packing(
tokens_per_phy, num_gpus // num_nodes
)
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
pphy2phy = inverse(phy2pphy)
pphy2mlog = phy2mlog.gather(
-1, pphy2phy
) # [num_layers * num_nodes, num_log_per_nodes]
# Reorder node-local logical indices into the post-packing physical order.
pphy2mlog = np.take_along_axis(phy2mlog, pphy2phy, axis=1)
pphy2mlog = (
pphy2mlog.view(num_layers, num_nodes, -1)
+ torch.arange(
pphy2mlog.reshape(num_layers, num_nodes, -1)
+ np.arange(
0,
num_logical_experts,
num_logical_experts // num_nodes,
device=group_pack_index.device,
).view(1, -1, 1)
).flatten(-2)
pphy2log = mlog2log.gather(-1, pphy2mlog)
pphy_replicas_idx = replicas_idx.gather(-1, pphy2phy).view(num_layers, -1)
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
dtype=np.int64,
)[None, :, None]
).reshape(num_layers, -1)
# Map node-local logical indices back to global logical expert ids.
pphy2log = np.take_along_axis(mlog2log, pphy2mlog, axis=1)
# Reorder replica ranks to the post-packing physical ordering.
pphy_replicas_idx = np.take_along_axis(replicas_idx, pphy2phy, axis=1).reshape(
num_layers, -1
)
# Convert replica counts back to the original logical ordering.
logcnt = np.take_along_axis(mlogcnt.reshape(num_layers, -1), log2mlog, axis=1)
return pphy2log, pphy_replicas_idx, logcnt
@classmethod
def preserve_intragpu_slots(
cls,
phy2log: torch.Tensor,
phy_replicas_idx: torch.Tensor,
phy2log: np.ndarray,
phy_replicas_idx: np.ndarray,
num_ranks: int,
old_global_expert_indices: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
old_phy2log: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""
Reorder the new mapping per GPU so that experts that remain on the same GPU
keep their previous slot positions when possible. Incoming experts to that GPU
@ -222,29 +216,24 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
is unchanged and the slots per GPU remain the same between
the old and new mappings.
"""
device = phy2log.device
num_phy_experts = phy2log.shape[1]
if num_ranks <= 0 or num_phy_experts % num_ranks != 0:
return phy2log, phy_replicas_idx
# Move to CPU and convert to NumPy for processing
new_phy2log_np = phy2log.cpu().numpy()
replicas_idx_np = phy_replicas_idx.cpu().numpy()
old_phy2log_np = old_global_expert_indices.cpu().numpy()
slots_per_gpu = num_phy_experts // num_ranks
num_layers = new_phy2log_np.shape[0]
num_layers = phy2log.shape[0]
post_phy2log_np = new_phy2log_np.copy()
post_phy_replicas_idx_np = replicas_idx_np.copy()
post_phy2log = phy2log.copy()
post_phy_replicas_idx = phy_replicas_idx.copy()
for gpu_idx in range(num_ranks):
start = gpu_idx * slots_per_gpu
end = start + slots_per_gpu
# Experts across all layers for this GPU
old_local = old_phy2log_np[:, start:end] # [layers, slots]
new_local = new_phy2log_np[:, start:end] # [layers, slots]
new_ridx = replicas_idx_np[:, start:end] # [layers, slots]
old_local = old_phy2log[:, start:end] # [layers, slots]
new_local = phy2log[:, start:end] # [layers, slots]
new_ridx = phy_replicas_idx[:, start:end] # [layers, slots]
used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool)
preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool)
@ -261,12 +250,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
first_idx = np.argmax(matches, axis=1)
layer_indices = np.nonzero(has_any)[0]
matched_new_positions = first_idx[layer_indices]
post_phy2log_np[layer_indices, start + slot_idx] = new_local[
post_phy2log[layer_indices, start + slot_idx] = new_local[
layer_indices, matched_new_positions
]
post_phy_replicas_idx[layer_indices, start + slot_idx] = new_ridx[
layer_indices, matched_new_positions
]
post_phy_replicas_idx_np[layer_indices, start + slot_idx] = (
new_ridx[layer_indices, matched_new_positions]
)
used_new_indices[layer_indices, matched_new_positions] = True
preserved_positions[layer_indices, slot_idx] = True
@ -295,16 +284,13 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
continue
src_pos = remaining_indices[layer_idx, :k]
dst_pos = fill_indices[layer_idx, :k]
post_phy2log_np[layer_idx, start + dst_pos] = new_local[
post_phy2log[layer_idx, start + dst_pos] = new_local[
layer_idx, src_pos
]
post_phy_replicas_idx_np[layer_idx, start + dst_pos] = new_ridx[
post_phy_replicas_idx[layer_idx, start + dst_pos] = new_ridx[
layer_idx, src_pos
]
# Convert back to torch and move to original device
post_phy2log = torch.from_numpy(post_phy2log_np).to(device)
post_phy_replicas_idx = torch.from_numpy(post_phy_replicas_idx_np).to(device)
return post_phy2log, post_phy_replicas_idx
@classmethod
@ -340,40 +326,51 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
logcnt: [layers, num_logical_experts], number of
physical replicas for each logical expert
"""
device = weight.device
num_layers, num_logical_experts = weight.shape
weight = weight.float()
weight_np = weight.float().cpu().numpy()
old_phy2log_np = (
old_global_expert_indices.cpu().numpy()
if old_global_expert_indices is not None
else None
)
if num_groups % num_nodes == 0:
# use hierarchical load-balance policy
phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_ranks
phy2log_np, phy_replicas_idx_np, logcnt_np = (
cls.rebalance_experts_hierarchical(
weight_np, num_replicas, num_groups, num_nodes, num_ranks
)
)
else:
# use global load-balance policy
phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical(
weight, num_replicas, 1, 1, num_ranks
phy2log_np, phy_replicas_idx_np, logcnt_np = (
cls.rebalance_experts_hierarchical(
weight_np, num_replicas, 1, 1, num_ranks
)
)
# Optional postprocessing to preserve slots for experts moving
# within the same GPU
# Only apply when the number of GPUs and slots per GPU remain unchanged.
# Helps to avoid unnecessary weight copying when experts move
# within the same GPU.
if old_global_expert_indices is not None:
phy2log, phy_replicas_idx = cls.preserve_intragpu_slots(
phy2log, phy_replicas_idx, num_ranks, old_global_expert_indices
phy2log_np, phy_replicas_idx_np = cls.preserve_intragpu_slots(
phy2log_np, phy_replicas_idx_np, num_ranks, old_phy2log_np
)
num_redundant_experts = num_replicas - num_logical_experts
maxlogcnt = num_redundant_experts + 1
log2phy: torch.Tensor = torch.full(
(num_layers, num_logical_experts, maxlogcnt),
-1,
dtype=torch.int64,
device=logcnt.device,
log2phy_np = np.full(
(num_layers, num_logical_experts, maxlogcnt), -1, dtype=np.int64
)
log2phy.view(num_layers, -1).scatter_(
-1,
phy2log * maxlogcnt + phy_replicas_idx,
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
num_layers, -1
),
layer_indices = np.arange(num_layers)[:, None]
replica_indices = np.tile(
np.arange(num_replicas, dtype=np.int64), (num_layers, 1)
)
log2phy_np[layer_indices, phy2log_np, phy_replicas_idx_np] = replica_indices
phy2log = torch.from_numpy(phy2log_np).to(device)
log2phy = torch.from_numpy(log2phy_np).to(device)
logcnt = torch.from_numpy(logcnt_np).to(device)
return phy2log, log2phy, logcnt