Remove layer grouping

Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
ilmarkov 2025-12-11 10:38:32 +00:00
parent fc54d760a6
commit f28720db88
5 changed files with 168 additions and 281 deletions

View File

@ -15,7 +15,7 @@ from vllm.utils.system_utils import update_environment_variables
mp.set_start_method("spawn", force=True) mp.set_start_method("spawn", force=True)
def distributed_run(fn, world_size, *args, max_grouped_layers=1): def distributed_run(fn, world_size, *args):
number_of_processes = world_size number_of_processes = world_size
processes: list[mp.Process] = [] processes: list[mp.Process] = []
for i in range(number_of_processes): for i in range(number_of_processes):
@ -26,7 +26,6 @@ def distributed_run(fn, world_size, *args, max_grouped_layers=1):
env["LOCAL_WORLD_SIZE"] = str(number_of_processes) env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
env["MASTER_ADDR"] = "localhost" env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = "12345" env["MASTER_PORT"] = "12345"
env["VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS"] = str(max_grouped_layers)
p = mp.Process(target=fn, args=(env, world_size, *args)) p = mp.Process(target=fn, args=(env, world_size, *args))
processes.append(p) processes.append(p)
p.start() p.start()

View File

@ -306,12 +306,12 @@ def _test_async_transfer_layer_without_mtp_worker(
) )
cuda_stream.synchronize() cuda_stream.synchronize()
move_from_buffer( move_from_buffer(
weights_group=[expert_weights[layer_idx]], expert_weights=expert_weights[layer_idx],
buffers_group=[expert_buffer], expert_weights_buffers=expert_buffer,
is_unchanged=is_unchanged, is_unchanged=is_unchanged,
is_received_locally=is_received_locally, is_received_locally=is_received_locally,
recv_metadata=recv_metadata, recv_metadata=recv_metadata,
new_indices_group=new_indices_cpu[layer_idx : layer_idx + 1], new_indices=new_indices_cpu[layer_idx],
ep_group=ep_group, ep_group=ep_group,
) )
@ -427,9 +427,8 @@ def _test_rearrange_expert_weights_with_redundancy(
(4, 8, 8, 16), (4, 8, 8, 16),
], ],
) )
@pytest.mark.parametrize("group_layers", [1, 2])
def test_rearrange_expert_weights_with_redundancy( def test_rearrange_expert_weights_with_redundancy(
world_size, num_layers, num_local_experts, num_logical_experts, group_layers world_size, num_layers, num_local_experts, num_logical_experts
): ):
"""Test the functionality of rearranging expert weights with redundancy.""" """Test the functionality of rearranging expert weights with redundancy."""
@ -441,7 +440,6 @@ def test_rearrange_expert_weights_with_redundancy(
num_layers, num_layers,
num_local_experts, num_local_experts,
num_logical_experts, num_logical_experts,
max_grouped_layers=group_layers,
) )

View File

@ -514,7 +514,7 @@ class EplbState:
is_received_locally=np.array([]), is_received_locally=np.array([]),
recv_metadata=RecvMetadata( recv_metadata=RecvMetadata(
recv_primary_mask=np.array([]), recv_primary_mask=np.array([]),
recv_counts=np.array([]), recv_count=0,
recv_expert_ids=np.array([]), recv_expert_ids=np.array([]),
recv_dst_rows=np.array([]), recv_dst_rows=np.array([]),
), ),
@ -989,24 +989,22 @@ class EplbState:
stream = torch.cuda.current_stream(device=device_index) stream = torch.cuda.current_stream(device=device_index)
stream.wait_event(model_state.buffer_ready_event) stream.wait_event(model_state.buffer_ready_event)
model_state.buffer_ready_event = None model_state.buffer_ready_event = None
weights_group = [ expert_weights = model_state.model.expert_weights[
model_state.model.expert_weights[model_state.layer_to_transfer] model_state.layer_to_transfer
] ]
buffers_group = [model_state.expert_buffer] expert_weights_buffer = model_state.expert_buffer
new_indices_group = ( new_indices = (
model_state.new_physical_to_logical_map[ model_state.new_physical_to_logical_map[model_state.layer_to_transfer]
model_state.layer_to_transfer : model_state.layer_to_transfer + 1
]
.cpu() .cpu()
.numpy() .numpy()
) )
move_from_buffer( move_from_buffer(
weights_group=weights_group, expert_weights=expert_weights,
buffers_group=buffers_group, expert_weights_buffers=expert_weights_buffer,
is_unchanged=model_state.is_unchanged, is_unchanged=model_state.is_unchanged,
is_received_locally=model_state.is_received_locally, is_received_locally=model_state.is_received_locally,
recv_metadata=model_state.recv_metadata, recv_metadata=model_state.recv_metadata,
new_indices_group=new_indices_group, new_indices=new_indices,
ep_group=ep_group, ep_group=ep_group,
) )
transferred_layer = model_state.layer_to_transfer transferred_layer = model_state.layer_to_transfer

View File

@ -19,7 +19,6 @@ from torch.distributed import (
get_global_rank, get_global_rank,
) )
import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
@ -30,15 +29,13 @@ class RecvMetadata:
"""Metadata describing remote receives during EPLB rebalancing.""" """Metadata describing remote receives during EPLB rebalancing."""
recv_primary_mask: np.ndarray recv_primary_mask: np.ndarray
"""Mask of (layer_group_size, num_local_experts) """Mask of (num_local_experts,) indicating primary experts received."""
indicating primary experts received.""" recv_count: int
recv_counts: np.ndarray """Number of received experts for the layer."""
"""Number of received experts for each layer."""
recv_expert_ids: np.ndarray recv_expert_ids: np.ndarray
"""Expert ids (layer_group_size, num_local_experts) of remote primary experts.""" """Expert ids (num_local_experts,) of remote primary experts."""
recv_dst_rows: np.ndarray recv_dst_rows: np.ndarray
"""Target expert indices (layer_group_size, num_local_experts) """Target expert indices (num_local_experts,) in local tensors to send."""
in local tensors to send."""
# Type alias for the result of move_to_buffer or transfer_layer # Type alias for the result of move_to_buffer or transfer_layer
@ -154,146 +151,105 @@ def get_ep_ranks_with_experts_batch(
def move_to_buffer( def move_to_buffer(
num_local_experts: int, num_local_experts: int,
old_indices_group: np.ndarray, old_indices: np.ndarray,
new_indices_group: np.ndarray, new_indices: np.ndarray,
expert_weights_group: Sequence[Iterable[torch.Tensor]], expert_weights: Iterable[torch.Tensor],
buffers_group: Sequence[Sequence[torch.Tensor]], expert_weights_buffers: Sequence[torch.Tensor],
cuda_stream: torch.cuda.Stream | None, cuda_stream: torch.cuda.Stream | None,
ep_group: ProcessGroup, ep_group: ProcessGroup,
) -> MoveToBufferResult: ) -> MoveToBufferResult:
""" """
Rearranges expert weights across a group of layers Rearranges expert weights during EPLB rebalancing.
during mixture-of-experts (MoE) expert parallel rebalancing.
Args: Args:
num_local_experts: Number of local experts. num_local_experts: Number of local experts.
old_indices_group: (num_layers, num_experts_total) ndarray of current old_indices: (num_experts_total,) ndarray of current (old)
(old) global-to-local expert assignments. global-to-local expert assignments.
new_indices_group: (num_layers, num_experts_total) ndarray of desired new_indices: (num_experts_total,) ndarray of desired (new)
(new) global-to-local assignments after rebalance. global-to-local assignments after rebalance.
expert_weights_group: Original expert weights for each layer. expert_weights: Original expert weights for the layer.
buffers_group: List of per-layer intermediate buffers (one per tensor). expert_weights_buffers: Intermediate buffers (one per tensor).
cuda_stream: CUDA stream for async copies (can be None for sync mode). cuda_stream: CUDA stream for async copies (can be None for sync mode).
ep_group: Distributed process group for expert parallel comms. ep_group: Distributed process group for expert parallel comms.
Returns: Returns:
is_unchanged (np.ndarray): (num_layers, num_local_experts), True where an is_unchanged (np.ndarray): (num_local_experts,), True where an expert row
expert row is unchanged after rebalance. is unchanged after rebalance.
is_received_locally (np.ndarray): (num_layers, num_local_experts), True is_received_locally (np.ndarray): (num_local_experts,), True where a row
where a row can be updated from local data. can be updated from local data.
RecvMetadata: Metadata needed for completing remote weight transfers. RecvMetadata: Metadata needed for completing remote weight transfers.
""" """
assert len(old_indices_group) == len(new_indices_group) == len(expert_weights_group) assert old_indices.shape == new_indices.shape
group_size = len(old_indices_group)
ep_rank = ep_group.rank() ep_rank = ep_group.rank()
# Pre-allocate per-layer compact maps/masks (numpy) recv_primary_mask = np.zeros((num_local_experts,), dtype=np.bool_)
is_unchanged = np.zeros((group_size, num_local_experts), dtype=np.bool_) send_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64)
is_received_locally = np.zeros((group_size, num_local_experts), dtype=np.bool_) send_src_rows = np.full((num_local_experts,), -1, dtype=np.int32)
recv_primary_mask = np.zeros((group_size, num_local_experts), dtype=np.bool_) recv_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64)
# Cache desired new expert ids per local row, for all layers recv_dst_rows = np.full((num_local_experts,), -1, dtype=np.int32)
new_local_expert_ids_mat = np.full(
(group_size, num_local_experts), -1, dtype=np.int64
)
send_counts = np.zeros(group_size, dtype=np.int32)
send_expert_ids = np.full((group_size, num_local_experts), -1, dtype=np.int64)
send_src_rows = np.full((group_size, num_local_experts), -1, dtype=np.int32)
recv_counts = np.zeros(group_size, dtype=np.int32)
recv_expert_ids = np.full((group_size, num_local_experts), -1, dtype=np.int64)
recv_dst_rows = np.full((group_size, num_local_experts), -1, dtype=np.int32)
base = ep_rank * num_local_experts base = ep_rank * num_local_experts
local_rows = np.arange(num_local_experts, dtype=np.int32) local_rows = np.arange(num_local_experts, dtype=np.int32)
local_global = base + local_rows local_global = base + local_rows
# Build masks and expert maps per layer old_local_expert_ids = old_indices[local_global]
for layer_idx in range(group_size): new_local_expert_ids = new_indices[local_global]
old_indices = old_indices_group[layer_idx]
layer_new_indices = new_indices_group[layer_idx]
old_local_expert_ids = old_indices[local_global] # Unchanged mask
new_local_expert_ids = layer_new_indices[local_global] is_unchanged = old_local_expert_ids == new_local_expert_ids
new_local_expert_ids_mat[layer_idx, :] = new_local_expert_ids
# Unchanged per-dst mask # Local receive eligibility
unchanged_mask = old_local_expert_ids == new_local_expert_ids new_valid = new_local_expert_ids != -1
is_unchanged[layer_idx, :] = unchanged_mask can_recv_local = np.isin(
new_local_expert_ids, old_local_expert_ids, assume_unique=False
)
is_received_locally = np.logical_or(
is_unchanged, np.logical_and(new_valid, can_recv_local)
)
# Local receive eligibility # Send map: first src row per unique expert present locally in old mapping
new_valid = new_local_expert_ids != -1 send_count = 0
can_recv_local = np.isin( valid_old = old_local_expert_ids != -1
new_local_expert_ids, old_local_expert_ids, assume_unique=False if np.any(valid_old):
uniq_experts, first_idx = np.unique(
old_local_expert_ids[valid_old], return_index=True
) )
is_local_recv = np.logical_or( filtered_rows = local_rows[valid_old]
unchanged_mask, np.logical_and(new_valid, can_recv_local) src_rows = filtered_rows[first_idx]
) send_count = int(uniq_experts.shape[0])
is_received_locally[layer_idx, :] = is_local_recv send_expert_ids[:send_count] = uniq_experts
send_src_rows[:send_count] = src_rows
# Send map: first src row per unique expert present locally in old mapping # Recv map: primary dst per unique expert needed remotely
valid_old = old_local_expert_ids != -1 recv_count = 0
if np.any(valid_old): need_recv_mask = np.logical_and(~is_received_locally, new_valid)
uniq_experts, first_idx = np.unique( if np.any(need_recv_mask):
old_local_expert_ids[valid_old], return_index=True desired_experts = new_local_expert_ids[need_recv_mask]
) desired_dsts = local_rows[need_recv_mask]
filtered_rows = local_rows[valid_old] uniq_recv_experts, uniq_indices = np.unique(desired_experts, return_index=True)
src_rows = filtered_rows[first_idx] dst_rows = desired_dsts[uniq_indices]
layer_send_count = int(uniq_experts.shape[0]) recv_count = int(uniq_recv_experts.shape[0])
send_counts[layer_idx] = layer_send_count recv_expert_ids[:recv_count] = uniq_recv_experts
send_expert_ids[layer_idx, :layer_send_count] = uniq_experts recv_dst_rows[:recv_count] = dst_rows
send_src_rows[layer_idx, :layer_send_count] = src_rows recv_primary_mask[dst_rows] = True
else:
send_counts[layer_idx] = 0
# Recv map: primary dst per unique expert needed remotely
need_recv_mask = np.logical_and(~is_local_recv, new_valid)
if np.any(need_recv_mask):
desired_experts = new_local_expert_ids[need_recv_mask]
desired_dsts = local_rows[need_recv_mask]
uniq_recv_experts, uniq_indices = np.unique(
desired_experts, return_index=True
)
dst_rows = desired_dsts[uniq_indices]
layer_recv_count = int(uniq_recv_experts.shape[0])
recv_counts[layer_idx] = layer_recv_count
recv_expert_ids[layer_idx, :layer_recv_count] = uniq_recv_experts
recv_dst_rows[layer_idx, :layer_recv_count] = dst_rows
recv_primary_mask[layer_idx, dst_rows] = True
else:
recv_counts[layer_idx] = 0
# Precompute per-layer destination mask that actually needs local buffering:
# need change, received locally, and valid target expert id
eligible_local_buffer_mask = np.logical_and( eligible_local_buffer_mask = np.logical_and(
np.logical_and(~is_unchanged, is_received_locally), np.logical_and(~is_unchanged, is_received_locally),
new_local_expert_ids_mat != -1, new_local_expert_ids != -1,
) )
# 1. Local moves into tmp buffers # 1. Local moves into tmp buffers
for layer_idx in range(group_size): if bool(eligible_local_buffer_mask.any()) and send_count > 0:
layer_send_count = int(send_counts[layer_idx]) dest_indices = np.nonzero(eligible_local_buffer_mask)[0].tolist()
if layer_send_count <= 0: expert_to_src_map = dict(
continue zip(send_expert_ids[:send_count], send_src_rows[:send_count])
)
layer_send_experts = send_expert_ids[layer_idx, :layer_send_count]
layer_send_srcs = send_src_rows[layer_idx, :layer_send_count]
layer_weights_list = list(expert_weights_group[layer_idx])
layer_buffers_list = list(buffers_group[layer_idx])
new_local_expert_ids = new_local_expert_ids_mat[layer_idx, :]
# Only consider destination rows that are eligible for local buffering
eligible_mask = eligible_local_buffer_mask[layer_idx, :]
if not bool(eligible_mask.any()):
continue
dest_indices = np.nonzero(eligible_mask)[0].tolist()
# Build a map from expert_id to its source row.
expert_to_src_map = dict(zip(layer_send_experts, layer_send_srcs))
for dst in dest_indices: for dst in dest_indices:
expert = new_local_expert_ids[dst] expert = new_local_expert_ids[dst]
src_local = expert_to_src_map.get(expert, -1) src_local = expert_to_src_map.get(expert, -1)
if src_local != -1: if src_local != -1:
for w, b in zip(layer_weights_list, layer_buffers_list): for w, b in zip(expert_weights, expert_weights_buffers):
b[dst].copy_(w[src_local]) b[dst].copy_(w[src_local], non_blocking=True)
p2p_ops: list[P2POp] = [] p2p_ops: list[P2POp] = []
@ -301,16 +257,10 @@ def move_to_buffer(
ep_size = ep_group.size() ep_size = ep_group.size()
rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)} rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)}
# 2. Post sends per layer # 2. Post sends
for layer_idx in range(group_size): if send_count > 0:
old_indices = old_indices_group[layer_idx] experts = send_expert_ids[:send_count]
layer_new_indices = new_indices_group[layer_idx] srcs = send_src_rows[:send_count]
layer_weights_list = list(expert_weights_group[layer_idx])
layer_send_count = int(send_counts[layer_idx])
if layer_send_count == 0:
continue
experts = send_expert_ids[layer_idx, :layer_send_count]
srcs = send_src_rows[layer_idx, :layer_send_count]
order = np.argsort(experts, kind="stable") order = np.argsort(experts, kind="stable")
experts = experts[order] experts = experts[order]
srcs = srcs[order] srcs = srcs[order]
@ -319,7 +269,7 @@ def move_to_buffer(
experts, experts,
num_local_experts, num_local_experts,
old_indices, old_indices,
layer_new_indices, new_indices,
) )
for expert, src in zip(experts.tolist(), srcs.tolist()): for expert, src in zip(experts.tolist(), srcs.tolist()):
@ -344,29 +294,22 @@ def move_to_buffer(
w[src], w[src],
dst_global, dst_global,
) )
for w in layer_weights_list for w in expert_weights
] ]
# 3. Post recvs per layer # 3. Post recvs
for layer_idx in range(group_size): if recv_count > 0:
old_indices = old_indices_group[layer_idx] experts = recv_expert_ids[:recv_count]
layer_new_indices = new_indices_group[layer_idx] dsts = recv_dst_rows[:recv_count]
layer_buffers_list = list(buffers_group[layer_idx])
layer_recv_count = int(recv_counts[layer_idx])
if layer_recv_count == 0:
continue
experts = recv_expert_ids[layer_idx, :layer_recv_count]
dsts = recv_dst_rows[layer_idx, :layer_recv_count]
order = np.argsort(experts, kind="stable") order = np.argsort(experts, kind="stable")
experts = experts[order] experts = experts[order]
dsts = dsts[order] dsts = dsts[order]
# Batch query all experts for this layer
send_map, recv_map = get_ep_ranks_with_experts_batch( send_map, recv_map = get_ep_ranks_with_experts_batch(
experts, experts,
num_local_experts, num_local_experts,
old_indices, old_indices,
layer_new_indices, new_indices,
) )
for expert, dst in zip(experts.tolist(), dsts.tolist()): for expert, dst in zip(experts.tolist(), dsts.tolist()):
@ -388,7 +331,7 @@ def move_to_buffer(
b[dst], b[dst],
src_global, src_global,
) )
for b in layer_buffers_list for b in expert_weights_buffers
] ]
# 4. Execute the P2P operations. The real communication happens here. # 4. Execute the P2P operations. The real communication happens here.
@ -407,7 +350,7 @@ def move_to_buffer(
is_received_locally, is_received_locally,
RecvMetadata( RecvMetadata(
recv_primary_mask=recv_primary_mask, recv_primary_mask=recv_primary_mask,
recv_counts=recv_counts, recv_count=recv_count,
recv_expert_ids=recv_expert_ids, recv_expert_ids=recv_expert_ids,
recv_dst_rows=recv_dst_rows, recv_dst_rows=recv_dst_rows,
), ),
@ -415,111 +358,82 @@ def move_to_buffer(
def move_from_buffer( def move_from_buffer(
weights_group: Sequence[Iterable[torch.Tensor]], expert_weights: Iterable[torch.Tensor],
buffers_group: Sequence[Sequence[torch.Tensor]], expert_weights_buffers: list[torch.Tensor],
is_unchanged: np.ndarray, is_unchanged: np.ndarray,
is_received_locally: np.ndarray, is_received_locally: np.ndarray,
recv_metadata: RecvMetadata, recv_metadata: RecvMetadata,
new_indices_group: np.ndarray, new_indices: np.ndarray,
ep_group: ProcessGroup, ep_group: ProcessGroup,
) -> None: ) -> None:
""" """
Copies expert weights from communication buffers back to the target weight tensors, Copies expert weights from communication buffers back to the target weight tensors
after EPLB rebalancing. after EPLB rebalancing.
Args: Args:
weights_group: Groups of consecutive MoE layers, each containing one or more expert_weights: List of weight tensors for a MoE layer.
weight tensors. expert_weights_buffers: Intermediate buffers matching ``expert_weights``.
buffers_group: Intermediate buffers matching weights_group.. is_unchanged: (num_local_experts,), True where an expert row is unchanged.
is_unchanged: (num_layers, num_local_experts), True is_received_locally: (num_local_experts,), True where a row is updated locally.
where an expert row is unchanged after rebalance.
is_received_locally: (num_layers, num_local_experts), True
where a row can be updated from local data.
recv_metadata: RecvMetadata containing remote receive metadata. recv_metadata: RecvMetadata containing remote receive metadata.
new_indices_group: np.ndarray giving for each layer the mapping from local rows new_indices: (num_experts_total,) mapping from local rows to desired
to desired (possibly global) expert id, after rebalance. (possibly global) expert id, after rebalance.
ep_group: torch.distributed.ProcessGroup for expert parallel communication ep_group: torch.distributed.ProcessGroup for expert parallel communication
domain. domain.
""" """
assert (
len(weights_group)
== len(buffers_group)
== len(is_unchanged)
== len(is_received_locally)
== len(recv_metadata.recv_primary_mask)
== len(new_indices_group)
), "Unmatching layer group size"
ep_rank = ep_group.rank() ep_rank = ep_group.rank()
group_size = len(is_unchanged)
recv_primary_mask = recv_metadata.recv_primary_mask recv_primary_mask = recv_metadata.recv_primary_mask
recv_counts = recv_metadata.recv_counts recv_count = recv_metadata.recv_count
recv_expert_ids = recv_metadata.recv_expert_ids recv_expert_ids = recv_metadata.recv_expert_ids
recv_dst_rows = recv_metadata.recv_dst_rows recv_dst_rows = recv_metadata.recv_dst_rows
num_local_experts = is_unchanged.shape[1] num_local_experts = is_unchanged.shape[0]
# Mask for rows to copy back from buffers: # Mask for rows to copy back from buffers:
# copy if locally received OR remote primary recv # copy if locally received OR remote primary recv
copy_mask = np.logical_or(is_received_locally, recv_primary_mask) copy_mask = np.logical_or(is_received_locally, recv_primary_mask)
# Copy back local buffered rows into destination weights dest_mask_np = np.logical_and(~is_unchanged, copy_mask)
for layer_idx in range(group_size): if bool(dest_mask_np.any()):
layer_is_unchanged = is_unchanged[layer_idx, :]
layer_copy_mask = copy_mask[layer_idx, :]
weights_list = list(weights_group[layer_idx])
buffers_list = list(buffers_group[layer_idx])
# rows to copy = (~unchanged) & copy_mask
dest_mask_np = np.logical_and(~layer_is_unchanged, layer_copy_mask)
if not bool(dest_mask_np.any()):
continue
dest_indices = np.nonzero(dest_mask_np)[0].tolist() dest_indices = np.nonzero(dest_mask_np)[0].tolist()
for dst in dest_indices: for dst in dest_indices:
for w, b in zip(weights_list, buffers_list): for w, b in zip(expert_weights, expert_weights_buffers):
w[dst].copy_(b[dst]) w[dst].copy_(b[dst])
# Duplicate remote received rows to non-primary duplicate dsts # Duplicate remote received rows to non-primary duplicate dsts
for layer_idx in range(group_size): if recv_count == 0:
layer_is_unchanged = is_unchanged[layer_idx, :] return
layer_is_received_locally = is_received_locally[layer_idx, :]
new_indices = new_indices_group[layer_idx]
weights_list = list(weights_group[layer_idx])
count_recv = int(recv_counts[layer_idx]) base = ep_rank * num_local_experts
if count_recv == 0: local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)]
# No remote primaries on this layer → no remote duplicates to materialize duplicate_mask = np.logical_and(
continue np.logical_and(~is_unchanged, ~is_received_locally),
# Local view of desired expert ids per local row np.logical_and(~recv_primary_mask, local_experts != -1),
base = ep_rank * num_local_experts )
local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)] if not bool(duplicate_mask.any()):
# Duplicate rows mask: need remote, not primary, and valid expert id return
duplicate_mask = np.logical_and(
np.logical_and(~layer_is_unchanged, ~layer_is_received_locally),
np.logical_and(~recv_primary_mask[layer_idx, :], local_experts != -1),
)
if not bool(duplicate_mask.any()):
continue
dup_dst_rows = np.nonzero(duplicate_mask)[0]
dup_experts = local_experts[dup_dst_rows]
# Build primary mapping arrays (expert -> primary dst) and vector-match dup_dst_rows = np.nonzero(duplicate_mask)[0]
prim_experts = recv_expert_ids[layer_idx, :count_recv] dup_experts = local_experts[dup_dst_rows]
prim_dsts = recv_dst_rows[layer_idx, :count_recv]
order = np.argsort(prim_experts, kind="stable")
prim_experts_sorted = prim_experts[order]
prim_dsts_sorted = prim_dsts[order]
pos = np.searchsorted(prim_experts_sorted, dup_experts)
# Filter to experts that have a matching primary entry
valid = np.logical_and(
pos < prim_experts_sorted.shape[0],
prim_experts_sorted[np.minimum(pos, prim_experts_sorted.shape[0] - 1)]
== dup_experts,
)
if not bool(valid.any()):
continue
matched_dst_rows = dup_dst_rows[valid]
matched_src_rows = prim_dsts_sorted[pos[valid]]
# Perform row copies per (dst, src) pair without tensor indexing prim_experts = recv_expert_ids[:recv_count]
for dst, src in zip(matched_dst_rows.tolist(), matched_src_rows.tolist()): prim_dsts = recv_dst_rows[:recv_count]
for w in weights_list: order = np.argsort(prim_experts, kind="stable")
w[dst].copy_(w[src]) prim_experts_sorted = prim_experts[order]
prim_dsts_sorted = prim_dsts[order]
pos = np.searchsorted(prim_experts_sorted, dup_experts)
valid = np.logical_and(
pos < prim_experts_sorted.shape[0],
prim_experts_sorted[np.minimum(pos, prim_experts_sorted.shape[0] - 1)]
== dup_experts,
)
if not bool(valid.any()):
return
matched_dst_rows = dup_dst_rows[valid]
matched_src_rows = prim_dsts_sorted[pos[valid]]
for dst, src in zip(matched_dst_rows.tolist(), matched_src_rows.tolist()):
for w in expert_weights:
w[dst].copy_(w[src])
async def transfer_layer( async def transfer_layer(
@ -555,7 +469,7 @@ async def transfer_layer(
is_unchanged (np.ndarray): (1, num_local_experts), True where expert is_unchanged (np.ndarray): (1, num_local_experts), True where expert
is left unchanged. is left unchanged.
is_received_locally (np.ndarray): (1, num_local_experts), True where expert is_received_locally (np.ndarray): (1, num_local_experts), True where expert
is not copied locally. can be received locally.
RecvMetadata: Metadata needed for completing remote weight transfers. RecvMetadata: Metadata needed for completing remote weight transfers.
""" """
ep_size = ep_group.size() ep_size = ep_group.size()
@ -586,10 +500,10 @@ async def transfer_layer(
is_unchanged, is_received_locally, recv_metadata = move_to_buffer( is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
num_local_experts=num_local_physical_experts, num_local_experts=num_local_physical_experts,
old_indices_group=old_global_expert_indices_np[layer : layer + 1], old_indices=old_global_expert_indices_np[layer],
new_indices_group=new_global_expert_indices_np[layer : layer + 1], new_indices=new_global_expert_indices_np[layer],
expert_weights_group=[expert_weights[layer]], expert_weights=expert_weights[layer],
buffers_group=[expert_weights_buffer], expert_weights_buffers=expert_weights_buffer,
cuda_stream=cuda_stream, cuda_stream=cuda_stream,
ep_group=ep_group, ep_group=ep_group,
) )
@ -649,33 +563,24 @@ def rearrange_expert_weights_inplace(
ep_size = ep_group.size() ep_size = ep_group.size()
assert num_physical_experts == ep_size * num_local_physical_experts assert num_physical_experts == ep_size * num_local_physical_experts
# Max number of layers to group for communication
max_group_layers = envs.VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS
max_group_layers = max(min(max_group_layers, num_moe_layers), 1)
first_layer_weights = list(expert_weights[0]) first_layer_weights = list(expert_weights[0])
# Buffers to hold the expert weights during the exchange. # Buffers to hold the expert weights during the exchange.
# NOTE: Currently we assume the same weights across different layers # NOTE: Currently we assume the same weights across different layers
# have the same shape. # have the same shape.
weights_buffers: list[list[torch.Tensor]] = [ weights_buffer: list[torch.Tensor] = [
[torch.empty_like(w) for w in first_layer_weights] torch.empty_like(w) for w in first_layer_weights
for _ in range(max_group_layers)
] ]
if is_profile: if is_profile:
# Reserve communication buffers via a minimal dummy all_gather on first layer # Reserve communication buffers via a minimal dummy all_gather on first layer
for layer_idx in range(max_group_layers): for weight, buffer in zip(expert_weights[0], weights_buffer):
for weight, buffer in zip(expert_weights[0], weights_buffers[layer_idx]): dummy_recv_buffer = [buffer for _ in range(ep_size)]
dummy_recv_buffer = [buffer for _ in range(ep_size)] torch.distributed.barrier()
torch.distributed.barrier() all_gather(
all_gather( dummy_recv_buffer,
dummy_recv_buffer, weight,
weight, group=ep_group,
group=ep_group, )
)
return return
logger.info_once(
f"EPLB Sync: rearrange max_group_layers: {max_group_layers}", scope="global"
)
# NOTE(bowen): We need this synchronize to run, but I don't know why. # NOTE(bowen): We need this synchronize to run, but I don't know why.
# If you figure out the reason, please let me know -- thank you! # If you figure out the reason, please let me know -- thank you!
@ -684,34 +589,26 @@ def rearrange_expert_weights_inplace(
old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy() old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy()
new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy() new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy()
start = 0 for layer_idx in range(num_moe_layers):
while start < num_moe_layers:
end = min(start + max_group_layers, num_moe_layers)
old_group = old_global_expert_indices_cpu[start:end]
new_group = new_global_expert_indices_cpu[start:end]
weights_group = [expert_weights[i] for i in range(start, end)]
buffers_group = weights_buffers[: (end - start)]
is_unchanged, is_received_locally, recv_metadata = move_to_buffer( is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
num_local_experts=num_local_physical_experts, num_local_experts=num_local_physical_experts,
old_indices_group=old_group, old_indices=old_global_expert_indices_cpu[layer_idx],
new_indices_group=new_group, new_indices=new_global_expert_indices_cpu[layer_idx],
expert_weights_group=weights_group, expert_weights=expert_weights[layer_idx],
buffers_group=buffers_group, expert_weights_buffers=weights_buffer,
cuda_stream=None, cuda_stream=None,
ep_group=ep_group, ep_group=ep_group,
) )
move_from_buffer( move_from_buffer(
weights_group=weights_group, expert_weights=expert_weights[layer_idx],
buffers_group=buffers_group, expert_weights_buffers=weights_buffer,
is_unchanged=is_unchanged, is_unchanged=is_unchanged,
is_received_locally=is_received_locally, is_received_locally=is_received_locally,
recv_metadata=recv_metadata, recv_metadata=recv_metadata,
new_indices_group=new_group, new_indices=new_global_expert_indices_cpu[layer_idx],
ep_group=ep_group, ep_group=ep_group,
) )
start = end
def _map_old_expert_indices_with_rank_mapping( def _map_old_expert_indices_with_rank_mapping(

View File

@ -244,7 +244,6 @@ if TYPE_CHECKING:
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256 VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary" VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
VLLM_USE_V2_MODEL_RUNNER: bool = False VLLM_USE_V2_MODEL_RUNNER: bool = False
VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS: int = 1
def get_default_cache_root(): def get_default_cache_root():
@ -1564,10 +1563,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_V2_MODEL_RUNNER": lambda: bool( "VLLM_USE_V2_MODEL_RUNNER": lambda: bool(
int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0")) int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0"))
), ),
# Max number of layers to group in synchronous EPLB weight communication.
"VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS": lambda: int(
os.getenv("VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS", "1")
),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]