Optimize weight rearrange with numpy

Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
ilmarkov 2025-11-25 14:34:22 +00:00
parent f4df2af946
commit a46c72ac71
6 changed files with 394 additions and 226 deletions

View File

@ -15,7 +15,7 @@ from vllm.utils.system_utils import update_environment_variables
mp.set_start_method("spawn", force=True)
def distributed_run(fn, world_size, *args):
def distributed_run(fn, world_size, *args, max_grouped_layers=1):
number_of_processes = world_size
processes: list[mp.Process] = []
for i in range(number_of_processes):
@ -26,6 +26,7 @@ def distributed_run(fn, world_size, *args):
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = "12345"
env["VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS"] = str(max_grouped_layers)
p = mp.Process(target=fn, args=(env, world_size, *args))
processes.append(p)
p.start()

View File

@ -286,15 +286,17 @@ def _test_async_transfer_layer_without_mtp_worker(
device,
old_indices,
)
old_indices_cpu = old_indices.cpu()
new_indices_cpu = new_indices.cpu()
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
cuda_stream = torch.cuda.Stream(device=device)
for layer_idx in range(num_layers):
is_unchanged, is_received_locally, experts_recv_loc = asyncio.run(
is_unchanged, is_received_locally, recv_metadata = asyncio.run(
transfer_layer(
old_global_expert_indices=old_indices,
new_global_expert_indices=new_indices,
old_global_expert_indices=old_indices_cpu,
new_global_expert_indices=new_indices_cpu,
expert_weights=expert_weights,
expert_weights_buffer=expert_buffer,
ep_group=ep_group,
@ -302,15 +304,14 @@ def _test_async_transfer_layer_without_mtp_worker(
cuda_stream=cuda_stream,
)
)
cuda_stream.synchronize()
move_from_buffer(
expert_weights=expert_weights[layer_idx],
expert_weights_buffer=expert_buffer,
weights_group=[expert_weights[layer_idx]],
buffers_group=[expert_buffer],
is_unchanged=is_unchanged,
is_received_locally=is_received_locally,
experts_recv_loc=experts_recv_loc,
new_indices=new_indices[layer_idx].tolist(),
recv_metadata=recv_metadata,
new_indices_group=new_indices_cpu[layer_idx : layer_idx + 1],
ep_group=ep_group,
)
@ -426,8 +427,9 @@ def _test_rearrange_expert_weights_with_redundancy(
(4, 8, 8, 16),
],
)
@pytest.mark.parametrize("group_layers", [1, 2])
def test_rearrange_expert_weights_with_redundancy(
world_size, num_layers, num_local_experts, num_logical_experts
world_size, num_layers, num_local_experts, num_logical_experts, group_layers
):
"""Test the functionality of rearranging expert weights with redundancy."""
@ -439,6 +441,7 @@ def test_rearrange_expert_weights_with_redundancy(
num_layers,
num_local_experts,
num_logical_experts,
max_grouped_layers=group_layers,
)

View File

@ -89,7 +89,7 @@ async def transfer_run_periodically(
(
model_state.is_unchanged,
model_state.is_received_locally,
model_state.experts_recv_loc,
model_state.recv_metadata,
) = await transfer_layer(
old_global_expert_indices=model_state.physical_to_logical_map,
new_global_expert_indices=model_state.new_physical_to_logical_map,

View File

@ -31,6 +31,7 @@ import time
from collections.abc import Sequence
from dataclasses import dataclass
import numpy as np
import torch
from torch.distributed import ProcessGroup, all_reduce
@ -164,20 +165,24 @@ class EplbModelState:
"""
Whether the async EPLB needs to poll peers for buffer readiness.
"""
is_unchanged: list[bool]
is_unchanged: np.ndarray
"""
intermediate variable between `move_to_buffer` and `move_to_workspace`.
The size is same as the num of physical experts in the current layer.
"""
is_received_locally: list[bool]
is_received_locally: np.ndarray
"""
intermediate variable between `move_to_buffer` and `move_to_workspace`.
The size is same as the num of physical experts in the current layer.
"""
experts_recv_loc: dict[int, int]
recv_metadata: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
"""
intermediate variable between `move_to_buffer` and `move_to_workspace`.
The size is same as the num of physical experts in the current layer.
The tuple contains:
- recv_primary_mask: np.ndarray, shape (group_size, num_local_experts)
- recv_counts: np.ndarray, shape (group_size,)
- recv_expert_ids: np.ndarray, shape (group_size, num_local_experts)
- recv_dst_rows: np.ndarray, shape (group_size, num_local_experts)
"""
is_async_enabled: bool
"""
@ -498,9 +503,9 @@ class EplbState:
layer_to_transfer=0,
rebalanced=False,
pending_global_ready_check=False,
is_unchanged=[],
is_received_locally=[],
experts_recv_loc={},
is_unchanged=np.array([]),
is_received_locally=np.array([]),
recv_metadata=(np.array([]), np.array([]), np.array([]), np.array([])),
is_async_enabled=self.is_async,
cuda_device_index=self.cuda_device_index,
new_physical_to_logical_map=new_physical_to_logical_map,
@ -847,8 +852,6 @@ class EplbState:
time_end - time_start,
)
else:
device = eplb_model_state.physical_to_logical_map.device
new_physical = new_physical_to_logical_map.to(device)
max_slots = eplb_model_state.logical_to_physical_map.shape[-1]
padded_logical = torch.nn.functional.pad(
new_logical_to_physical_map,
@ -859,7 +862,10 @@ class EplbState:
eplb_model_state.logical_replica_count.device
)
eplb_model_state.new_physical_to_logical_map = new_physical
# Move map to cpu in advance
eplb_model_state.new_physical_to_logical_map = (
new_physical_to_logical_map.cpu()
)
eplb_model_state.new_logical_to_physical_map = padded_logical
eplb_model_state.new_logical_replica_count = new_replica
@ -958,17 +964,21 @@ class EplbState:
stream = torch.cuda.current_stream(device=device_index)
stream.wait_event(model_state.buffer_ready_event)
model_state.buffer_ready_event = None
weights_group = [
model_state.model.expert_weights[model_state.layer_to_transfer]
]
buffers_group = [model_state.expert_buffer]
move_from_buffer(
expert_weights=model_state.model.expert_weights[
model_state.layer_to_transfer
],
expert_weights_buffer=model_state.expert_buffer,
weights_group=weights_group,
buffers_group=buffers_group,
is_unchanged=model_state.is_unchanged,
is_received_locally=model_state.is_received_locally,
experts_recv_loc=model_state.experts_recv_loc,
new_indices=model_state.new_physical_to_logical_map[
model_state.layer_to_transfer
].tolist(),
recv_metadata=model_state.recv_metadata,
new_indices_group=model_state.new_physical_to_logical_map[
model_state.layer_to_transfer : model_state.layer_to_transfer + 1
]
.cpu()
.numpy(),
ep_group=ep_group,
)
transferred_layer = model_state.layer_to_transfer

View File

@ -6,9 +6,10 @@ The actual execution of the rearrangement.
This involves the exchange of expert weights between GPUs.
"""
from collections.abc import Iterable, MutableSequence, Sequence
from collections.abc import Iterable, Sequence
from functools import partial
import numpy as np
import torch
from torch.distributed import (
P2POp,
@ -18,6 +19,11 @@ from torch.distributed import (
get_global_rank,
)
import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
def idx_local_to_global(
local_idx: int,
@ -54,9 +60,9 @@ def global_idx_to_rank(
def get_ep_ranks_with_expert(
idx: int,
num_local_experts: int,
old_indices: Sequence[int],
new_indices: Sequence[int],
) -> tuple[MutableSequence[int], MutableSequence[int]]:
old_indices: np.ndarray,
new_indices: np.ndarray,
) -> tuple[list[int], list[int]]:
"""
Get the ranks of the experts that need to be exchanged.
@ -71,161 +77,227 @@ def get_ep_ranks_with_expert(
- The ranks of the experts that need to be sent.
- The ranks of the experts that need to be received.
"""
global2rank = partial(
global_idx_to_rank,
local_cnt=num_local_experts,
)
ranks_to_send: list[int] = []
ranks_to_recv: list[int] = []
for i, e in enumerate(old_indices):
if e == idx:
rank = global2rank(i)
if not ranks_to_send or ranks_to_send[-1] != rank:
ranks_to_send.append(rank)
for i, e in enumerate(new_indices):
if e == idx:
rank = global2rank(i)
if not ranks_to_recv or ranks_to_recv[-1] != rank:
ranks_to_recv.append(rank)
# Remove those ranks that can get this expert locally.
# Indices where expert idx appears
old_pos = np.nonzero(old_indices == idx)[0]
new_pos = np.nonzero(new_indices == idx)[0]
# Map positions to ranks
if old_pos.size > 0:
old_ranks = old_pos // num_local_experts
uniq_send, first_idx_send = np.unique(old_ranks, return_index=True)
order_send = np.argsort(first_idx_send)
ranks_to_send = uniq_send[order_send].astype(int).tolist()
else:
ranks_to_send = []
if new_pos.size > 0:
new_ranks = new_pos // num_local_experts
uniq_recv, first_idx_recv = np.unique(new_ranks, return_index=True)
order_recv = np.argsort(first_idx_recv)
ranks_to_recv = uniq_recv[order_recv].astype(int).tolist()
else:
ranks_to_recv = []
# Remove ranks that have local copies to avoid unnecessary recv
ranks_to_send_set = set(ranks_to_send)
ranks_to_recv_actual = [
rank for rank in ranks_to_recv if rank not in ranks_to_send_set
]
ranks_to_recv_actual = [r for r in ranks_to_recv if r not in ranks_to_send_set]
return ranks_to_send, ranks_to_recv_actual
def move_to_buffer(
num_local_experts: int,
old_indices: Sequence[int],
new_indices: Sequence[int],
expert_weights: Iterable[torch.Tensor],
expert_weights_buffer: Sequence[torch.Tensor],
old_indices_group: np.ndarray,
new_indices_group: np.ndarray,
expert_weights_group: Sequence[Iterable[torch.Tensor]],
buffers_group: Sequence[Sequence[torch.Tensor]],
cuda_stream: torch.cuda.Stream | None,
ep_group: ProcessGroup,
) -> tuple[list[bool], list[bool], dict[int, int]]:
) -> tuple[
np.ndarray, np.ndarray, tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
]:
"""
Perform expert weights rearrangement of one layer.
Perform expert weights rearrangement of a group of layers.
"""
assert len(old_indices_group) == len(new_indices_group) == len(expert_weights_group)
group_size = len(old_indices_group)
ep_rank = ep_group.rank()
local2global = partial(
idx_local_to_global,
local_cnt=num_local_experts,
ep_rank=ep_rank,
)
# 0. Do nothing for experts that did not change.
is_unchanged = [
old_indices[local2global(i)] == new_indices[local2global(i)]
for i in range(num_local_experts)
]
# Pre-allocate per-layer compact maps/masks (numpy)
is_unchanged = np.zeros((group_size, num_local_experts), dtype=np.bool_)
is_received_locally = np.zeros((group_size, num_local_experts), dtype=np.bool_)
recv_primary_mask = np.zeros((group_size, num_local_experts), dtype=np.bool_)
send_counts = np.zeros(group_size, dtype=np.int32)
send_expert_ids = np.full((group_size, num_local_experts), -1, dtype=np.int64)
send_src_rows = np.full((group_size, num_local_experts), -1, dtype=np.int32)
recv_counts = np.zeros(group_size, dtype=np.int32)
recv_expert_ids = np.full((group_size, num_local_experts), -1, dtype=np.int64)
recv_dst_rows = np.full((group_size, num_local_experts), -1, dtype=np.int32)
base = ep_rank * num_local_experts
local_rows = np.arange(num_local_experts, dtype=np.int32)
local_global = base + local_rows
# 1. Perform weight copy inside the local rank.
is_received_locally = is_unchanged[:]
for src in range(num_local_experts):
src_global = local2global(src)
# Build masks and expert maps per layer
for layer_idx in range(group_size):
old_indices = old_indices_group[layer_idx]
layer_new_indices = new_indices_group[layer_idx]
old_local_expert_ids = old_indices[local_global]
new_local_expert_ids = layer_new_indices[local_global]
# Unchanged per-dst mask
unchanged_mask = old_local_expert_ids == new_local_expert_ids
is_unchanged[layer_idx, :] = unchanged_mask
# Local receive eligibility
new_valid = new_local_expert_ids != -1
can_recv_local = np.isin(
new_local_expert_ids, old_local_expert_ids, assume_unique=False
)
is_local_recv = np.logical_or(
unchanged_mask, np.logical_and(new_valid, can_recv_local)
)
is_received_locally[layer_idx, :] = is_local_recv
# Send map: first src row per unique expert present locally in old mapping
valid_old = old_local_expert_ids != -1
if np.any(valid_old):
uniq_experts, first_idx = np.unique(
old_local_expert_ids[valid_old], return_index=True
)
filtered_rows = local_rows[valid_old]
src_rows = filtered_rows[first_idx]
layer_send_count = int(uniq_experts.shape[0])
send_counts[layer_idx] = layer_send_count
send_expert_ids[layer_idx, :layer_send_count] = uniq_experts
send_src_rows[layer_idx, :layer_send_count] = src_rows
else:
send_counts[layer_idx] = 0
# Recv map: primary dst per unique expert needed remotely
need_recv_mask = np.logical_and(~is_local_recv, new_valid)
if np.any(need_recv_mask):
desired_experts = new_local_expert_ids[need_recv_mask]
desired_dsts = local_rows[need_recv_mask]
uniq_recv_experts, uniq_indices = np.unique(
desired_experts, return_index=True
)
dst_rows = desired_dsts[uniq_indices]
layer_send_count = int(uniq_recv_experts.shape[0])
recv_counts[layer_idx] = layer_send_count
recv_expert_ids[layer_idx, :layer_send_count] = uniq_recv_experts
recv_dst_rows[layer_idx, :layer_send_count] = dst_rows
recv_primary_mask[layer_idx, dst_rows] = True
else:
recv_counts[layer_idx] = 0
# 1. Local moves into tmp buffers
for layer_idx in range(group_size):
layer_is_unchanged = is_unchanged[layer_idx, :]
layer_is_received_locally = is_received_locally[layer_idx, :]
layer_new_indices = new_indices_group[layer_idx]
layer_send_count = int(send_counts[layer_idx])
layer_send_experts = send_expert_ids[layer_idx, :layer_send_count]
layer_send_srcs = send_src_rows[layer_idx, :layer_send_count]
local2global = partial(
idx_local_to_global,
local_cnt=num_local_experts,
ep_rank=ep_rank,
)
layer_weights_list = list(expert_weights_group[layer_idx])
layer_buffers_list = list(buffers_group[layer_idx])
for dst in range(num_local_experts):
if layer_is_unchanged[dst] or not layer_is_received_locally[dst]:
continue
dst_global = local2global(dst)
if is_received_locally[dst]:
expert = layer_new_indices[dst_global]
if expert == -1:
continue
if old_indices[src_global] == -1 or new_indices[dst_global] == -1:
matches = np.nonzero(layer_send_experts == expert)[0]
if matches.size == 0:
continue
if old_indices[src_global] == new_indices[dst_global]:
is_received_locally[dst] = True
for weight, buffer in zip(expert_weights, expert_weights_buffer):
with torch.cuda.stream(cuda_stream):
buffer[dst].copy_(weight[src], non_blocking=True)
src_local = int(layer_send_srcs[matches[0]])
for w, b in zip(layer_weights_list, layer_buffers_list):
b[dst].copy_(w[src_local])
p2p_ops: list[P2POp] = []
# 2. Initiate sending of weights.
experts_send_loc: dict[int, int] = {}
for src in range(num_local_experts):
expert = old_indices[local2global(src)]
if expert == -1:
# 2. Post sends per layer
for layer_idx in range(group_size):
old_indices = old_indices_group[layer_idx]
layer_new_indices = new_indices_group[layer_idx]
layer_weights_list = list(expert_weights_group[layer_idx])
layer_send_count = int(send_counts[layer_idx])
if layer_send_count == 0:
continue
if expert in experts_send_loc:
experts = send_expert_ids[layer_idx, :layer_send_count]
srcs = send_src_rows[layer_idx, :layer_send_count]
order = np.argsort(experts, kind="stable")
experts = experts[order]
srcs = srcs[order]
for expert, src in zip(experts.tolist(), srcs.tolist()):
ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert(
expert,
num_local_experts,
old_indices,
layer_new_indices,
)
if not ranks_to_send or not ranks_to_recv:
continue
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
sender_pos = ranks_to_send.index(ep_rank)
recv_begin = sender_pos * num_dst_per_sender
recv_end = recv_begin + num_dst_per_sender
recv_ranks = ranks_to_recv[recv_begin:recv_end]
remainder_start = len(ranks_to_send) * num_dst_per_sender
recver_pos = remainder_start + sender_pos
if recver_pos < len(ranks_to_recv):
recv_ranks.append(ranks_to_recv[recver_pos])
for dst in recv_ranks:
dst_global = get_global_rank(ep_group, dst)
p2p_ops += [
P2POp(
torch.distributed.isend,
w[src],
dst_global,
)
for w in layer_weights_list
]
# 3. Post recvs per layer
for layer_idx in range(group_size):
old_indices = old_indices_group[layer_idx]
layer_new_indices = new_indices_group[layer_idx]
layer_buffers_list = list(buffers_group[layer_idx])
layer_recv_count = int(recv_counts[layer_idx])
if layer_recv_count == 0:
continue
experts_send_loc[expert] = src
# We need to sort here to match send/recv
for expert, src in sorted(experts_send_loc.items()):
ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert(
expert,
num_local_experts,
old_indices,
new_indices,
)
# Calculate the ranks to send by this rank
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
sender_pos = ranks_to_send.index(ep_rank)
recv_begin = sender_pos * num_dst_per_sender
recv_end = recv_begin + num_dst_per_sender
recv_ranks = ranks_to_recv[recv_begin:recv_end]
# Tackle remainders
remainder_start = len(ranks_to_send) * num_dst_per_sender
recver_pos = remainder_start + sender_pos
if recver_pos < len(ranks_to_recv):
recv_ranks.append(ranks_to_recv[recver_pos])
for dst in recv_ranks:
dst_global = get_global_rank(ep_group, dst)
experts = recv_expert_ids[layer_idx, :layer_recv_count]
dsts = recv_dst_rows[layer_idx, :layer_recv_count]
order = np.argsort(experts, kind="stable")
experts = experts[order]
dsts = dsts[order]
for expert, dst in zip(experts.tolist(), dsts.tolist()):
ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert(
expert,
num_local_experts,
old_indices,
layer_new_indices,
)
if not ranks_to_send or not ranks_to_recv:
continue
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
recver_pos = ranks_to_recv.index(ep_rank)
remainder_start = len(ranks_to_send) * num_dst_per_sender
if recver_pos < remainder_start:
src = ranks_to_send[recver_pos // num_dst_per_sender]
else:
src = ranks_to_send[recver_pos - remainder_start]
src_global = get_global_rank(ep_group, src)
p2p_ops += [
P2POp(
torch.distributed.isend,
weight[src],
dst_global,
torch.distributed.irecv,
b[dst],
src_global,
)
for weight in expert_weights
for b in layer_buffers_list
]
# 3. Initiate receiving of weights.
experts_recv_loc: dict[int, int] = {}
for dst in range(num_local_experts):
if is_received_locally[dst]:
continue
expert = new_indices[local2global(dst)]
if expert == -1:
continue
if expert in experts_recv_loc:
continue
experts_recv_loc[expert] = dst
# We need to sort here to match send/recv
for expert, dst in sorted(experts_recv_loc.items()):
ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert(
expert,
num_local_experts,
old_indices,
new_indices,
)
# Calculate the rank to recv by this rank
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
recver_pos = ranks_to_recv.index(ep_rank)
remainder_start = len(ranks_to_send) * num_dst_per_sender
if recver_pos < remainder_start:
src = ranks_to_send[recver_pos // num_dst_per_sender]
else:
src = ranks_to_send[recver_pos - remainder_start]
src_global = get_global_rank(ep_group, src)
p2p_ops += [
P2POp(
torch.distributed.irecv,
weight[dst],
src_global,
)
for weight in expert_weights_buffer
]
# 4. Execute the P2P operations. The real communication happens here.
if p2p_ops and cuda_stream is not None:
with torch.cuda.stream(cuda_stream):
@ -237,38 +309,98 @@ def move_to_buffer(
for req in reqs:
req.wait()
# wait for the communication to finish
return is_unchanged, is_received_locally, experts_recv_loc
return (
is_unchanged,
is_received_locally,
(recv_primary_mask, recv_counts, recv_expert_ids, recv_dst_rows),
)
def move_from_buffer(
expert_weights: Iterable[torch.Tensor],
expert_weights_buffer: list[torch.Tensor],
is_unchanged: list[bool],
is_received_locally: list[bool],
experts_recv_loc: dict[int, int],
new_indices: Sequence[int],
weights_group: Sequence[Iterable[torch.Tensor]],
buffers_group: Sequence[Sequence[torch.Tensor]],
is_unchanged: np.ndarray,
is_received_locally: np.ndarray,
recv_metadata: tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray],
new_indices_group: np.ndarray,
ep_group: ProcessGroup,
) -> None:
assert (
len(weights_group)
== len(buffers_group)
== len(is_unchanged)
== len(is_received_locally)
== len(recv_metadata[0])
== len(new_indices_group)
), "Unmatching layer group size"
ep_rank = ep_group.rank()
num_local_experts = len(is_unchanged)
local2global = partial(
idx_local_to_global, local_cnt=num_local_experts, ep_rank=ep_rank
)
for dst in range(num_local_experts):
if is_unchanged[dst]:
group_size = len(is_unchanged)
recv_primary_mask, recv_counts, recv_expert_ids, recv_dst_rows = recv_metadata
num_local_experts = is_unchanged.shape[1]
# Mask for rows to copy back from buffers:
# copy if locally received OR remote primary recv
copy_mask = np.logical_or(is_received_locally, recv_primary_mask)
# Copy back local buffered rows into destination weights
for layer_idx in range(group_size):
layer_is_unchanged = is_unchanged[layer_idx, :]
layer_copy_mask = copy_mask[layer_idx, :]
weights_list = list(weights_group[layer_idx])
buffers_list = list(buffers_group[layer_idx])
# rows to copy = (~unchanged) & copy_mask
dest_mask_np = np.logical_and(~layer_is_unchanged, layer_copy_mask)
if not bool(dest_mask_np.any()):
continue
if is_received_locally[dst]:
for weight, buffer in zip(expert_weights, expert_weights_buffer):
weight[dst].copy_(buffer[dst], non_blocking=True)
else:
expert = new_indices[local2global(dst)]
if expert == -1:
continue
src = experts_recv_loc[expert]
for weight, buffer in zip(expert_weights, expert_weights_buffer):
weight[dst].copy_(buffer[src], non_blocking=True)
dest_indices = np.nonzero(dest_mask_np)[0].tolist()
for dst in dest_indices:
for w, b in zip(weights_list, buffers_list):
w[dst].copy_(b[dst])
# Duplicate remote received rows to non-primary duplicate dsts
for layer_idx in range(group_size):
layer_is_unchanged = is_unchanged[layer_idx, :]
layer_is_received_locally = is_received_locally[layer_idx, :]
new_indices = new_indices_group[layer_idx]
weights_list = list(weights_group[layer_idx])
count_recv = int(recv_counts[layer_idx])
if count_recv == 0:
# No remote primaries on this layer → no remote duplicates to materialize
continue
# Local view of desired expert ids per local row
base = ep_rank * num_local_experts
local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)]
# Duplicate rows mask: need remote, not primary, and valid expert id
duplicate_mask = np.logical_and(
np.logical_and(~layer_is_unchanged, ~layer_is_received_locally),
np.logical_and(~recv_primary_mask[layer_idx, :], local_experts != -1),
)
if not bool(duplicate_mask.any()):
continue
dup_dst_rows = np.nonzero(duplicate_mask)[0]
dup_experts = local_experts[dup_dst_rows]
# Build primary mapping arrays (expert -> primary dst) and vector-match
prim_experts = recv_expert_ids[layer_idx, :count_recv]
prim_dsts = recv_dst_rows[layer_idx, :count_recv]
order = np.argsort(prim_experts, kind="stable")
prim_experts_sorted = prim_experts[order]
prim_dsts_sorted = prim_dsts[order]
pos = np.searchsorted(prim_experts_sorted, dup_experts)
# Filter to experts that have a matching primary entry
valid = np.logical_and(
pos < prim_experts_sorted.shape[0],
prim_experts_sorted[np.minimum(pos, prim_experts_sorted.shape[0] - 1)]
== dup_experts,
)
if not bool(valid.any()):
continue
matched_dst_rows = dup_dst_rows[valid]
matched_src_rows = prim_dsts_sorted[pos[valid]]
# Perform row copies per (dst, src) pair without tensor indexing
for dst, src in zip(matched_dst_rows.tolist(), matched_src_rows.tolist()):
for w in weights_list:
w[dst].copy_(w[src])
async def transfer_layer(
@ -281,7 +413,9 @@ async def transfer_layer(
layer: int = 0,
cuda_stream: torch.cuda.Stream | None = None,
rank_mapping: dict[int, int] | None = None,
) -> tuple[list[bool], list[bool], dict[int, int]]:
) -> tuple[
np.ndarray, np.ndarray, tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]
]:
"""
Rearranges the expert weights in place according to the new expert indices.
@ -322,20 +456,20 @@ async def transfer_layer(
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
assert num_physical_experts == ep_size * num_local_physical_experts
# A buffer to hold the expert weights in one layer during the exchange.
# NOTE: Currently we assume the same weights across different layers
# have the same shape.
is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer(
old_global_expert_indices_np = old_global_expert_indices.cpu().numpy()
new_global_expert_indices_np = new_global_expert_indices.cpu().numpy()
is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
num_local_experts=num_local_physical_experts,
old_indices=old_global_expert_indices[layer].tolist(),
new_indices=new_global_expert_indices[layer].tolist(),
expert_weights=expert_weights[layer],
expert_weights_buffer=expert_weights_buffer,
old_indices_group=old_global_expert_indices_np[layer : layer + 1],
new_indices_group=new_global_expert_indices_np[layer : layer + 1],
expert_weights_group=[expert_weights[layer]],
buffers_group=[expert_weights_buffer],
cuda_stream=cuda_stream,
ep_group=ep_group,
)
return is_unchanged, is_received_locally, experts_recv_loc
return is_unchanged, is_received_locally, recv_metadata
def rearrange_expert_weights_inplace(
@ -391,54 +525,69 @@ def rearrange_expert_weights_inplace(
ep_size = ep_group.size()
assert num_physical_experts == ep_size * num_local_physical_experts
# A buffer to hold the expert weights in one layer during the exchange.
# Max number of layers to group for communication
max_group_layers = envs.VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS
max_group_layers = max(min(max_group_layers, num_moe_layers), 1)
logger.info_once(
f"EPLB Sync: rearrange max_group_layers: {max_group_layers}", scope="global"
)
first_layer_weights = list(expert_weights[0])
# Buffers to hold the expert weights during the exchange.
# NOTE: Currently we assume the same weights across different layers
# have the same shape.
expert_weights_buffer = [torch.empty_like(w) for w in expert_weights[0]]
weights_buffers: list[list[torch.Tensor]] = [
[torch.empty_like(w) for w in first_layer_weights]
for _ in range(max_group_layers)
]
if is_profile:
# Maximum send size is to send all local experts to all ranks,
# So we use a dummy `all_gather` to reserve enough communication buffer
for weight, buffer in zip(expert_weights[0], expert_weights_buffer):
# A `/dev/null`-like buffer to avoid real memory allocation
dummy_recv_buffer = [buffer for _ in range(ep_size)]
# NOTE(bowen): Needed this barrier to avoid OOM during actual
# execution. I'm not very sure why this is needed
torch.distributed.barrier()
all_gather(
dummy_recv_buffer,
weight,
group=ep_group,
)
# Reserve communication buffers via a minimal dummy all_gather on first layer
for layer_idx in range(max_group_layers):
for weight, buffer in zip(expert_weights[0], weights_buffers[layer_idx]):
dummy_recv_buffer = [buffer for _ in range(ep_size)]
torch.distributed.barrier()
all_gather(
dummy_recv_buffer,
weight,
group=ep_group,
)
return
old_global_expert_indices_cpu = old_global_expert_indices.cpu()
new_global_expert_indices_cpu = new_global_expert_indices.cpu()
# NOTE(bowen): We need this synchronize to run, but I don't know why.
# If you figure out the reason, please let me know -- thank you!
torch.cuda.synchronize()
for layer in range(num_moe_layers):
is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer(
old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy()
new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy()
start = 0
while start < num_moe_layers:
end = min(start + max_group_layers, num_moe_layers)
old_group = old_global_expert_indices_cpu[start:end]
new_group = new_global_expert_indices_cpu[start:end]
weights_group = [expert_weights[i] for i in range(start, end)]
buffers_group = weights_buffers[: (end - start)]
is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
num_local_experts=num_local_physical_experts,
old_indices=old_global_expert_indices_cpu[layer].tolist(),
new_indices=new_global_expert_indices_cpu[layer].tolist(),
expert_weights=expert_weights[layer],
expert_weights_buffer=expert_weights_buffer,
old_indices_group=old_group,
new_indices_group=new_group,
expert_weights_group=weights_group,
buffers_group=buffers_group,
cuda_stream=None,
ep_group=ep_group,
)
move_from_buffer(
expert_weights=expert_weights[layer],
expert_weights_buffer=expert_weights_buffer,
weights_group=weights_group,
buffers_group=buffers_group,
is_unchanged=is_unchanged,
is_received_locally=is_received_locally,
experts_recv_loc=experts_recv_loc,
new_indices=new_global_expert_indices_cpu[layer].tolist(),
recv_metadata=recv_metadata,
new_indices_group=new_group,
ep_group=ep_group,
)
start = end
def _map_old_expert_indices_with_rank_mapping(

View File

@ -232,6 +232,7 @@ if TYPE_CHECKING:
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
VLLM_USE_V2_MODEL_RUNNER: bool = False
VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS: int = 1
def get_default_cache_root():
@ -1526,6 +1527,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_V2_MODEL_RUNNER": lambda: bool(
int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0"))
),
# Max number of layers to group in synchronous EPLB weight communication.
"VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS": lambda: int(
os.getenv("VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS", "1")
),
}
# --8<-- [end:env-vars-definition]