mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-23 09:07:03 +08:00
Add preserve expert on the same slot within gpu optimization
Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
parent
a46c72ac71
commit
561b427299
@ -4,7 +4,10 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.distributed.eplb.rebalance_algo import rebalance_experts
|
||||
from vllm.distributed.eplb.rebalance_algo import (
|
||||
preserve_intragpu_slots,
|
||||
rebalance_experts,
|
||||
)
|
||||
|
||||
|
||||
def test_basic_rebalance():
|
||||
@ -306,3 +309,136 @@ if __name__ == "__main__":
|
||||
print(phy2log)
|
||||
|
||||
test_basic_rebalance()
|
||||
|
||||
|
||||
def _make_phyrank_from_phy2log(phy2log: torch.Tensor) -> torch.Tensor:
|
||||
"""Create phyrank from phy2log"""
|
||||
pr = torch.zeros_like(phy2log)
|
||||
for layer in range(phy2log.shape[0]):
|
||||
seen: dict[int, int] = {}
|
||||
row = phy2log[layer].tolist()
|
||||
for i, expert in enumerate(row):
|
||||
r = seen.get(expert, 0)
|
||||
pr[layer, i] = r
|
||||
seen[expert] = r + 1
|
||||
return pr
|
||||
|
||||
|
||||
def _validate_intragpu_rearrangement(
|
||||
old_global_expert_indices: torch.Tensor,
|
||||
new_phy2log: torch.Tensor,
|
||||
new_phyrank: torch.Tensor,
|
||||
post_phy2log: torch.Tensor,
|
||||
post_phyrank: torch.Tensor,
|
||||
num_gpus: int,
|
||||
slots_per_gpu: int,
|
||||
):
|
||||
# Per-GPU checks
|
||||
for gpu_idx in range(num_gpus):
|
||||
start = gpu_idx * slots_per_gpu
|
||||
end = start + slots_per_gpu
|
||||
old_seg = old_global_expert_indices[0, start:end]
|
||||
new_seg = new_phy2log[0, start:end]
|
||||
new_rnk = new_phyrank[0, start:end]
|
||||
post_seg = post_phy2log[0, start:end]
|
||||
post_rnk = post_phyrank[0, start:end]
|
||||
|
||||
# Pairwise equality for (expert, rank) pairs to ensure nothing is lost
|
||||
def sorted_pairs(seg: torch.Tensor, rnk: torch.Tensor):
|
||||
pairs = list(zip(seg.tolist(), rnk.tolist()))
|
||||
pairs.sort()
|
||||
return pairs
|
||||
|
||||
assert sorted_pairs(post_seg, post_rnk) == sorted_pairs(new_seg, new_rnk), (
|
||||
f"Per-GPU pairs of (expert,rank) must match new mapping for GPU {gpu_idx}"
|
||||
)
|
||||
|
||||
# For experts that remain on the same GPU, the old slot is preserved
|
||||
# for at least one occurrence; rank at that slot must be valid for that expert
|
||||
old_list = old_seg.tolist()
|
||||
new_list = new_seg.tolist()
|
||||
post_list = post_seg.tolist()
|
||||
remained = set(old_list) & set(new_list)
|
||||
new_ranks_for_expert: dict[int, list[int]] = {}
|
||||
for v, r in zip(new_list, new_rnk.tolist()):
|
||||
new_ranks_for_expert.setdefault(v, []).append(r)
|
||||
for expert in remained:
|
||||
old_pos = old_list.index(expert)
|
||||
assert post_list[old_pos] == expert, (
|
||||
f"Expert {expert} on GPU {gpu_idx} should stay at old slot {old_pos}"
|
||||
)
|
||||
# Rank at preserved slot must be one of the ranks
|
||||
# the expert has in new mapping
|
||||
assert post_rnk.tolist()[old_pos] in new_ranks_for_expert[expert], (
|
||||
f"Rank for expert {expert} at preserved slot on GPU {gpu_idx} "
|
||||
"must come from new mapping"
|
||||
)
|
||||
|
||||
|
||||
def test_preserve_intragpu_slots_simple():
|
||||
"""Experts that stay on a GPU keep their old slots; incoming not lost."""
|
||||
# Setup: 2 GPUs, 4 slots each, 1 layer
|
||||
num_gpus = 2
|
||||
slots_per_gpu = 4
|
||||
# Old mapping: GPU0 -> [0,1,2,3], GPU1 -> [4,5,6,7]
|
||||
old_global_expert_indices = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]])
|
||||
# New mapping shuffles within GPU0 and brings 4,5 into GPU0.
|
||||
# GPU0 new -> [1,5,0,4] (0 and 1 remain on GPU0 but at different slots)
|
||||
# GPU1 new -> [6,2,7,3] (6 and 7 remain on GPU1, 2 and 3 move in)
|
||||
phy2log = torch.tensor([[1, 5, 0, 4, 6, 2, 7, 3]])
|
||||
# Derive phyrank from replica occurrence order per expert
|
||||
phyrank = _make_phyrank_from_phy2log(phy2log)
|
||||
|
||||
post_phy2log, post_phyrank = preserve_intragpu_slots(
|
||||
phy2log, phyrank, num_gpus, old_global_expert_indices
|
||||
)
|
||||
|
||||
# Shapes preserved
|
||||
assert post_phy2log.shape == phy2log.shape
|
||||
assert post_phyrank.shape == phyrank.shape
|
||||
|
||||
_validate_intragpu_rearrangement(
|
||||
old_global_expert_indices,
|
||||
phy2log,
|
||||
phyrank,
|
||||
post_phy2log,
|
||||
post_phyrank,
|
||||
num_gpus,
|
||||
slots_per_gpu,
|
||||
)
|
||||
|
||||
|
||||
def test_preserve_intragpu_slots_with_duplicates():
|
||||
"""Test preserve intragpu slots with duplicates"""
|
||||
# Setup: 2 GPUs, 5 slots each (total 10 physical experts), 1 layer
|
||||
num_gpus = 2
|
||||
slots_per_gpu = 5
|
||||
# Old mapping:
|
||||
# GPU0 -> [0, 1, 0, 2, 3] (expert 0 duplicated)
|
||||
# GPU1 -> [4, 5, 6, 1, 2]
|
||||
old_global_expert_indices = torch.tensor([[0, 1, 0, 2, 3, 4, 5, 6, 1, 2]])
|
||||
# New mapping reorders within GPUs and moves some experts across GPUs,
|
||||
# while still including duplicates:
|
||||
# GPU0 new -> [0, 5, 4, 0, 1] (expert 0 duplicated, 4/5 incoming)
|
||||
# GPU1 new -> [6, 2, 3, 1, 2] (expert 2 duplicated)
|
||||
phy2log = torch.tensor([[0, 5, 4, 0, 1, 6, 2, 3, 1, 2]])
|
||||
# Derive ranks so duplicates have ranks [0,1,...] by occurrence
|
||||
phyrank = _make_phyrank_from_phy2log(phy2log)
|
||||
|
||||
post_phy2log, post_phyrank = preserve_intragpu_slots(
|
||||
phy2log, phyrank, num_gpus, old_global_expert_indices
|
||||
)
|
||||
|
||||
# Shapes preserved
|
||||
assert post_phy2log.shape == phy2log.shape
|
||||
assert post_phyrank.shape == phyrank.shape
|
||||
|
||||
_validate_intragpu_rearrangement(
|
||||
old_global_expert_indices,
|
||||
phy2log,
|
||||
phyrank,
|
||||
post_phy2log,
|
||||
post_phyrank,
|
||||
num_gpus,
|
||||
slots_per_gpu,
|
||||
)
|
||||
|
||||
@ -795,6 +795,7 @@ class EplbState:
|
||||
num_groups,
|
||||
num_nodes,
|
||||
num_gpus,
|
||||
eplb_model_state.physical_to_logical_map,
|
||||
)
|
||||
|
||||
if not eplb_model_state.is_async_enabled or is_profile:
|
||||
|
||||
@ -197,12 +197,110 @@ def rebalance_experts_hierarchical(
|
||||
return pphy2log, pphyrank, logcnt
|
||||
|
||||
|
||||
def preserve_intragpu_slots(
|
||||
phy2log: torch.Tensor,
|
||||
phyrank: torch.Tensor,
|
||||
num_gpus: int,
|
||||
old_global_expert_indices: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Reorder the new mapping per GPU so that experts that remain on the same GPU
|
||||
keep their previous slot positions when possible. Incoming experts to that GPU
|
||||
fill any remaining available slots. This is applied only when the number of GPUs
|
||||
is unchanged and the slots per GPU remain the same between the old and new mappings.
|
||||
"""
|
||||
device = phy2log.device
|
||||
new_num_phy = phy2log.shape[1]
|
||||
old_num_phy = old_global_expert_indices.shape[1]
|
||||
if (
|
||||
num_gpus <= 0
|
||||
or new_num_phy % num_gpus != 0
|
||||
or old_num_phy % num_gpus != 0
|
||||
or (new_num_phy // num_gpus) != (old_num_phy // num_gpus)
|
||||
):
|
||||
return phy2log, phyrank
|
||||
|
||||
# Move to CPU and convert to NumPy for processing
|
||||
phy2log_np = phy2log.cpu().numpy()
|
||||
phyrank_np = phyrank.cpu().numpy()
|
||||
old_np = old_global_expert_indices.cpu().numpy()
|
||||
|
||||
slots_per_gpu = new_num_phy // num_gpus
|
||||
num_layers = phy2log_np.shape[0]
|
||||
|
||||
post_phy2log_np = phy2log_np.copy()
|
||||
post_phyrank_np = phyrank_np.copy()
|
||||
|
||||
for gpu_idx in range(num_gpus):
|
||||
start = gpu_idx * slots_per_gpu
|
||||
end = start + slots_per_gpu
|
||||
# Segments across all layers for this GPU
|
||||
old_seg = old_np[:, start:end] # [L, S]
|
||||
new_seg = phy2log_np[:, start:end] # [L, S]
|
||||
new_rnk = phyrank_np[:, start:end] # [L, S]
|
||||
|
||||
used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool)
|
||||
preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool)
|
||||
|
||||
# First pass: preserve same-logical experts in their previous slots
|
||||
for pos in range(slots_per_gpu):
|
||||
# matches: [L, S], True where new_seg has the same logical value
|
||||
# as the old slot 'pos' and not used
|
||||
matches = (new_seg == old_seg[:, pos][:, None]) & (~used_new_indices)
|
||||
has_any = matches.any(axis=1)
|
||||
if np.any(has_any):
|
||||
first_idx = np.argmax(matches, axis=1)
|
||||
rows = np.nonzero(has_any)[0]
|
||||
cols = first_idx[rows]
|
||||
post_phy2log_np[rows, start + pos] = new_seg[rows, cols]
|
||||
post_phyrank_np[rows, start + pos] = new_rnk[rows, cols]
|
||||
used_new_indices[rows, cols] = True
|
||||
preserved_positions[rows, pos] = True
|
||||
|
||||
# Second pass: fill remaining slots with remaining new experts
|
||||
remaining_mask = ~used_new_indices # [L, S]
|
||||
fill_mask = ~preserved_positions # [L, S]
|
||||
if remaining_mask.any() and fill_mask.any():
|
||||
idx_base = np.broadcast_to(
|
||||
np.arange(slots_per_gpu), (num_layers, slots_per_gpu)
|
||||
)
|
||||
large = slots_per_gpu + 1
|
||||
remaining_priority = np.where(remaining_mask, idx_base, large)
|
||||
fill_priority = np.where(fill_mask, idx_base, large)
|
||||
# Sort to get per-row ordered indices of True positions
|
||||
remaining_indices = np.argsort(remaining_priority, axis=1)
|
||||
fill_indices = np.argsort(fill_priority, axis=1)
|
||||
# How many to fill per row
|
||||
remaining_counts = remaining_mask.sum(axis=1)
|
||||
fill_counts = fill_mask.sum(axis=1)
|
||||
take_counts = np.minimum(remaining_counts, fill_counts)
|
||||
# Assign per row
|
||||
for layer_idx in range(num_layers):
|
||||
k = int(take_counts[layer_idx])
|
||||
if k <= 0:
|
||||
continue
|
||||
src_pos = remaining_indices[layer_idx, :k]
|
||||
dst_pos = fill_indices[layer_idx, :k]
|
||||
post_phy2log_np[layer_idx, start + dst_pos] = new_seg[
|
||||
layer_idx, src_pos
|
||||
]
|
||||
post_phyrank_np[layer_idx, start + dst_pos] = new_rnk[
|
||||
layer_idx, src_pos
|
||||
]
|
||||
|
||||
# Convert back to torch and move to original device
|
||||
post_phy2log = torch.from_numpy(post_phy2log_np).to(device)
|
||||
post_phyrank = torch.from_numpy(post_phyrank_np).to(device)
|
||||
return post_phy2log, post_phyrank
|
||||
|
||||
|
||||
def rebalance_experts(
|
||||
weight: torch.Tensor,
|
||||
num_replicas: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
num_gpus: int,
|
||||
old_global_expert_indices: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Entry point for expert-parallelism load balancer.
|
||||
@ -239,6 +337,14 @@ def rebalance_experts(
|
||||
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
|
||||
weight, num_replicas, 1, 1, num_gpus
|
||||
)
|
||||
|
||||
# Optional postprocessing to preserve slots for experts moving within the same GPU
|
||||
# Only apply when the number of GPUs and slots per GPU remain unchanged.
|
||||
# Helps to avoid unnecessary weight copying when experts move within the same GPU.
|
||||
if old_global_expert_indices is not None:
|
||||
phy2log, phyrank = preserve_intragpu_slots(
|
||||
phy2log, phyrank, num_gpus, old_global_expert_indices
|
||||
)
|
||||
num_redundant_experts = num_replicas - num_logical_experts
|
||||
maxlogcnt = num_redundant_experts + 1
|
||||
log2phy: torch.Tensor = torch.full(
|
||||
@ -257,4 +363,4 @@ def rebalance_experts(
|
||||
return phy2log, log2phy, logcnt
|
||||
|
||||
|
||||
__all__ = ["rebalance_experts"]
|
||||
__all__ = ["rebalance_experts", "preserve_intragpu_slots"]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user