mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 09:01:40 +08:00
[EPLB] Refactor balance_packing to use numpy and optimize GPU-CPU transfers in EPLB (#28369)
Signed-off-by: Sage Moore <sage@neuralmagic.com>
This commit is contained in:
parent
4fd4b743a2
commit
798c7bebca
@ -12,6 +12,7 @@ Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example
|
||||
on how the EPLB algorithm works.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
@ -34,29 +35,44 @@ def balanced_packing(
|
||||
assert num_groups % num_packs == 0
|
||||
groups_per_pack = num_groups // num_packs
|
||||
|
||||
device = weight.device
|
||||
|
||||
if groups_per_pack == 1:
|
||||
pack_index = torch.arange(
|
||||
weight.size(-1), dtype=torch.int64, device=weight.device
|
||||
weight.size(-1), dtype=torch.int64, device=device
|
||||
).expand(weight.shape)
|
||||
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64)
|
||||
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device)
|
||||
return pack_index, rank_in_pack
|
||||
|
||||
indices = weight.float().sort(-1, descending=True).indices.cpu()
|
||||
pack_index = torch.full_like(weight, fill_value=-1, dtype=torch.int64, device="cpu")
|
||||
rank_in_pack = torch.full_like(pack_index, fill_value=-1)
|
||||
weight_np = weight.cpu().numpy()
|
||||
|
||||
# Sort and get indices in decending order
|
||||
indices_np = np.argsort(-weight_np, axis=-1)
|
||||
|
||||
pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
|
||||
rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
|
||||
|
||||
# Run the packing algorithm
|
||||
for i in range(num_layers):
|
||||
pack_weights = [0] * num_packs
|
||||
pack_weights = [0.0] * num_packs
|
||||
pack_items = [0] * num_packs
|
||||
for group in indices[i]:
|
||||
|
||||
for group in indices_np[i]:
|
||||
# Find a pack with capacity that has the lowest weight
|
||||
pack = min(
|
||||
(i for i in range(num_packs) if pack_items[i] < groups_per_pack),
|
||||
(j for j in range(num_packs) if pack_items[j] < groups_per_pack),
|
||||
key=pack_weights.__getitem__,
|
||||
)
|
||||
|
||||
assert pack_items[pack] < groups_per_pack
|
||||
pack_index[i, group] = pack
|
||||
rank_in_pack[i, group] = pack_items[pack]
|
||||
pack_weights[pack] += weight[i, group]
|
||||
pack_index_np[i, group] = pack
|
||||
rank_in_pack_np[i, group] = pack_items[pack]
|
||||
pack_weights[pack] += weight_np[i, group]
|
||||
pack_items[pack] += 1
|
||||
|
||||
pack_index = torch.from_numpy(pack_index_np).to(device)
|
||||
rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device)
|
||||
|
||||
return pack_index, rank_in_pack
|
||||
|
||||
|
||||
@ -212,7 +228,7 @@ def rebalance_experts(
|
||||
replicas for each logical expert
|
||||
"""
|
||||
num_layers, num_logical_experts = weight.shape
|
||||
weight = weight.float().cpu()
|
||||
weight = weight.float()
|
||||
if num_groups % num_nodes == 0:
|
||||
# use hierarchical load-balance policy
|
||||
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
|
||||
|
||||
@ -321,15 +321,19 @@ def rearrange_expert_weights_inplace(
|
||||
)
|
||||
return
|
||||
|
||||
old_global_expert_indices_cpu = old_global_expert_indices.cpu()
|
||||
new_global_expert_indices_cpu = new_global_expert_indices.cpu()
|
||||
|
||||
# NOTE(bowen): We need this synchronize to run, but I don't know why.
|
||||
# If you figure out the reason, please let me know -- thank you!
|
||||
torch.cuda.synchronize()
|
||||
|
||||
for layer in range(num_moe_layers):
|
||||
# NOTE(bowen): We need this synchronize to run, but I don't know why.
|
||||
# If you figure out the reason, please let me know -- thank you!
|
||||
torch.cuda.synchronize()
|
||||
shuffle_layer(
|
||||
num_local_physical_experts,
|
||||
ep_rank,
|
||||
old_global_expert_indices[layer].tolist(),
|
||||
new_global_expert_indices[layer].tolist(),
|
||||
old_global_expert_indices_cpu[layer].tolist(),
|
||||
new_global_expert_indices_cpu[layer].tolist(),
|
||||
expert_weights[layer],
|
||||
expert_weights_buffer,
|
||||
ep_group,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user