balanced_packing into numpy

Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
ilmarkov 2025-12-15 13:55:41 +00:00
parent 389f86e0c5
commit dcf4783967

View File

@ -40,37 +40,36 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
groups_per_pack = num_groups // num_packs
if groups_per_pack == 1:
pack_index_np = np.tile(
np.arange(num_groups, dtype=np.int64), (num_layers, 1)
)
rank_in_pack_np = np.zeros_like(pack_index_np, dtype=np.int64)
return pack_index_np, rank_in_pack_np
pack_index = np.tile(np.arange(num_groups, dtype=np.int64), (num_layers, 1))
rank_in_pack = np.zeros_like(pack_index, dtype=np.int64)
return pack_index, rank_in_pack
# Sort and get indices in decending order
indices = np.argsort(-weight, axis=-1)
pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
pack_index = np.full((num_layers, num_groups), -1, dtype=np.int64)
rank_in_pack = np.full((num_layers, num_groups), -1, dtype=np.int64)
pack_weights = np.zeros((num_layers, num_packs), dtype=np.float64)
pack_items = np.zeros((num_layers, num_packs), dtype=np.int64)
# Run the packing algorithm
for i in range(num_layers):
pack_weights = [0.0] * num_packs
pack_items = [0] * num_packs
for layer_idx in range(num_layers):
weights_row = pack_weights[layer_idx]
items_row = pack_items[layer_idx]
for group in indices[i]:
# Find a pack with capacity that has the lowest weight
pack = min(
(j for j in range(num_packs) if pack_items[j] < groups_per_pack),
key=pack_weights.__getitem__,
)
for group in indices[layer_idx]:
# Select the lightest pack that still has capacity.
available = items_row < groups_per_pack
assert np.any(available)
pack = int(np.argmin(np.where(available, weights_row, np.inf)))
assert pack_items[pack] < groups_per_pack
pack_index_np[i, group] = pack
rank_in_pack_np[i, group] = pack_items[pack]
pack_weights[pack] += weight[i, group]
pack_items[pack] += 1
pack_index[layer_idx, group] = pack
rank_in_pack[layer_idx, group] = items_row[pack]
weights_row[pack] += weight[layer_idx, group]
items_row[pack] += 1
return pack_index_np, rank_in_pack_np
return pack_index, rank_in_pack
@classmethod
def replicate_experts(