[EPLB] Refactor balance_packing to use numpy and optimize GPU-CPU transfers in EPLB (#28369)

Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
Sage Moore 2025-11-11 00:19:51 -08:00 committed by GitHub
parent 4fd4b743a2
commit 798c7bebca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 37 additions and 17 deletions

View File

@ -12,6 +12,7 @@ Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example
on how the EPLB algorithm works.
"""
import numpy as np
import torch
@ -34,29 +35,44 @@ def balanced_packing(
assert num_groups % num_packs == 0
groups_per_pack = num_groups // num_packs
device = weight.device
if groups_per_pack == 1:
pack_index = torch.arange(
weight.size(-1), dtype=torch.int64, device=weight.device
weight.size(-1), dtype=torch.int64, device=device
).expand(weight.shape)
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64)
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device)
return pack_index, rank_in_pack
indices = weight.float().sort(-1, descending=True).indices.cpu()
pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device="cpu")
rank_in_pack = torch.full_like(pack_index, fill_value=-1)
weight_np = weight.cpu().numpy()
# Sort and get indices in decending order
indices_np = np.argsort(-weight_np, axis=-1)
pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
# Run the packing algorithm
for i in range(num_layers):
pack_weights = [0] * num_packs
pack_weights = [0.0] * num_packs
pack_items = [0] * num_packs
for group in indices[i]:
for group in indices_np[i]:
# Find a pack with capacity that has the lowest weight
pack = min(
(i for i in range(num_packs) if pack_items[i] < groups_per_pack),
(j for j in range(num_packs) if pack_items[j] < groups_per_pack),
key=pack_weights.__getitem__,
)
assert pack_items[pack] < groups_per_pack
pack_index[i, group] = pack
rank_in_pack[i, group] = pack_items[pack]
pack_weights[pack] += weight[i, group]
pack_index_np[i, group] = pack
rank_in_pack_np[i, group] = pack_items[pack]
pack_weights[pack] += weight_np[i, group]
pack_items[pack] += 1
pack_index = torch.from_numpy(pack_index_np).to(device)
rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device)
return pack_index, rank_in_pack
@ -212,7 +228,7 @@ def rebalance_experts(
replicas for each logical expert
"""
num_layers, num_logical_experts = weight.shape
weight = weight.float().cpu()
weight = weight.float()
if num_groups % num_nodes == 0:
# use hierarchical load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(

View File

@ -321,15 +321,19 @@ def rearrange_expert_weights_inplace(
)
return
old_global_expert_indices_cpu = old_global_expert_indices.cpu()
new_global_expert_indices_cpu = new_global_expert_indices.cpu()
# NOTE(bowen): We need this synchronize to run, but I don't know why.
# If you figure out the reason, please let me know -- thank you!
torch.cuda.synchronize()
for layer in range(num_moe_layers):
# NOTE(bowen): We need this synchronize to run, but I don't know why.
# If you figure out the reason, please let me know -- thank you!
torch.cuda.synchronize()
shuffle_layer(
num_local_physical_experts,
ep_rank,
old_global_expert_indices[layer].tolist(),
new_global_expert_indices[layer].tolist(),
old_global_expert_indices_cpu[layer].tolist(),
new_global_expert_indices_cpu[layer].tolist(),
expert_weights[layer],
expert_weights_buffer,
ep_group,