Merge 5cd45c646f53923fc5b5cc046aa044c1ce94aa08 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29

This commit is contained in:
Ilya Markov 2025-12-25 00:06:34 +00:00 committed by GitHub
commit 74ee657332
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 943 additions and 409 deletions

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import numpy as np
import pytest
import torch
@ -310,3 +311,143 @@ if __name__ == "__main__":
print(phy2log)
test_basic_rebalance()
def _make_phy_replicas_idx_from_phy2log(phy2log: np.ndarray) -> np.ndarray:
"""Create replicas indices mapping from phy2log."""
pr = np.zeros_like(phy2log, dtype=np.int64)
for layer in range(phy2log.shape[0]):
seen: dict[int, int] = {}
row = phy2log[layer].tolist()
for i, expert in enumerate(row):
r = seen.get(expert, 0)
pr[layer, i] = r
seen[expert] = r + 1
return pr
def _validate_intragpu_rearrangement(
old_global_expert_indices: np.ndarray,
new_phy2log: np.ndarray,
new_phy_replicas_idx: np.ndarray,
post_phy2log: np.ndarray,
post_phy_replicas_idx: np.ndarray,
num_ranks: int,
slots_per_gpu: int,
):
# Per-GPU checks
for gpu_idx in range(num_ranks):
start = gpu_idx * slots_per_gpu
end = start + slots_per_gpu
old_seg = old_global_expert_indices[0, start:end]
new_seg = new_phy2log[0, start:end]
new_rnk = new_phy_replicas_idx[0, start:end]
post_seg = post_phy2log[0, start:end]
post_rnk = post_phy_replicas_idx[0, start:end]
# Pairwise equality for (expert, rank) pairs to ensure nothing is lost
def sorted_pairs(seg, rnk):
pairs = list(zip(seg.tolist(), rnk.tolist()))
pairs.sort()
return pairs
assert sorted_pairs(post_seg, post_rnk) == sorted_pairs(new_seg, new_rnk), (
f"Per-GPU pairs of (expert,rank) must match new mapping for GPU {gpu_idx}"
)
# For experts that remain on the same GPU, the old slot is preserved
# for at least one occurrence; rank at that slot must be valid for that expert
old_list = old_seg.tolist()
new_list = new_seg.tolist()
post_list = post_seg.tolist()
remained = set(old_list) & set(new_list)
new_ranks_for_expert: dict[int, list[int]] = {}
for v, r in zip(new_list, new_rnk.tolist()):
new_ranks_for_expert.setdefault(v, []).append(r)
for expert in remained:
old_pos = old_list.index(expert)
assert post_list[old_pos] == expert, (
f"Expert {expert} on GPU {gpu_idx} should stay at old slot {old_pos}"
)
# Rank at preserved slot must be one of the ranks
# the expert has in new mapping
assert post_rnk.tolist()[old_pos] in new_ranks_for_expert[expert], (
f"Rank for expert {expert} at preserved slot on GPU {gpu_idx} "
"must come from new mapping"
)
@pytest.mark.parametrize(
"num_ranks, slots_per_gpu, old_phy2log, new_phy2log",
[
pytest.param(
# Setup: 2 GPUs, 4 slots each, 1 layer
# Old mapping: GPU0 -> [0,1,2,3], GPU1 -> [4,5,6,7]
# New mapping shuffles within GPU0 and brings 4,5 into GPU0.
# GPU0 new -> [1,5,0,4]; GPU1 new -> [6,2,7,3]
2,
4,
np.array([[0, 1, 2, 3, 4, 5, 6, 7]]),
np.array([[1, 5, 0, 4, 6, 2, 7, 3]]),
id="simple",
),
pytest.param(
# Setup: 2 GPUs, 5 slots each (total 10 physical experts), 1 layer
# Old mapping:
# GPU0 -> [0, 1, 0, 2, 3] (expert 0 duplicated)
# GPU1 -> [4, 5, 6, 1, 2]
# New mapping reorders within GPUs and moves some experts across GPUs,
# while still including duplicates:
# GPU0 new -> [0, 5, 4, 0, 1] (expert 0 duplicated, 4/5 incoming)
# GPU1 new -> [6, 2, 3, 2, 1] (expert 2 duplicated)
2,
5,
np.array([[0, 1, 0, 2, 3, 4, 5, 6, 1, 2]]),
np.array([[0, 5, 4, 0, 1, 6, 2, 3, 2, 1]]),
id="duplicates",
),
pytest.param(
# Setup: 3 GPUs, 4 slots each (total 12 physical experts), 1 layer
# Old mapping:
# GPU0 -> [0, 1, 2, 3]
# GPU1 -> [0, 1, 2, 3]
# GPU2 -> [0, 1, 2, 3]
# New mapping decides to use one expert on 2 GPUs and shuffles
# experts on the third GPU,
# GPU0 new -> [0, 0, 0, 0]
# GPU1 new -> [0, 0, 0, 0]
# GPU2 new -> [1, 2, 3, 0]
3,
4,
np.array([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]]),
np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0]]),
id="skewed_expert",
),
],
)
def test_preserve_intragpu_slots(
num_ranks: int,
slots_per_gpu: int,
old_phy2log: torch.Tensor,
new_phy2log: torch.Tensor,
):
"""Experts that stay on a GPU keep their old slots; incoming not lost."""
phy_replicas_idx = _make_phy_replicas_idx_from_phy2log(new_phy2log)
post_phy2log, post_phy_replicas_idx = DefaultEplbPolicy.preserve_intragpu_slots(
new_phy2log, phy_replicas_idx, num_ranks, old_phy2log
)
# Shapes preserved
assert post_phy2log.shape == new_phy2log.shape
assert post_phy_replicas_idx.shape == phy_replicas_idx.shape
_validate_intragpu_rearrangement(
old_phy2log,
new_phy2log,
phy_replicas_idx,
post_phy2log,
post_phy_replicas_idx,
num_ranks,
slots_per_gpu,
)

View File

@ -286,32 +286,32 @@ def _test_async_transfer_layer_without_mtp_worker(
device,
old_indices,
)
old_indices_cpu = old_indices.cpu()
new_indices_cpu = new_indices.cpu()
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
cuda_stream = torch.cuda.Stream(device=device)
for layer_idx in range(num_layers):
is_unchanged, is_received_locally, experts_recv_loc = asyncio.run(
is_unchanged, is_received_locally, recv_metadata = asyncio.run(
transfer_layer(
old_global_expert_indices=old_indices,
new_global_expert_indices=new_indices,
expert_weights=expert_weights,
old_layer_indices=old_indices_cpu[layer_idx],
new_layer_indices=new_indices_cpu[layer_idx],
expert_weights=expert_weights[layer_idx],
expert_weights_buffer=expert_buffer,
ep_group=ep_group,
layer=layer_idx,
cuda_stream=cuda_stream,
)
)
cuda_stream.synchronize()
move_from_buffer(
expert_weights=expert_weights[layer_idx],
expert_weights_buffer=expert_buffer,
expert_weights_buffers=expert_buffer,
is_unchanged=is_unchanged,
is_received_locally=is_received_locally,
experts_recv_loc=experts_recv_loc,
new_indices=new_indices[layer_idx].tolist(),
ep_group=ep_group,
recv_metadata=recv_metadata,
new_indices=new_indices_cpu[layer_idx],
ep_rank=ep_group.rank(),
)
verify_expert_weights_after_shuffle(

View File

@ -69,6 +69,10 @@ class EPLBConfig:
Log the balancedness each step of expert parallelism.
This is turned off by default since it will cause communication overhead.
"""
log_balancedness_interval: int = 1
"""
Interval for logging the balancedness.
"""
use_async: bool = False
"""
Whether to use non-blocking EPLB.

View File

@ -17,7 +17,7 @@ from vllm.logger import init_logger
from .rebalance_execute import transfer_layer
if TYPE_CHECKING:
from .eplb_state import EplbState
from .eplb_state import EplbModelState, EplbState
logger = init_logger(__name__)
@ -57,6 +57,44 @@ def start_async_worker(
return thread
def run_rebalance_experts(
model_state: "EplbModelState",
eplb_state: "EplbState",
) -> None:
assert model_state.eplb_stats is not None
eplb_stats = model_state.eplb_stats
# Move the global expert load window to CPU for computation.
# It has to be done in the main stream to avoid race condition
# with the main thread.
global_expert_load_window = eplb_stats.global_expert_load_window.cpu()
# Compute new expert mappings for the model
(
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = eplb_state.policy.rebalance_experts(
global_expert_load_window,
eplb_stats.num_replicas,
eplb_stats.num_groups,
eplb_stats.num_nodes,
eplb_stats.num_gpus,
model_state.physical_to_logical_map,
)
assert new_physical_to_logical_map.device == torch.device("cpu")
model_state.new_physical_to_logical_map = new_physical_to_logical_map
max_slots = model_state.logical_to_physical_map.shape[-1]
padded_logical = torch.nn.functional.pad(
new_logical_to_physical_map,
(0, max(0, max_slots - new_logical_to_physical_map.shape[-1])),
value=-1,
).to(model_state.logical_to_physical_map.device)
new_replica = new_logical_replica_count.to(model_state.logical_replica_count.device)
model_state.new_logical_to_physical_map = padded_logical
model_state.new_logical_replica_count = new_replica
async def transfer_run_periodically(
state: "EplbState",
ep_group: ProcessGroup,
@ -71,33 +109,46 @@ async def transfer_run_periodically(
for model_state in state.model_states.values():
if not model_state.is_async_enabled:
continue
rebalancing_algorithm_executed = False
logger.info(
"Async worker computed new indices for model %s",
model_state.model_name,
)
current_num_layers = model_state.model.num_moe_layers
while (
model_state.rebalanced
and model_state.layer_to_transfer < current_num_layers
):
if (
not model_state.ep_buffer_ready
and model_state.rebalanced
and model_state.new_physical_to_logical_map is not None
):
if not model_state.ep_buffer_ready and model_state.rebalanced:
await asyncio.to_thread(model_state.buffer_lock.acquire)
try:
if model_state.layer_to_transfer >= current_num_layers:
break
if not rebalancing_algorithm_executed:
run_rebalance_experts(model_state, state)
rebalancing_algorithm_executed = True
assert model_state.new_physical_to_logical_map is not None
layer_idx = model_state.layer_to_transfer
old_layer_indices = model_state.physical_to_logical_map[
layer_idx
]
new_layer_indices = model_state.new_physical_to_logical_map[
layer_idx
]
(
model_state.is_unchanged,
model_state.is_received_locally,
model_state.experts_recv_loc,
model_state.recv_metadata,
) = await transfer_layer(
old_global_expert_indices=model_state.physical_to_logical_map,
new_global_expert_indices=model_state.new_physical_to_logical_map,
expert_weights=model_state.model.expert_weights,
old_layer_indices=old_layer_indices,
new_layer_indices=new_layer_indices,
expert_weights=model_state.model.expert_weights[layer_idx],
expert_weights_buffer=model_state.expert_buffer,
ep_group=ep_group,
is_profile=is_profile,
layer=model_state.layer_to_transfer,
cuda_stream=cuda_stream,
rank_mapping=rank_mapping,
)

View File

@ -27,10 +27,10 @@ physical experts.
"""
import threading
import time
from collections.abc import Sequence
from dataclasses import dataclass
import numpy as np
import torch
from torch.distributed import ProcessGroup, all_reduce
@ -46,11 +46,44 @@ from vllm.model_executor.models.interfaces import MixtureOfExperts
from .async_worker import start_async_worker
from .policy import EPLB_POLICIES, AbstractEplbPolicy, DefaultEplbPolicy
from .rebalance_execute import move_from_buffer, rearrange_expert_weights_inplace
from .rebalance_execute import (
RecvMetadata,
move_from_buffer,
rearrange_expert_weights_inplace,
)
logger = init_logger(__name__)
@dataclass
class EplbStats:
"""
Model stats used in EPLB rebalanding algorithm.
"""
global_expert_load_window: torch.Tensor
"""
Experts load window.
Shape: (window_size, num_moe_layers, num_physical_experts)
"""
num_replicas: int
"""
Number of physical experts.
"""
num_groups: int
"""
Number of expert groups.
"""
num_nodes: int
"""
Number of nodes.
"""
num_gpus: int
"""
Number of GPUs.
"""
@dataclass
class EplbModelState:
"""EPLB metrics."""
@ -164,20 +197,23 @@ class EplbModelState:
"""
Whether the async EPLB needs to poll peers for buffer readiness.
"""
is_unchanged: list[bool]
eplb_stats: EplbStats | None
"""
EPLB stats for the model.
"""
is_unchanged: np.ndarray
"""
intermediate variable between `move_to_buffer` and `move_to_workspace`.
The size is same as the num of physical experts in the current layer.
"""
is_received_locally: list[bool]
is_received_locally: np.ndarray
"""
intermediate variable between `move_to_buffer` and `move_to_workspace`.
The size is same as the num of physical experts in the current layer.
"""
experts_recv_loc: dict[int, int]
recv_metadata: RecvMetadata
"""
intermediate variable between `move_to_buffer` and `move_to_workspace`.
The size is same as the num of physical experts in the current layer.
"""
is_async_enabled: bool
"""
@ -507,9 +543,15 @@ class EplbState:
layer_to_transfer=0,
rebalanced=False,
pending_global_ready_check=False,
is_unchanged=[],
is_received_locally=[],
experts_recv_loc={},
eplb_stats=None,
is_unchanged=np.array([]),
is_received_locally=np.array([]),
recv_metadata=RecvMetadata(
recv_primary_mask=np.array([]),
recv_count=0,
recv_expert_ids=np.array([]),
recv_dst_rows=np.array([]),
),
is_async_enabled=self.is_async,
cuda_device_index=self.cuda_device_index,
new_physical_to_logical_map=new_physical_to_logical_map,
@ -553,7 +595,12 @@ class EplbState:
for eplb_model_state in self.model_states.values():
eplb_model_state.expert_load_pass.zero_()
if log_stats:
if (
log_stats
and self.expert_rearrangement_step
% self.parallel_config.eplb_config.log_balancedness_interval
== 0
):
# Sync the expert load pass for each model (main and drafter).
# expert_load_pass: (num_moe_layers, num_physical_experts)
expert_load_pass_list = self._sync_load_pass()
@ -585,9 +632,10 @@ class EplbState:
if ep_group.rank() == 0:
logger.info(
"EPLB step: %d for model %s: avg_tokens=%.2f, "
"EPLB step: %d/%d for model %s: avg_tokens=%.2f, "
"max_tokens=%d, balancedness=%.4f",
self.expert_rearrangement_step,
self.expert_rearrangement_step_interval,
eplb_model_state.model_name,
avg_tokens,
max_tokens,
@ -684,11 +732,14 @@ class EplbState:
ep_group = get_ep_group().device_group
ep_rank = ep_group.rank()
time_start = None
start_event = None
end_event = None
is_main_rank = ep_rank == 0
if is_main_rank:
torch.cuda.synchronize()
time_start = time.perf_counter()
if not self.is_async or is_profile:
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record()
logger.info(
"Rearranging experts %s %s...",
"(async mode)" if self.is_async else "sync mode",
@ -789,20 +840,21 @@ class EplbState:
for eplb_model_state, global_expert_load_window in zip(
self.model_states.values(), global_expert_load_windows
):
# Get new expert mappings for the model
(
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = self.policy.rebalance_experts(
global_expert_load_window,
num_replicas,
num_groups,
num_nodes,
num_gpus,
)
if not eplb_model_state.is_async_enabled or is_profile:
# Get new expert mappings for the model
(
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = self.policy.rebalance_experts(
global_expert_load_window,
num_replicas,
num_groups,
num_nodes,
num_gpus,
eplb_model_state.physical_to_logical_map,
)
# Update expert weights
rearrange_expert_weights_inplace(
eplb_model_state.physical_to_logical_map,
@ -848,35 +900,29 @@ class EplbState:
new_logical_replica_count
)
if is_main_rank:
assert time_start is not None
torch.cuda.synchronize()
time_end = time.perf_counter()
assert start_event is not None
assert end_event is not None
end_event.record()
end_event.synchronize()
gpu_elapsed = start_event.elapsed_time(end_event) / 1000.0
logger.info(
"Rearranged experts%sin %.2f seconds.",
"Rearranged experts %s in %.2f s.",
" (profile) " if is_profile else " ",
time_end - time_start,
gpu_elapsed,
)
else:
device = eplb_model_state.physical_to_logical_map.device
new_physical = new_physical_to_logical_map.to(device)
max_slots = eplb_model_state.logical_to_physical_map.shape[-1]
padded_logical = torch.nn.functional.pad(
new_logical_to_physical_map,
(0, max(0, max_slots - new_logical_to_physical_map.shape[-1])),
value=-1,
).to(eplb_model_state.logical_to_physical_map.device)
new_replica = new_logical_replica_count.to(
eplb_model_state.logical_replica_count.device
eplb_model_state.eplb_stats = EplbStats(
# We copy the tensor to snapshot the workload on the main
# thread to be used on the async thread.
global_expert_load_window=global_expert_load_window.clone(),
num_replicas=num_replicas,
num_groups=num_groups,
num_nodes=num_nodes,
num_gpus=num_gpus,
)
eplb_model_state.new_physical_to_logical_map = new_physical
eplb_model_state.new_logical_to_physical_map = padded_logical
eplb_model_state.new_logical_replica_count = new_replica
eplb_model_state.rebalanced = True
eplb_model_state.layer_to_transfer = 0
eplb_model_state.pending_global_ready_check = True
# Signal async thread to start transferring layers
if self.is_async and (not is_profile):
self.rearrange_event.set()
@ -908,11 +954,13 @@ class EplbState:
target_device = model_state.physical_to_logical_map.device
new_physical = model_state.new_physical_to_logical_map
# In order to avoid race condition with async eplb worker,
# we need to copy blocking in case of updated EP size.
if model_state.physical_to_logical_map.shape[1] != new_physical.shape[1]:
model_state.physical_to_logical_map = new_physical.to(target_device)
else:
model_state.physical_to_logical_map[layer].copy_(
new_physical[layer].to(target_device)
new_physical[layer].to(target_device, non_blocking=True)
)
logical_device = model_state.logical_to_physical_map.device
@ -968,25 +1016,28 @@ class EplbState:
stream = torch.cuda.current_stream(device=device_index)
stream.wait_event(model_state.buffer_ready_event)
model_state.buffer_ready_event = None
expert_weights = model_state.model.expert_weights[
model_state.layer_to_transfer
]
expert_weights_buffer = model_state.expert_buffer
new_indices = model_state.new_physical_to_logical_map[
model_state.layer_to_transfer
].numpy()
move_from_buffer(
expert_weights=model_state.model.expert_weights[
model_state.layer_to_transfer
],
expert_weights_buffer=model_state.expert_buffer,
expert_weights=expert_weights,
expert_weights_buffers=expert_weights_buffer,
is_unchanged=model_state.is_unchanged,
is_received_locally=model_state.is_received_locally,
experts_recv_loc=model_state.experts_recv_loc,
new_indices=model_state.new_physical_to_logical_map[
model_state.layer_to_transfer
].tolist(),
ep_group=ep_group,
recv_metadata=model_state.recv_metadata,
new_indices=new_indices,
ep_rank=ep_group.rank(),
)
transferred_layer = model_state.layer_to_transfer
self._update_layer_mapping_from_new(model_state, transferred_layer)
# After the main thread consumes, advance layer_to_transfer
model_state.layer_to_transfer += 1
model_state.ep_buffer_ready = 0
logger.info(
logger.debug(
"model %s successfully move_to_workspace layer %d",
model_state.model_name,
transferred_layer,
@ -1005,9 +1056,7 @@ class EplbState:
assert model_state.new_physical_to_logical_map is not None
assert model_state.new_logical_to_physical_map is not None
assert model_state.new_logical_replica_count is not None
if not is_profile:
for layer_idx in range(model_state.physical_to_logical_map.shape[0]):
self._update_layer_mapping_from_new(model_state, layer_idx)
model_state.new_physical_to_logical_map = None
model_state.new_logical_to_physical_map = None
model_state.new_logical_replica_count = None

View File

@ -16,6 +16,7 @@ class AbstractEplbPolicy(ABC):
num_groups: int,
num_nodes: int,
num_ranks: int,
old_global_expert_indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Entry point for expert-parallelism load balancer.
@ -28,7 +29,9 @@ class AbstractEplbPolicy(ABC):
num_groups: number of expert groups
num_nodes: number of server nodes
num_ranks: number of ranks, must be a multiple of `num_nodes`
old_global_expert_indices: [layers, num_logical_experts], the old global
expert indices. Used to avoid unnecessary weight copying
for experts moving within one rank.
Returns:
physical_to_logical_map: [layers, num_replicas], the expert
index of each replica

View File

@ -21,8 +21,8 @@ from .abstract import AbstractEplbPolicy
class DefaultEplbPolicy(AbstractEplbPolicy):
@classmethod
def balanced_packing(
cls, weight: torch.Tensor, num_packs: int
) -> tuple[torch.Tensor, torch.Tensor]:
cls, weight: np.ndarray, num_packs: int
) -> tuple[np.ndarray, np.ndarray]:
"""
Pack n weighted objects to m packs, such that each bin contains exactly
n/m objects and the weights of all packs are as balanced as possible.
@ -39,50 +39,43 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
assert num_groups % num_packs == 0
groups_per_pack = num_groups // num_packs
device = weight.device
if groups_per_pack == 1:
pack_index = torch.arange(
weight.size(-1), dtype=torch.int64, device=device
).expand(weight.shape)
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device)
pack_index = np.tile(np.arange(num_groups, dtype=np.int64), (num_layers, 1))
rank_in_pack = np.zeros_like(pack_index, dtype=np.int64)
return pack_index, rank_in_pack
weight_np = weight.cpu().numpy()
# Sort and get indices in decending order
indices_np = np.argsort(-weight_np, axis=-1)
indices = np.argsort(-weight, axis=-1)
pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
pack_index = np.full((num_layers, num_groups), -1, dtype=np.int64)
rank_in_pack = np.full((num_layers, num_groups), -1, dtype=np.int64)
pack_weights = np.zeros((num_layers, num_packs), dtype=np.float64)
pack_items = np.zeros((num_layers, num_packs), dtype=np.int64)
# Run the packing algorithm
for i in range(num_layers):
pack_weights = [0.0] * num_packs
pack_items = [0] * num_packs
for layer_idx in range(num_layers):
weights_row = pack_weights[layer_idx]
items_row = pack_items[layer_idx]
for group in indices_np[i]:
# Find a pack with capacity that has the lowest weight
pack = min(
(j for j in range(num_packs) if pack_items[j] < groups_per_pack),
key=pack_weights.__getitem__,
)
for group in indices[layer_idx]:
# Pick the lightest pack; full packs are masked out by inf.
pack = int(np.argmin(weights_row))
assert pack_items[pack] < groups_per_pack
pack_index_np[i, group] = pack
rank_in_pack_np[i, group] = pack_items[pack]
pack_weights[pack] += weight_np[i, group]
pack_items[pack] += 1
pack_index = torch.from_numpy(pack_index_np).to(device)
rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device)
pack_index[layer_idx, group] = pack
rank_in_pack[layer_idx, group] = items_row[pack]
weights_row[pack] += weight[layer_idx, group]
items_row[pack] += 1
if items_row[pack] == groups_per_pack:
# Mark as unavailable for future selections.
weights_row[pack] = np.inf
return pack_index, rank_in_pack
@classmethod
def replicate_experts(
cls, weight: torch.Tensor, num_phy: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
cls, weight: np.ndarray, num_phy: int
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
load of all replicas is minimized.
@ -93,33 +86,32 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
Returns:
phy2log: [X, num_phy], logical expert id of each physical expert
rank: [X, num_phy], the replica rank
replica_idx: [X, num_phy], the index of the replica for each logical expert
logcnt: [X, num_log], number of replicas for each logical expert
"""
n, num_log = weight.shape
num_redundant = num_phy - num_log
assert num_redundant >= 0
device = weight.device
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
arangen = torch.arange(n, dtype=torch.int64, device=device)
phy2log = np.tile(np.arange(num_phy, dtype=np.int64), (n, 1))
replica_idx = np.zeros((n, num_phy), dtype=np.int64)
logcnt = np.ones((n, num_log), dtype=np.int64)
arangen = np.arange(n, dtype=np.int64)
for i in range(num_log, num_phy):
redundant_indices = (weight / logcnt).max(dim=-1).indices
redundant_indices = np.argmax(weight / logcnt, axis=-1)
phy2log[:, i] = redundant_indices
rank[:, i] = logcnt[arangen, redundant_indices]
replica_idx[:, i] = logcnt[arangen, redundant_indices]
logcnt[arangen, redundant_indices] += 1
return phy2log, rank, logcnt
return phy2log, replica_idx, logcnt
@classmethod
def rebalance_experts_hierarchical(
cls,
weight: torch.Tensor,
weight: np.ndarray,
num_physical_experts: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Parameters:
weight: [num_moe_layers, num_logical_experts]
@ -132,7 +124,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
Returns:
phy2log: [layers, num_replicas], the expert
index of each replica
log2phy: [layers, num_logical_experts, X],
pphy_replicas_idx: [layers, num_logical_experts, X],
the replica indices for each expert
logcnt: [layers, num_logical_experts], number of
physical replicas for each logical expert
@ -146,66 +138,160 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
assert num_physical_experts % num_gpus == 0
phy_experts_per_gpu = num_physical_experts // num_gpus
def inverse(perm: torch.Tensor) -> torch.Tensor:
inv = torch.empty_like(perm)
inv.scatter_(
1,
perm,
torch.arange(
perm.size(1), dtype=torch.int64, device=perm.device
).expand(perm.shape),
)
def inverse(perm: np.ndarray) -> np.ndarray:
inv = np.empty_like(perm)
row_idx = np.arange(perm.shape[0])[:, None]
col_idx = np.arange(perm.shape[1], dtype=np.int64)
inv[row_idx, perm] = col_idx
return inv
# Step 1: pack groups to nodes
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
tokens_per_group = weight.reshape(num_layers, num_groups, group_size).sum(
axis=-1
)
group_pack_index, group_rank_in_pack = cls.balanced_packing(
tokens_per_group, num_nodes
)
# Map each logical expert into a node-local ordering based on packed groups.
log2mlog = (
(
(group_pack_index * groups_per_node + group_rank_in_pack) * group_size
).unsqueeze(-1)
+ torch.arange(
group_size, dtype=torch.int64, device=group_pack_index.device
(group_pack_index * groups_per_node + group_rank_in_pack)[..., None]
* group_size
)
).flatten(-2)
+ np.arange(group_size, dtype=np.int64)
).reshape(num_layers, num_logical_experts)
mlog2log = inverse(log2mlog)
# Step 2: construct redundant experts within nodes
# [num_layers * num_nodes, num_logical_experts // num_nodes]
tokens_per_mlog = weight.gather(-1, mlog2log).view(
# Reorder weights into the node-local layout so replication is done per node.
tokens_per_mlog = np.take_along_axis(weight, mlog2log, axis=1).reshape(
-1, num_logical_experts // num_nodes
)
phy2mlog, phyrank, mlogcnt = cls.replicate_experts(
phy2mlog, replicas_idx, mlogcnt = cls.replicate_experts(
tokens_per_mlog, num_physical_experts // num_nodes
)
# Step 3: pack physical_experts to GPUs
# [num_layers * num_nodes, num_physical_experts // num_nodes]
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
# Effective per-physical load = logical load divided by replica count.
tokens_per_phy = np.take_along_axis(tokens_per_mlog / mlogcnt, phy2mlog, axis=1)
pack_index, rank_in_pack = cls.balanced_packing(
tokens_per_phy, num_gpus // num_nodes
)
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
pphy2phy = inverse(phy2pphy)
pphy2mlog = phy2mlog.gather(
-1, pphy2phy
) # [num_layers * num_nodes, num_log_per_nodes]
# Reorder node-local logical indices into the post-packing physical order.
pphy2mlog = np.take_along_axis(phy2mlog, pphy2phy, axis=1)
pphy2mlog = (
pphy2mlog.view(num_layers, num_nodes, -1)
+ torch.arange(
pphy2mlog.reshape(num_layers, num_nodes, -1)
+ np.arange(
0,
num_logical_experts,
num_logical_experts // num_nodes,
device=group_pack_index.device,
).view(1, -1, 1)
).flatten(-2)
pphy2log = mlog2log.gather(-1, pphy2mlog)
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
return pphy2log, pphyrank, logcnt
dtype=np.int64,
)[None, :, None]
).reshape(num_layers, -1)
# Map node-local logical indices back to global logical expert ids.
pphy2log = np.take_along_axis(mlog2log, pphy2mlog, axis=1)
# Reorder replica ranks to the post-packing physical ordering.
pphy_replicas_idx = np.take_along_axis(replicas_idx, pphy2phy, axis=1).reshape(
num_layers, -1
)
# Convert replica counts back to the original logical ordering.
logcnt = np.take_along_axis(mlogcnt.reshape(num_layers, -1), log2mlog, axis=1)
return pphy2log, pphy_replicas_idx, logcnt
@classmethod
def preserve_intragpu_slots(
cls,
phy2log: np.ndarray,
phy_replicas_idx: np.ndarray,
num_ranks: int,
old_phy2log: np.ndarray,
) -> tuple[np.ndarray, np.ndarray]:
"""
Reorder the new mapping per GPU so that experts that remain on the same GPU
keep their previous slot positions when possible. Incoming experts to that GPU
fill any remaining available slots. This is applied only when the number of GPUs
is unchanged and the slots per GPU remain the same between
the old and new mappings.
"""
num_phy_experts = phy2log.shape[1]
if num_ranks <= 0 or num_phy_experts % num_ranks != 0:
return phy2log, phy_replicas_idx
# Move to CPU and convert to NumPy for processing
slots_per_gpu = num_phy_experts // num_ranks
num_layers = phy2log.shape[0]
post_phy2log = phy2log.copy()
post_phy_replicas_idx = phy_replicas_idx.copy()
for gpu_idx in range(num_ranks):
start = gpu_idx * slots_per_gpu
end = start + slots_per_gpu
# Experts across all layers for this GPU
old_local = old_phy2log[:, start:end] # [layers, slots]
new_local = phy2log[:, start:end] # [layers, slots]
new_ridx = phy_replicas_idx[:, start:end] # [layers, slots]
used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool)
preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool)
# First pass: preserve same-logical experts in their previous slots
for slot_idx in range(slots_per_gpu):
# matches: [layers, slots], True where new local experts have
# the same logical value as the old from 'slot_idx' and not checked yet
matches = (new_local == old_local[:, slot_idx][:, None]) & (
~used_new_indices
)
has_any = matches.any(axis=1)
if np.any(has_any):
first_idx = np.argmax(matches, axis=1)
layer_indices = np.nonzero(has_any)[0]
matched_new_positions = first_idx[layer_indices]
post_phy2log[layer_indices, start + slot_idx] = new_local[
layer_indices, matched_new_positions
]
post_phy_replicas_idx[layer_indices, start + slot_idx] = new_ridx[
layer_indices, matched_new_positions
]
used_new_indices[layer_indices, matched_new_positions] = True
preserved_positions[layer_indices, slot_idx] = True
# Second pass: fill remaining slots with remaining new experts
remaining_mask = ~used_new_indices # [layers, slots]
fill_mask = ~preserved_positions # [layers, slots]
if remaining_mask.any() and fill_mask.any():
idx_base = np.tile(np.arange(slots_per_gpu), (num_layers, 1))
# Sentinel value for unavailable positions.
large = slots_per_gpu + 1
# Priorities: keep original index for available spots, set sentinel
# for unavailable; lower is earlier.
remaining_priority = np.where(remaining_mask, idx_base, large)
fill_priority = np.where(fill_mask, idx_base, large)
# Sort to get ordered indices of available src/dst positions per layer.
remaining_indices = np.argsort(remaining_priority, axis=1)
fill_indices = np.argsort(fill_priority, axis=1)
# Fill count per layer (cannot exceed either side).
remaining_counts = remaining_mask.sum(axis=1)
fill_counts = fill_mask.sum(axis=1)
take_counts = np.minimum(remaining_counts, fill_counts)
# Assign remaining new experts to remaining slots per layer.
for layer_idx in range(num_layers):
k = int(take_counts[layer_idx])
if k <= 0:
continue
src_pos = remaining_indices[layer_idx, :k]
dst_pos = fill_indices[layer_idx, :k]
post_phy2log[layer_idx, start + dst_pos] = new_local[
layer_idx, src_pos
]
post_phy_replicas_idx[layer_idx, start + dst_pos] = new_ridx[
layer_idx, src_pos
]
return post_phy2log, post_phy_replicas_idx
@classmethod
def rebalance_experts(
@ -215,6 +301,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
num_groups: int,
num_nodes: int,
num_ranks: int,
old_global_expert_indices: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Entry point for expert-parallelism load balancer.
@ -228,7 +315,9 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
num_nodes: number of server nodes, where the intra-node network
(e.g, NVLink) is faster
num_ranks: number of ranks, must be a multiple of `num_nodes`
old_global_expert_indices: [layers, num_logical_experts], the old global
expert indices. Used to avoid unnecessary weight copying
for experts moving within one rank.
Returns:
phy2log: [layers, num_replicas], the expert
index of each replica
@ -237,31 +326,51 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
logcnt: [layers, num_logical_experts], number of
physical replicas for each logical expert
"""
device = weight.device
num_layers, num_logical_experts = weight.shape
weight = weight.float()
weight_np = weight.float().cpu().numpy()
old_phy2log_np = (
old_global_expert_indices.cpu().numpy()
if old_global_expert_indices is not None
else None
)
if num_groups % num_nodes == 0:
# use hierarchical load-balance policy
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_ranks
phy2log_np, phy_replicas_idx_np, logcnt_np = (
cls.rebalance_experts_hierarchical(
weight_np, num_replicas, num_groups, num_nodes, num_ranks
)
)
else:
# use global load-balance policy
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical(
weight, num_replicas, 1, 1, num_ranks
phy2log_np, phy_replicas_idx_np, logcnt_np = (
cls.rebalance_experts_hierarchical(
weight_np, num_replicas, 1, 1, num_ranks
)
)
# Optional postprocessing to preserve slots for experts moving
# within the same GPU
# Only apply when the number of GPUs and slots per GPU remain unchanged.
# Helps to avoid unnecessary weight copying when experts move
# within the same GPU.
if old_global_expert_indices is not None:
phy2log_np, phy_replicas_idx_np = cls.preserve_intragpu_slots(
phy2log_np, phy_replicas_idx_np, num_ranks, old_phy2log_np
)
num_redundant_experts = num_replicas - num_logical_experts
maxlogcnt = num_redundant_experts + 1
log2phy: torch.Tensor = torch.full(
(num_layers, num_logical_experts, maxlogcnt),
-1,
dtype=torch.int64,
device=logcnt.device,
log2phy_np = np.full(
(num_layers, num_logical_experts, maxlogcnt), -1, dtype=np.int64
)
log2phy.view(num_layers, -1).scatter_(
-1,
phy2log * maxlogcnt + phyrank,
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
num_layers, -1
),
layer_indices = np.arange(num_layers)[:, None]
replica_indices = np.tile(
np.arange(num_replicas, dtype=np.int64), (num_layers, 1)
)
log2phy_np[layer_indices, phy2log_np, phy_replicas_idx_np] = replica_indices
phy2log = torch.from_numpy(phy2log_np).to(device)
log2phy = torch.from_numpy(log2phy_np).to(device)
logcnt = torch.from_numpy(logcnt_np).to(device)
return phy2log, log2phy, logcnt

View File

@ -6,9 +6,10 @@ The actual execution of the rearrangement.
This involves the exchange of expert weights between GPUs.
"""
from collections.abc import Iterable, MutableSequence, Sequence
from functools import partial
from collections.abc import Iterable, Sequence
from dataclasses import dataclass
import numpy as np
import torch
from torch.distributed import (
P2POp,
@ -18,214 +19,318 @@ from torch.distributed import (
get_global_rank,
)
from vllm.logger import init_logger
def idx_local_to_global(
local_idx: int,
local_cnt: int,
ep_rank: int,
) -> int:
"""
Convert a local expert index to a global expert index.
"""
return ep_rank * local_cnt + local_idx
logger = init_logger(__name__)
def idx_global_to_local(
global_idx: int,
local_cnt: int,
ep_rank: int,
) -> int:
"""
Convert a global expert index to a local expert index.
"""
return global_idx - ep_rank * local_cnt
@dataclass
class RecvMetadata:
"""Metadata describing remote receives during EPLB rebalancing."""
recv_primary_mask: np.ndarray
"""Mask of (num_local_experts,) indicating primary experts received."""
recv_count: int
"""Number of received experts for the layer."""
recv_expert_ids: np.ndarray
"""Expert ids (num_local_experts,) of remote primary experts."""
recv_dst_rows: np.ndarray
"""Target expert indices (num_local_experts,) in local tensors to send."""
def global_idx_to_rank(
global_idx: int,
local_cnt: int,
) -> int:
"""
Convert a global expert index to a rank index.
"""
return global_idx // local_cnt
# Type alias for the result of move_to_buffer or transfer_layer
MoveToBufferResult = tuple[np.ndarray, np.ndarray, RecvMetadata]
def get_ep_ranks_with_expert(
idx: int,
def get_ep_ranks_with_experts_batch(
expert_ids: np.ndarray,
num_local_experts: int,
old_indices: Sequence[int],
new_indices: Sequence[int],
) -> tuple[MutableSequence[int], MutableSequence[int]]:
old_indices: np.ndarray,
new_indices: np.ndarray,
) -> tuple[dict[int, list[int]], dict[int, list[int]]]:
"""
Get the ranks of the experts that need to be exchanged.
Args:
idx: The index of the expert.
expert_ids: 1D array of expert indices to query.
num_local_experts: The number of local experts.
old_indices: The old indices of the experts.
new_indices: The new indices of the experts.
Returns:
A tuple of two lists:
- The ranks of the experts that need to be sent.
- The ranks of the experts that need to be received.
A tuple of two dictionaries mapping expert_id to:
- ranks_to_send: The ranks that have this expert and need to send.
- ranks_to_recv: The ranks that need to receive this expert.
"""
global2rank = partial(
global_idx_to_rank,
local_cnt=num_local_experts,
)
ranks_to_send_map: dict[int, list[int]] = {}
ranks_to_recv_map: dict[int, list[int]] = {}
ranks_to_send: list[int] = []
ranks_to_recv: list[int] = []
# Fast path: if no experts, return empty dicts
if expert_ids.size == 0:
return ranks_to_send_map, ranks_to_recv_map
for i, e in enumerate(old_indices):
if e == idx:
rank = global2rank(i)
if not ranks_to_send or ranks_to_send[-1] != rank:
ranks_to_send.append(rank)
unique_experts = np.unique(expert_ids)
num_positions = len(old_indices)
position_indices = np.arange(num_positions, dtype=np.int32)
for i, e in enumerate(new_indices):
if e == idx:
rank = global2rank(i)
if not ranks_to_recv or ranks_to_recv[-1] != rank:
ranks_to_recv.append(rank)
# Vectorized approach: find all positions matching any query expert in one pass
# Use np.isin to get boolean masks for all relevant positions at once
old_relevant_mask = np.isin(old_indices, unique_experts)
new_relevant_mask = np.isin(new_indices, unique_experts)
# Remove those ranks that can get this expert locally.
ranks_to_send_set = set(ranks_to_send)
ranks_to_recv_actual = [
rank for rank in ranks_to_recv if rank not in ranks_to_send_set
]
# Process old_indices (send ranks)
if np.any(old_relevant_mask):
old_relevant_positions = position_indices[old_relevant_mask]
old_relevant_experts = old_indices[old_relevant_mask]
old_relevant_ranks = old_relevant_positions // num_local_experts
return ranks_to_send, ranks_to_recv_actual
# Sort by expert first, then by position (to maintain first-appearance order)
sort_order = np.lexsort((old_relevant_positions, old_relevant_experts))
sorted_experts = old_relevant_experts[sort_order]
sorted_ranks = old_relevant_ranks[sort_order]
# Find boundaries where expert changes
expert_boundaries = np.concatenate(
[[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]]
)
# For each expert, extract unique ranks in order of first appearance
for i in range(len(expert_boundaries) - 1):
start, end = expert_boundaries[i], expert_boundaries[i + 1]
expert = int(sorted_experts[start])
expert_ranks = sorted_ranks[start:end]
# Get unique ranks preserving order
_, unique_idx = np.unique(expert_ranks, return_index=True)
unique_ranks = expert_ranks[np.sort(unique_idx)]
ranks_to_send_map[expert] = unique_ranks.tolist()
# Process new_indices (recv ranks)
if np.any(new_relevant_mask):
new_relevant_positions = position_indices[new_relevant_mask]
new_relevant_experts = new_indices[new_relevant_mask]
new_relevant_ranks = new_relevant_positions // num_local_experts
# Sort by expert first, then by position
sort_order = np.lexsort((new_relevant_positions, new_relevant_experts))
sorted_experts = new_relevant_experts[sort_order]
sorted_ranks = new_relevant_ranks[sort_order]
# Find boundaries where expert changes
expert_boundaries = np.concatenate(
[[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]]
)
# For each expert, extract unique ranks and exclude local copies
for i in range(len(expert_boundaries) - 1):
start, end = expert_boundaries[i], expert_boundaries[i + 1]
expert = int(sorted_experts[start])
expert_ranks = sorted_ranks[start:end]
# Get unique ranks preserving order
_, unique_idx = np.unique(expert_ranks, return_index=True)
unique_ranks = expert_ranks[np.sort(unique_idx)]
# Remove ranks that have local copies (in send map)
send_ranks_set = set(ranks_to_send_map.get(expert, []))
recv_ranks_actual = [
int(r) for r in unique_ranks if r not in send_ranks_set
]
ranks_to_recv_map[expert] = recv_ranks_actual
# Handle experts that only appear in old (send only) or new (recv only)
for expert in unique_experts:
expert = int(expert)
if expert not in ranks_to_send_map:
ranks_to_send_map[expert] = []
if expert not in ranks_to_recv_map:
ranks_to_recv_map[expert] = []
return ranks_to_send_map, ranks_to_recv_map
def move_to_buffer(
num_local_experts: int,
old_indices: Sequence[int],
new_indices: Sequence[int],
old_indices: np.ndarray,
new_indices: np.ndarray,
expert_weights: Iterable[torch.Tensor],
expert_weights_buffer: Sequence[torch.Tensor],
expert_weights_buffers: Sequence[torch.Tensor],
cuda_stream: torch.cuda.Stream | None,
ep_group: ProcessGroup,
) -> tuple[list[bool], list[bool], dict[int, int]]:
) -> MoveToBufferResult:
"""
Perform expert weights rearrangement of one layer.
Rearranges expert weights during EPLB rebalancing.
Args:
num_local_experts: Number of local experts.
old_indices: (num_experts_total,) ndarray of current (old)
global-to-local expert assignments.
new_indices: (num_experts_total,) ndarray of desired (new)
global-to-local assignments after rebalance.
expert_weights: Original expert weights for the layer.
expert_weights_buffers: Intermediate buffers (one per tensor).
cuda_stream: CUDA stream for async copies (can be None for sync mode).
ep_group: Distributed process group for expert parallel comms.
Returns:
is_unchanged (np.ndarray): (num_local_experts,), True where an expert row
is unchanged after rebalance.
is_received_locally (np.ndarray): (num_local_experts,), True where a row
can be updated from local data.
RecvMetadata: Metadata needed for completing remote weight transfers.
"""
assert old_indices.shape == new_indices.shape
ep_rank = ep_group.rank()
local2global = partial(
idx_local_to_global,
local_cnt=num_local_experts,
ep_rank=ep_rank,
recv_primary_mask = np.zeros((num_local_experts,), dtype=np.bool_)
send_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64)
send_src_rows = np.full((num_local_experts,), -1, dtype=np.int32)
recv_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64)
recv_dst_rows = np.full((num_local_experts,), -1, dtype=np.int32)
base = ep_rank * num_local_experts
local_rows = np.arange(num_local_experts, dtype=np.int32)
local_global = base + local_rows
old_local_expert_ids = old_indices[local_global]
new_local_expert_ids = new_indices[local_global]
# Unchanged mask
is_unchanged = old_local_expert_ids == new_local_expert_ids
# Local receive eligibility
new_valid = new_local_expert_ids != -1
can_recv_local = np.isin(
new_local_expert_ids, old_local_expert_ids, assume_unique=False
)
is_received_locally = np.logical_or(
is_unchanged, np.logical_and(new_valid, can_recv_local)
)
# 0. Do nothing for experts that did not change.
is_unchanged = [
old_indices[local2global(i)] == new_indices[local2global(i)]
for i in range(num_local_experts)
]
# Send map: first src row per unique expert present locally in old mapping
send_count = 0
valid_old = old_local_expert_ids != -1
if np.any(valid_old):
uniq_experts, first_idx = np.unique(
old_local_expert_ids[valid_old], return_index=True
)
filtered_rows = local_rows[valid_old]
src_rows = filtered_rows[first_idx]
send_count = int(uniq_experts.shape[0])
send_expert_ids[:send_count] = uniq_experts
send_src_rows[:send_count] = src_rows
# 1. Perform weight copy inside the local rank.
is_received_locally = is_unchanged[:]
for src in range(num_local_experts):
src_global = local2global(src)
for dst in range(num_local_experts):
dst_global = local2global(dst)
if is_received_locally[dst]:
continue
if old_indices[src_global] == -1 or new_indices[dst_global] == -1:
continue
if old_indices[src_global] == new_indices[dst_global]:
is_received_locally[dst] = True
for weight, buffer in zip(expert_weights, expert_weights_buffer):
with torch.cuda.stream(cuda_stream):
buffer[dst].copy_(weight[src], non_blocking=True)
# Recv map: primary dst per unique expert needed remotely
recv_count = 0
need_recv_mask = np.logical_and(~is_received_locally, new_valid)
if np.any(need_recv_mask):
desired_experts = new_local_expert_ids[need_recv_mask]
desired_dsts = local_rows[need_recv_mask]
uniq_recv_experts, uniq_indices = np.unique(desired_experts, return_index=True)
dst_rows = desired_dsts[uniq_indices]
recv_count = int(uniq_recv_experts.shape[0])
recv_expert_ids[:recv_count] = uniq_recv_experts
recv_dst_rows[:recv_count] = dst_rows
recv_primary_mask[dst_rows] = True
eligible_local_buffer_mask = np.logical_and(~is_unchanged, is_received_locally)
# 1. Local moves into tmp buffers
if bool(eligible_local_buffer_mask.any()) and send_count > 0:
dest_indices = np.nonzero(eligible_local_buffer_mask)[0].tolist()
expert_to_src_map = dict(
zip(send_expert_ids[:send_count], send_src_rows[:send_count])
)
for dst in dest_indices:
expert = new_local_expert_ids[dst]
src_local = expert_to_src_map.get(expert, -1)
if src_local != -1:
for w, b in zip(expert_weights, expert_weights_buffers):
b[dst].copy_(w[src_local], non_blocking=True)
p2p_ops: list[P2POp] = []
# 2. Initiate sending of weights.
experts_send_loc: dict[int, int] = {}
for src in range(num_local_experts):
expert = old_indices[local2global(src)]
if expert == -1:
continue
if expert in experts_send_loc:
continue
experts_send_loc[expert] = src
# Pre-compute global ranks mapping
ep_size = ep_group.size()
rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)}
# We need to sort here to match send/recv
for expert, src in sorted(experts_send_loc.items()):
ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert(
expert,
# 2. Post sends
if send_count > 0:
experts = send_expert_ids[:send_count]
srcs = send_src_rows[:send_count]
order = np.argsort(experts, kind="stable")
experts = experts[order]
srcs = srcs[order]
send_map, recv_map = get_ep_ranks_with_experts_batch(
experts,
num_local_experts,
old_indices,
new_indices,
)
# Calculate the ranks to send by this rank
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
sender_pos = ranks_to_send.index(ep_rank)
recv_begin = sender_pos * num_dst_per_sender
recv_end = recv_begin + num_dst_per_sender
recv_ranks = ranks_to_recv[recv_begin:recv_end]
for expert, src in zip(experts.tolist(), srcs.tolist()):
ranks_to_send = send_map[expert]
ranks_to_recv = recv_map[expert]
if not ranks_to_send or not ranks_to_recv:
continue
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
sender_pos = ranks_to_send.index(ep_rank)
recv_begin = sender_pos * num_dst_per_sender
recv_end = recv_begin + num_dst_per_sender
recv_ranks = ranks_to_recv[recv_begin:recv_end]
remainder_start = len(ranks_to_send) * num_dst_per_sender
recver_pos = remainder_start + sender_pos
if recver_pos < len(ranks_to_recv):
recv_ranks.append(ranks_to_recv[recver_pos])
for dst in recv_ranks:
dst_global = rank_to_global[dst]
p2p_ops += [
P2POp(
torch.distributed.isend,
w[src],
dst_global,
)
for w in expert_weights
]
# Tackle remainders
remainder_start = len(ranks_to_send) * num_dst_per_sender
recver_pos = remainder_start + sender_pos
if recver_pos < len(ranks_to_recv):
recv_ranks.append(ranks_to_recv[recver_pos])
# 3. Post recvs
if recv_count > 0:
experts = recv_expert_ids[:recv_count]
dsts = recv_dst_rows[:recv_count]
order = np.argsort(experts, kind="stable")
experts = experts[order]
dsts = dsts[order]
for dst in recv_ranks:
dst_global = get_global_rank(ep_group, dst)
send_map, recv_map = get_ep_ranks_with_experts_batch(
experts,
num_local_experts,
old_indices,
new_indices,
)
for expert, dst in zip(experts.tolist(), dsts.tolist()):
ranks_to_send = send_map[expert]
ranks_to_recv = recv_map[expert]
if not ranks_to_send or not ranks_to_recv:
continue
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
recver_pos = ranks_to_recv.index(ep_rank)
remainder_start = len(ranks_to_send) * num_dst_per_sender
if recver_pos < remainder_start:
src = ranks_to_send[recver_pos // num_dst_per_sender]
else:
src = ranks_to_send[recver_pos - remainder_start]
src_global = rank_to_global[src]
p2p_ops += [
P2POp(
torch.distributed.isend,
weight[src],
dst_global,
torch.distributed.irecv,
b[dst],
src_global,
)
for weight in expert_weights
for b in expert_weights_buffers
]
# 3. Initiate receiving of weights.
experts_recv_loc: dict[int, int] = {}
for dst in range(num_local_experts):
if is_received_locally[dst]:
continue
expert = new_indices[local2global(dst)]
if expert == -1:
continue
if expert in experts_recv_loc:
continue
experts_recv_loc[expert] = dst
# We need to sort here to match send/recv
for expert, dst in sorted(experts_recv_loc.items()):
ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert(
expert,
num_local_experts,
old_indices,
new_indices,
)
# Calculate the rank to recv by this rank
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
recver_pos = ranks_to_recv.index(ep_rank)
remainder_start = len(ranks_to_send) * num_dst_per_sender
if recver_pos < remainder_start:
src = ranks_to_send[recver_pos // num_dst_per_sender]
else:
src = ranks_to_send[recver_pos - remainder_start]
src_global = get_global_rank(ep_group, src)
p2p_ops += [
P2POp(
torch.distributed.irecv,
weight[dst],
src_global,
)
for weight in expert_weights_buffer
]
# 4. Execute the P2P operations. The real communication happens here.
if p2p_ops and cuda_stream is not None:
with torch.cuda.stream(cuda_stream):
@ -237,51 +342,107 @@ def move_to_buffer(
for req in reqs:
req.wait()
# wait for the communication to finish
return is_unchanged, is_received_locally, experts_recv_loc
return (
is_unchanged,
is_received_locally,
RecvMetadata(
recv_primary_mask=recv_primary_mask,
recv_count=recv_count,
recv_expert_ids=recv_expert_ids,
recv_dst_rows=recv_dst_rows,
),
)
def move_from_buffer(
expert_weights: Iterable[torch.Tensor],
expert_weights_buffer: list[torch.Tensor],
is_unchanged: list[bool],
is_received_locally: list[bool],
experts_recv_loc: dict[int, int],
new_indices: Sequence[int],
ep_group: ProcessGroup,
expert_weights_buffers: list[torch.Tensor],
is_unchanged: np.ndarray,
is_received_locally: np.ndarray,
recv_metadata: RecvMetadata,
new_indices: np.ndarray,
ep_rank: int,
) -> None:
ep_rank = ep_group.rank()
num_local_experts = len(is_unchanged)
"""
Copies expert weights from communication buffers back to the target weight tensors
after EPLB rebalancing.
local2global = partial(
idx_local_to_global, local_cnt=num_local_experts, ep_rank=ep_rank
Args:
expert_weights: List of the actual MoE layer weights used in the execution.
expert_weights_buffers: Intermediate buffers containing the experts weights
after the transfer is completed.
is_unchanged: (num_local_experts,), True where an expert row is unchanged.
is_received_locally: (num_local_experts,), True where a row is updated locally.
recv_metadata: RecvMetadata containing remote receive metadata.
new_indices: (num_experts_total,) mapping from local rows to desired
(possibly global) expert id, after rebalance.
ep_rank: Rank of the process in the expert parallel group.
"""
recv_primary_mask = recv_metadata.recv_primary_mask
recv_count = recv_metadata.recv_count
recv_expert_ids = recv_metadata.recv_expert_ids
recv_dst_rows = recv_metadata.recv_dst_rows
num_local_experts = is_unchanged.shape[0]
# Mask for rows to copy back from buffers:
# copy if locally received OR remote primary recv
copy_mask = np.logical_or(is_received_locally, recv_primary_mask)
dest_mask_np = np.logical_and(~is_unchanged, copy_mask)
if bool(dest_mask_np.any()):
dest_indices = np.nonzero(dest_mask_np)[0].tolist()
for dst in dest_indices:
for w, b in zip(expert_weights, expert_weights_buffers):
w[dst].copy_(b[dst], non_blocking=True)
if recv_count == 0:
return
# Duplicate remote received rows to non-primary duplicate dsts
base = ep_rank * num_local_experts
local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)]
duplicate_mask = np.logical_and(
np.logical_and(~is_unchanged, ~is_received_locally),
np.logical_and(~recv_primary_mask, local_experts != -1),
)
# All received experts are unique in the destination, so no need to copy duplicates
if not bool(duplicate_mask.any()):
return
for dst in range(num_local_experts):
if is_unchanged[dst]:
continue
if is_received_locally[dst]:
for weight, buffer in zip(expert_weights, expert_weights_buffer):
weight[dst].copy_(buffer[dst], non_blocking=True)
else:
expert = new_indices[local2global(dst)]
if expert == -1:
continue
src = experts_recv_loc[expert]
for weight, buffer in zip(expert_weights, expert_weights_buffer):
weight[dst].copy_(buffer[src], non_blocking=True)
dup_dst_rows = np.nonzero(duplicate_mask)[0]
dup_experts = local_experts[dup_dst_rows]
prim_experts = recv_expert_ids[:recv_count]
prim_dsts = recv_dst_rows[:recv_count]
order = np.argsort(prim_experts, kind="stable")
prim_experts_sorted = prim_experts[order]
prim_dsts_sorted = prim_dsts[order]
pos = np.searchsorted(prim_experts_sorted, dup_experts)
valid = np.logical_and(
pos < prim_experts_sorted.shape[0],
prim_experts_sorted[np.minimum(pos, prim_experts_sorted.shape[0] - 1)]
== dup_experts,
)
if not bool(valid.any()):
return
matched_dst_rows = dup_dst_rows[valid]
matched_src_rows = prim_dsts_sorted[pos[valid]]
for dst, src in zip(matched_dst_rows.tolist(), matched_src_rows.tolist()):
for w in expert_weights:
w[dst].copy_(w[src], non_blocking=True)
async def transfer_layer(
old_global_expert_indices: torch.Tensor,
new_global_expert_indices: torch.Tensor,
expert_weights: Sequence[Iterable[torch.Tensor]],
old_layer_indices: torch.Tensor,
new_layer_indices: torch.Tensor,
expert_weights: Iterable[torch.Tensor],
expert_weights_buffer: Sequence[torch.Tensor],
ep_group: ProcessGroup,
is_profile: bool = False,
layer: int = 0,
cuda_stream: torch.cuda.Stream | None = None,
rank_mapping: dict[int, int] | None = None,
) -> tuple[list[bool], list[bool], dict[int, int]]:
) -> MoveToBufferResult:
"""
Rearranges the expert weights in place according to the new expert indices.
@ -289,50 +450,68 @@ async def transfer_layer(
while keys are physical.
Args:
old_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
new_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
expert_weights: A sequence of shape (num_moe_layers)(weight_count)
of tensors of shape (num_local_physical_experts, hidden_size_i).
For example, a linear layer may have up and down projection,
so weight_count = 2. Each weight's hidden size can be different.
old_layer_indices: Shape (num_physical_experts,).
new_layer_indices: Shape (num_physical_experts,).
expert_weights: Iterable of weight tensors for this layer, each with shape
(num_local_physical_experts, hidden_size_i).
For example, a linear layer may have up and down projection.
expert_weights_buffer: Intermediate buffers (one per weight tensor).
ep_group: The device process group for expert parallelism.
is_profile (bool): If `True`, do not perform any actual weight copy.
This is used during profile run, where we only perform dummy
communications to reserve enough memory for the buffers.
cuda_stream: CUDA stream for async copies (can be None for sync mode).
rank_mapping: Optional rank mapping for elastic expert parallelism.
Returns:
is_unchanged (np.ndarray): (num_local_experts,), True where expert
is left unchanged.
is_received_locally (np.ndarray): (num_local_experts,), True where expert
can be received locally.
RecvMetadata: Metadata needed for completing remote weight transfers.
"""
ep_size = ep_group.size()
if rank_mapping is not None:
# Add a layer dimension for compatibility with mapping functions
old_layer_indices_2d = old_layer_indices.unsqueeze(0)
new_layer_indices_2d = new_layer_indices.unsqueeze(0)
if len(rank_mapping) == ep_group.size():
# scale down
new_global_expert_indices = _map_new_expert_indices_with_rank_mapping(
new_global_expert_indices,
new_layer_indices_2d = _map_new_expert_indices_with_rank_mapping(
new_layer_indices_2d,
rank_mapping,
)
else:
# scale up
old_global_expert_indices = _map_old_expert_indices_with_rank_mapping(
old_global_expert_indices,
old_layer_indices_2d = _map_old_expert_indices_with_rank_mapping(
old_layer_indices_2d,
rank_mapping,
ep_group.size(),
)
assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1]
num_moe_layers, num_physical_experts = old_global_expert_indices.shape
assert len(expert_weights) == num_moe_layers
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
# Remove the layer dimension
old_layer_indices = old_layer_indices_2d.squeeze(0)
new_layer_indices = new_layer_indices_2d.squeeze(0)
assert old_layer_indices.shape == new_layer_indices.shape
num_physical_experts = old_layer_indices.shape[0]
num_local_physical_experts = next(iter(expert_weights)).shape[0]
assert num_physical_experts == ep_size * num_local_physical_experts
is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer(
old_layer_indices_np = old_layer_indices.cpu().numpy()
new_layer_indices_np = new_layer_indices.cpu().numpy()
is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
num_local_experts=num_local_physical_experts,
old_indices=old_global_expert_indices[layer].tolist(),
new_indices=new_global_expert_indices[layer].tolist(),
expert_weights=expert_weights[layer],
expert_weights_buffer=expert_weights_buffer,
old_indices=old_layer_indices_np,
new_indices=new_layer_indices_np,
expert_weights=expert_weights,
expert_weights_buffers=expert_weights_buffer,
cuda_stream=cuda_stream,
ep_group=ep_group,
)
return is_unchanged, is_received_locally, experts_recv_loc
return is_unchanged, is_received_locally, recv_metadata
def rearrange_expert_weights_inplace(
@ -388,19 +567,17 @@ def rearrange_expert_weights_inplace(
ep_size = ep_group.size()
assert num_physical_experts == ep_size * num_local_physical_experts
# A buffer to hold the expert weights in one layer during the exchange.
first_layer_weights = list(expert_weights[0])
# Buffers to hold the expert weights during the exchange.
# NOTE: Currently we assume the same weights across different layers
# have the same shape.
expert_weights_buffer = [torch.empty_like(w) for w in expert_weights[0]]
weights_buffer: list[torch.Tensor] = [
torch.empty_like(w) for w in first_layer_weights
]
if is_profile:
# Maximum send size is to send all local experts to all ranks,
# So we use a dummy `all_gather` to reserve enough communication buffer
for weight, buffer in zip(expert_weights[0], expert_weights_buffer):
# A `/dev/null`-like buffer to avoid real memory allocation
# Reserve communication buffers via a minimal dummy all_gather on first layer
for weight, buffer in zip(expert_weights[0], weights_buffer):
dummy_recv_buffer = [buffer for _ in range(ep_size)]
# NOTE(bowen): Needed this barrier to avoid OOM during actual
# execution. I'm not very sure why this is needed
torch.distributed.barrier()
all_gather(
dummy_recv_buffer,
@ -409,32 +586,32 @@ def rearrange_expert_weights_inplace(
)
return
old_global_expert_indices_cpu = old_global_expert_indices.cpu()
new_global_expert_indices_cpu = new_global_expert_indices.cpu()
# NOTE(bowen): We need this synchronize to run, but I don't know why.
# If you figure out the reason, please let me know -- thank you!
torch.cuda.synchronize()
for layer in range(num_moe_layers):
is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer(
old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy()
new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy()
for layer_idx in range(num_moe_layers):
is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
num_local_experts=num_local_physical_experts,
old_indices=old_global_expert_indices_cpu[layer].tolist(),
new_indices=new_global_expert_indices_cpu[layer].tolist(),
expert_weights=expert_weights[layer],
expert_weights_buffer=expert_weights_buffer,
old_indices=old_global_expert_indices_cpu[layer_idx],
new_indices=new_global_expert_indices_cpu[layer_idx],
expert_weights=expert_weights[layer_idx],
expert_weights_buffers=weights_buffer,
cuda_stream=None,
ep_group=ep_group,
)
move_from_buffer(
expert_weights=expert_weights[layer],
expert_weights_buffer=expert_weights_buffer,
expert_weights=expert_weights[layer_idx],
expert_weights_buffers=weights_buffer,
is_unchanged=is_unchanged,
is_received_locally=is_received_locally,
experts_recv_loc=experts_recv_loc,
new_indices=new_global_expert_indices[layer].tolist(),
ep_group=ep_group,
recv_metadata=recv_metadata,
new_indices=new_global_expert_indices_cpu[layer_idx],
ep_rank=ep_group.rank(),
)
@ -526,4 +703,4 @@ def _map_new_expert_indices_with_rank_mapping(
return mapped_expert_indices
__all__ = ["transfer_layer", "move_from_buffer"]
__all__ = ["transfer_layer", "move_from_buffer", "RecvMetadata"]