Remove layer grouping

Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
ilmarkov 2025-12-11 10:38:32 +00:00
parent fc54d760a6
commit f28720db88
5 changed files with 168 additions and 281 deletions

View File

@ -15,7 +15,7 @@ from vllm.utils.system_utils import update_environment_variables
mp.set_start_method("spawn", force=True)
def distributed_run(fn, world_size, *args, max_grouped_layers=1):
def distributed_run(fn, world_size, *args):
number_of_processes = world_size
processes: list[mp.Process] = []
for i in range(number_of_processes):
@ -26,7 +26,6 @@ def distributed_run(fn, world_size, *args, max_grouped_layers=1):
env["LOCAL_WORLD_SIZE"] = str(number_of_processes)
env["MASTER_ADDR"] = "localhost"
env["MASTER_PORT"] = "12345"
env["VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS"] = str(max_grouped_layers)
p = mp.Process(target=fn, args=(env, world_size, *args))
processes.append(p)
p.start()

View File

@ -306,12 +306,12 @@ def _test_async_transfer_layer_without_mtp_worker(
)
cuda_stream.synchronize()
move_from_buffer(
weights_group=[expert_weights[layer_idx]],
buffers_group=[expert_buffer],
expert_weights=expert_weights[layer_idx],
expert_weights_buffers=expert_buffer,
is_unchanged=is_unchanged,
is_received_locally=is_received_locally,
recv_metadata=recv_metadata,
new_indices_group=new_indices_cpu[layer_idx : layer_idx + 1],
new_indices=new_indices_cpu[layer_idx],
ep_group=ep_group,
)
@ -427,9 +427,8 @@ def _test_rearrange_expert_weights_with_redundancy(
(4, 8, 8, 16),
],
)
@pytest.mark.parametrize("group_layers", [1, 2])
def test_rearrange_expert_weights_with_redundancy(
world_size, num_layers, num_local_experts, num_logical_experts, group_layers
world_size, num_layers, num_local_experts, num_logical_experts
):
"""Test the functionality of rearranging expert weights with redundancy."""
@ -441,7 +440,6 @@ def test_rearrange_expert_weights_with_redundancy(
num_layers,
num_local_experts,
num_logical_experts,
max_grouped_layers=group_layers,
)

View File

@ -514,7 +514,7 @@ class EplbState:
is_received_locally=np.array([]),
recv_metadata=RecvMetadata(
recv_primary_mask=np.array([]),
recv_counts=np.array([]),
recv_count=0,
recv_expert_ids=np.array([]),
recv_dst_rows=np.array([]),
),
@ -989,24 +989,22 @@ class EplbState:
stream = torch.cuda.current_stream(device=device_index)
stream.wait_event(model_state.buffer_ready_event)
model_state.buffer_ready_event = None
weights_group = [
model_state.model.expert_weights[model_state.layer_to_transfer]
expert_weights = model_state.model.expert_weights[
model_state.layer_to_transfer
]
buffers_group = [model_state.expert_buffer]
new_indices_group = (
model_state.new_physical_to_logical_map[
model_state.layer_to_transfer : model_state.layer_to_transfer + 1
]
expert_weights_buffer = model_state.expert_buffer
new_indices = (
model_state.new_physical_to_logical_map[model_state.layer_to_transfer]
.cpu()
.numpy()
)
move_from_buffer(
weights_group=weights_group,
buffers_group=buffers_group,
expert_weights=expert_weights,
expert_weights_buffers=expert_weights_buffer,
is_unchanged=model_state.is_unchanged,
is_received_locally=model_state.is_received_locally,
recv_metadata=model_state.recv_metadata,
new_indices_group=new_indices_group,
new_indices=new_indices,
ep_group=ep_group,
)
transferred_layer = model_state.layer_to_transfer

View File

@ -19,7 +19,6 @@ from torch.distributed import (
get_global_rank,
)
import vllm.envs as envs
from vllm.logger import init_logger
logger = init_logger(__name__)
@ -30,15 +29,13 @@ class RecvMetadata:
"""Metadata describing remote receives during EPLB rebalancing."""
recv_primary_mask: np.ndarray
"""Mask of (layer_group_size, num_local_experts)
indicating primary experts received."""
recv_counts: np.ndarray
"""Number of received experts for each layer."""
"""Mask of (num_local_experts,) indicating primary experts received."""
recv_count: int
"""Number of received experts for the layer."""
recv_expert_ids: np.ndarray
"""Expert ids (layer_group_size, num_local_experts) of remote primary experts."""
"""Expert ids (num_local_experts,) of remote primary experts."""
recv_dst_rows: np.ndarray
"""Target expert indices (layer_group_size, num_local_experts)
in local tensors to send."""
"""Target expert indices (num_local_experts,) in local tensors to send."""
# Type alias for the result of move_to_buffer or transfer_layer
@ -154,146 +151,105 @@ def get_ep_ranks_with_experts_batch(
def move_to_buffer(
num_local_experts: int,
old_indices_group: np.ndarray,
new_indices_group: np.ndarray,
expert_weights_group: Sequence[Iterable[torch.Tensor]],
buffers_group: Sequence[Sequence[torch.Tensor]],
old_indices: np.ndarray,
new_indices: np.ndarray,
expert_weights: Iterable[torch.Tensor],
expert_weights_buffers: Sequence[torch.Tensor],
cuda_stream: torch.cuda.Stream | None,
ep_group: ProcessGroup,
) -> MoveToBufferResult:
"""
Rearranges expert weights across a group of layers
during mixture-of-experts (MoE) expert parallel rebalancing.
Rearranges expert weights during EPLB rebalancing.
Args:
num_local_experts: Number of local experts.
old_indices_group: (num_layers, num_experts_total) ndarray of current
(old) global-to-local expert assignments.
new_indices_group: (num_layers, num_experts_total) ndarray of desired
(new) global-to-local assignments after rebalance.
expert_weights_group: Original expert weights for each layer.
buffers_group: List of per-layer intermediate buffers (one per tensor).
old_indices: (num_experts_total,) ndarray of current (old)
global-to-local expert assignments.
new_indices: (num_experts_total,) ndarray of desired (new)
global-to-local assignments after rebalance.
expert_weights: Original expert weights for the layer.
expert_weights_buffers: Intermediate buffers (one per tensor).
cuda_stream: CUDA stream for async copies (can be None for sync mode).
ep_group: Distributed process group for expert parallel comms.
Returns:
is_unchanged (np.ndarray): (num_layers, num_local_experts), True where an
expert row is unchanged after rebalance.
is_received_locally (np.ndarray): (num_layers, num_local_experts), True
where a row can be updated from local data.
is_unchanged (np.ndarray): (num_local_experts,), True where an expert row
is unchanged after rebalance.
is_received_locally (np.ndarray): (num_local_experts,), True where a row
can be updated from local data.
RecvMetadata: Metadata needed for completing remote weight transfers.
"""
assert len(old_indices_group) == len(new_indices_group) == len(expert_weights_group)
group_size = len(old_indices_group)
assert old_indices.shape == new_indices.shape
ep_rank = ep_group.rank()
# Pre-allocate per-layer compact maps/masks (numpy)
is_unchanged = np.zeros((group_size, num_local_experts), dtype=np.bool_)
is_received_locally = np.zeros((group_size, num_local_experts), dtype=np.bool_)
recv_primary_mask = np.zeros((group_size, num_local_experts), dtype=np.bool_)
# Cache desired new expert ids per local row, for all layers
new_local_expert_ids_mat = np.full(
(group_size, num_local_experts), -1, dtype=np.int64
)
send_counts = np.zeros(group_size, dtype=np.int32)
send_expert_ids = np.full((group_size, num_local_experts), -1, dtype=np.int64)
send_src_rows = np.full((group_size, num_local_experts), -1, dtype=np.int32)
recv_counts = np.zeros(group_size, dtype=np.int32)
recv_expert_ids = np.full((group_size, num_local_experts), -1, dtype=np.int64)
recv_dst_rows = np.full((group_size, num_local_experts), -1, dtype=np.int32)
recv_primary_mask = np.zeros((num_local_experts,), dtype=np.bool_)
send_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64)
send_src_rows = np.full((num_local_experts,), -1, dtype=np.int32)
recv_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64)
recv_dst_rows = np.full((num_local_experts,), -1, dtype=np.int32)
base = ep_rank * num_local_experts
local_rows = np.arange(num_local_experts, dtype=np.int32)
local_global = base + local_rows
# Build masks and expert maps per layer
for layer_idx in range(group_size):
old_indices = old_indices_group[layer_idx]
layer_new_indices = new_indices_group[layer_idx]
old_local_expert_ids = old_indices[local_global]
new_local_expert_ids = new_indices[local_global]
old_local_expert_ids = old_indices[local_global]
new_local_expert_ids = layer_new_indices[local_global]
new_local_expert_ids_mat[layer_idx, :] = new_local_expert_ids
# Unchanged mask
is_unchanged = old_local_expert_ids == new_local_expert_ids
# Unchanged per-dst mask
unchanged_mask = old_local_expert_ids == new_local_expert_ids
is_unchanged[layer_idx, :] = unchanged_mask
# Local receive eligibility
new_valid = new_local_expert_ids != -1
can_recv_local = np.isin(
new_local_expert_ids, old_local_expert_ids, assume_unique=False
)
is_received_locally = np.logical_or(
is_unchanged, np.logical_and(new_valid, can_recv_local)
)
# Local receive eligibility
new_valid = new_local_expert_ids != -1
can_recv_local = np.isin(
new_local_expert_ids, old_local_expert_ids, assume_unique=False
# Send map: first src row per unique expert present locally in old mapping
send_count = 0
valid_old = old_local_expert_ids != -1
if np.any(valid_old):
uniq_experts, first_idx = np.unique(
old_local_expert_ids[valid_old], return_index=True
)
is_local_recv = np.logical_or(
unchanged_mask, np.logical_and(new_valid, can_recv_local)
)
is_received_locally[layer_idx, :] = is_local_recv
filtered_rows = local_rows[valid_old]
src_rows = filtered_rows[first_idx]
send_count = int(uniq_experts.shape[0])
send_expert_ids[:send_count] = uniq_experts
send_src_rows[:send_count] = src_rows
# Send map: first src row per unique expert present locally in old mapping
valid_old = old_local_expert_ids != -1
if np.any(valid_old):
uniq_experts, first_idx = np.unique(
old_local_expert_ids[valid_old], return_index=True
)
filtered_rows = local_rows[valid_old]
src_rows = filtered_rows[first_idx]
layer_send_count = int(uniq_experts.shape[0])
send_counts[layer_idx] = layer_send_count
send_expert_ids[layer_idx, :layer_send_count] = uniq_experts
send_src_rows[layer_idx, :layer_send_count] = src_rows
else:
send_counts[layer_idx] = 0
# Recv map: primary dst per unique expert needed remotely
recv_count = 0
need_recv_mask = np.logical_and(~is_received_locally, new_valid)
if np.any(need_recv_mask):
desired_experts = new_local_expert_ids[need_recv_mask]
desired_dsts = local_rows[need_recv_mask]
uniq_recv_experts, uniq_indices = np.unique(desired_experts, return_index=True)
dst_rows = desired_dsts[uniq_indices]
recv_count = int(uniq_recv_experts.shape[0])
recv_expert_ids[:recv_count] = uniq_recv_experts
recv_dst_rows[:recv_count] = dst_rows
recv_primary_mask[dst_rows] = True
# Recv map: primary dst per unique expert needed remotely
need_recv_mask = np.logical_and(~is_local_recv, new_valid)
if np.any(need_recv_mask):
desired_experts = new_local_expert_ids[need_recv_mask]
desired_dsts = local_rows[need_recv_mask]
uniq_recv_experts, uniq_indices = np.unique(
desired_experts, return_index=True
)
dst_rows = desired_dsts[uniq_indices]
layer_recv_count = int(uniq_recv_experts.shape[0])
recv_counts[layer_idx] = layer_recv_count
recv_expert_ids[layer_idx, :layer_recv_count] = uniq_recv_experts
recv_dst_rows[layer_idx, :layer_recv_count] = dst_rows
recv_primary_mask[layer_idx, dst_rows] = True
else:
recv_counts[layer_idx] = 0
# Precompute per-layer destination mask that actually needs local buffering:
# need change, received locally, and valid target expert id
eligible_local_buffer_mask = np.logical_and(
np.logical_and(~is_unchanged, is_received_locally),
new_local_expert_ids_mat != -1,
new_local_expert_ids != -1,
)
# 1. Local moves into tmp buffers
for layer_idx in range(group_size):
layer_send_count = int(send_counts[layer_idx])
if layer_send_count <= 0:
continue
layer_send_experts = send_expert_ids[layer_idx, :layer_send_count]
layer_send_srcs = send_src_rows[layer_idx, :layer_send_count]
layer_weights_list = list(expert_weights_group[layer_idx])
layer_buffers_list = list(buffers_group[layer_idx])
new_local_expert_ids = new_local_expert_ids_mat[layer_idx, :]
# Only consider destination rows that are eligible for local buffering
eligible_mask = eligible_local_buffer_mask[layer_idx, :]
if not bool(eligible_mask.any()):
continue
dest_indices = np.nonzero(eligible_mask)[0].tolist()
# Build a map from expert_id to its source row.
expert_to_src_map = dict(zip(layer_send_experts, layer_send_srcs))
if bool(eligible_local_buffer_mask.any()) and send_count > 0:
dest_indices = np.nonzero(eligible_local_buffer_mask)[0].tolist()
expert_to_src_map = dict(
zip(send_expert_ids[:send_count], send_src_rows[:send_count])
)
for dst in dest_indices:
expert = new_local_expert_ids[dst]
src_local = expert_to_src_map.get(expert, -1)
if src_local != -1:
for w, b in zip(layer_weights_list, layer_buffers_list):
b[dst].copy_(w[src_local])
for w, b in zip(expert_weights, expert_weights_buffers):
b[dst].copy_(w[src_local], non_blocking=True)
p2p_ops: list[P2POp] = []
@ -301,16 +257,10 @@ def move_to_buffer(
ep_size = ep_group.size()
rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)}
# 2. Post sends per layer
for layer_idx in range(group_size):
old_indices = old_indices_group[layer_idx]
layer_new_indices = new_indices_group[layer_idx]
layer_weights_list = list(expert_weights_group[layer_idx])
layer_send_count = int(send_counts[layer_idx])
if layer_send_count == 0:
continue
experts = send_expert_ids[layer_idx, :layer_send_count]
srcs = send_src_rows[layer_idx, :layer_send_count]
# 2. Post sends
if send_count > 0:
experts = send_expert_ids[:send_count]
srcs = send_src_rows[:send_count]
order = np.argsort(experts, kind="stable")
experts = experts[order]
srcs = srcs[order]
@ -319,7 +269,7 @@ def move_to_buffer(
experts,
num_local_experts,
old_indices,
layer_new_indices,
new_indices,
)
for expert, src in zip(experts.tolist(), srcs.tolist()):
@ -344,29 +294,22 @@ def move_to_buffer(
w[src],
dst_global,
)
for w in layer_weights_list
for w in expert_weights
]
# 3. Post recvs per layer
for layer_idx in range(group_size):
old_indices = old_indices_group[layer_idx]
layer_new_indices = new_indices_group[layer_idx]
layer_buffers_list = list(buffers_group[layer_idx])
layer_recv_count = int(recv_counts[layer_idx])
if layer_recv_count == 0:
continue
experts = recv_expert_ids[layer_idx, :layer_recv_count]
dsts = recv_dst_rows[layer_idx, :layer_recv_count]
# 3. Post recvs
if recv_count > 0:
experts = recv_expert_ids[:recv_count]
dsts = recv_dst_rows[:recv_count]
order = np.argsort(experts, kind="stable")
experts = experts[order]
dsts = dsts[order]
# Batch query all experts for this layer
send_map, recv_map = get_ep_ranks_with_experts_batch(
experts,
num_local_experts,
old_indices,
layer_new_indices,
new_indices,
)
for expert, dst in zip(experts.tolist(), dsts.tolist()):
@ -388,7 +331,7 @@ def move_to_buffer(
b[dst],
src_global,
)
for b in layer_buffers_list
for b in expert_weights_buffers
]
# 4. Execute the P2P operations. The real communication happens here.
@ -407,7 +350,7 @@ def move_to_buffer(
is_received_locally,
RecvMetadata(
recv_primary_mask=recv_primary_mask,
recv_counts=recv_counts,
recv_count=recv_count,
recv_expert_ids=recv_expert_ids,
recv_dst_rows=recv_dst_rows,
),
@ -415,111 +358,82 @@ def move_to_buffer(
def move_from_buffer(
weights_group: Sequence[Iterable[torch.Tensor]],
buffers_group: Sequence[Sequence[torch.Tensor]],
expert_weights: Iterable[torch.Tensor],
expert_weights_buffers: list[torch.Tensor],
is_unchanged: np.ndarray,
is_received_locally: np.ndarray,
recv_metadata: RecvMetadata,
new_indices_group: np.ndarray,
new_indices: np.ndarray,
ep_group: ProcessGroup,
) -> None:
"""
Copies expert weights from communication buffers back to the target weight tensors,
Copies expert weights from communication buffers back to the target weight tensors
after EPLB rebalancing.
Args:
weights_group: Groups of consecutive MoE layers, each containing one or more
weight tensors.
buffers_group: Intermediate buffers matching weights_group..
is_unchanged: (num_layers, num_local_experts), True
where an expert row is unchanged after rebalance.
is_received_locally: (num_layers, num_local_experts), True
where a row can be updated from local data.
expert_weights: List of weight tensors for a MoE layer.
expert_weights_buffers: Intermediate buffers matching ``expert_weights``.
is_unchanged: (num_local_experts,), True where an expert row is unchanged.
is_received_locally: (num_local_experts,), True where a row is updated locally.
recv_metadata: RecvMetadata containing remote receive metadata.
new_indices_group: np.ndarray giving for each layer the mapping from local rows
to desired (possibly global) expert id, after rebalance.
new_indices: (num_experts_total,) mapping from local rows to desired
(possibly global) expert id, after rebalance.
ep_group: torch.distributed.ProcessGroup for expert parallel communication
domain.
"""
assert (
len(weights_group)
== len(buffers_group)
== len(is_unchanged)
== len(is_received_locally)
== len(recv_metadata.recv_primary_mask)
== len(new_indices_group)
), "Unmatching layer group size"
ep_rank = ep_group.rank()
group_size = len(is_unchanged)
recv_primary_mask = recv_metadata.recv_primary_mask
recv_counts = recv_metadata.recv_counts
recv_count = recv_metadata.recv_count
recv_expert_ids = recv_metadata.recv_expert_ids
recv_dst_rows = recv_metadata.recv_dst_rows
num_local_experts = is_unchanged.shape[1]
num_local_experts = is_unchanged.shape[0]
# Mask for rows to copy back from buffers:
# copy if locally received OR remote primary recv
copy_mask = np.logical_or(is_received_locally, recv_primary_mask)
# Copy back local buffered rows into destination weights
for layer_idx in range(group_size):
layer_is_unchanged = is_unchanged[layer_idx, :]
layer_copy_mask = copy_mask[layer_idx, :]
weights_list = list(weights_group[layer_idx])
buffers_list = list(buffers_group[layer_idx])
# rows to copy = (~unchanged) & copy_mask
dest_mask_np = np.logical_and(~layer_is_unchanged, layer_copy_mask)
if not bool(dest_mask_np.any()):
continue
dest_mask_np = np.logical_and(~is_unchanged, copy_mask)
if bool(dest_mask_np.any()):
dest_indices = np.nonzero(dest_mask_np)[0].tolist()
for dst in dest_indices:
for w, b in zip(weights_list, buffers_list):
for w, b in zip(expert_weights, expert_weights_buffers):
w[dst].copy_(b[dst])
# Duplicate remote received rows to non-primary duplicate dsts
for layer_idx in range(group_size):
layer_is_unchanged = is_unchanged[layer_idx, :]
layer_is_received_locally = is_received_locally[layer_idx, :]
new_indices = new_indices_group[layer_idx]
weights_list = list(weights_group[layer_idx])
if recv_count == 0:
return
count_recv = int(recv_counts[layer_idx])
if count_recv == 0:
# No remote primaries on this layer → no remote duplicates to materialize
continue
# Local view of desired expert ids per local row
base = ep_rank * num_local_experts
local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)]
# Duplicate rows mask: need remote, not primary, and valid expert id
duplicate_mask = np.logical_and(
np.logical_and(~layer_is_unchanged, ~layer_is_received_locally),
np.logical_and(~recv_primary_mask[layer_idx, :], local_experts != -1),
)
if not bool(duplicate_mask.any()):
continue
dup_dst_rows = np.nonzero(duplicate_mask)[0]
dup_experts = local_experts[dup_dst_rows]
base = ep_rank * num_local_experts
local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)]
duplicate_mask = np.logical_and(
np.logical_and(~is_unchanged, ~is_received_locally),
np.logical_and(~recv_primary_mask, local_experts != -1),
)
if not bool(duplicate_mask.any()):
return
# Build primary mapping arrays (expert -> primary dst) and vector-match
prim_experts = recv_expert_ids[layer_idx, :count_recv]
prim_dsts = recv_dst_rows[layer_idx, :count_recv]
order = np.argsort(prim_experts, kind="stable")
prim_experts_sorted = prim_experts[order]
prim_dsts_sorted = prim_dsts[order]
pos = np.searchsorted(prim_experts_sorted, dup_experts)
# Filter to experts that have a matching primary entry
valid = np.logical_and(
pos < prim_experts_sorted.shape[0],
prim_experts_sorted[np.minimum(pos, prim_experts_sorted.shape[0] - 1)]
== dup_experts,
)
if not bool(valid.any()):
continue
matched_dst_rows = dup_dst_rows[valid]
matched_src_rows = prim_dsts_sorted[pos[valid]]
dup_dst_rows = np.nonzero(duplicate_mask)[0]
dup_experts = local_experts[dup_dst_rows]
# Perform row copies per (dst, src) pair without tensor indexing
for dst, src in zip(matched_dst_rows.tolist(), matched_src_rows.tolist()):
for w in weights_list:
w[dst].copy_(w[src])
prim_experts = recv_expert_ids[:recv_count]
prim_dsts = recv_dst_rows[:recv_count]
order = np.argsort(prim_experts, kind="stable")
prim_experts_sorted = prim_experts[order]
prim_dsts_sorted = prim_dsts[order]
pos = np.searchsorted(prim_experts_sorted, dup_experts)
valid = np.logical_and(
pos < prim_experts_sorted.shape[0],
prim_experts_sorted[np.minimum(pos, prim_experts_sorted.shape[0] - 1)]
== dup_experts,
)
if not bool(valid.any()):
return
matched_dst_rows = dup_dst_rows[valid]
matched_src_rows = prim_dsts_sorted[pos[valid]]
for dst, src in zip(matched_dst_rows.tolist(), matched_src_rows.tolist()):
for w in expert_weights:
w[dst].copy_(w[src])
async def transfer_layer(
@ -555,7 +469,7 @@ async def transfer_layer(
is_unchanged (np.ndarray): (1, num_local_experts), True where expert
is left unchanged.
is_received_locally (np.ndarray): (1, num_local_experts), True where expert
is not copied locally.
can be received locally.
RecvMetadata: Metadata needed for completing remote weight transfers.
"""
ep_size = ep_group.size()
@ -586,10 +500,10 @@ async def transfer_layer(
is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
num_local_experts=num_local_physical_experts,
old_indices_group=old_global_expert_indices_np[layer : layer + 1],
new_indices_group=new_global_expert_indices_np[layer : layer + 1],
expert_weights_group=[expert_weights[layer]],
buffers_group=[expert_weights_buffer],
old_indices=old_global_expert_indices_np[layer],
new_indices=new_global_expert_indices_np[layer],
expert_weights=expert_weights[layer],
expert_weights_buffers=expert_weights_buffer,
cuda_stream=cuda_stream,
ep_group=ep_group,
)
@ -649,33 +563,24 @@ def rearrange_expert_weights_inplace(
ep_size = ep_group.size()
assert num_physical_experts == ep_size * num_local_physical_experts
# Max number of layers to group for communication
max_group_layers = envs.VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS
max_group_layers = max(min(max_group_layers, num_moe_layers), 1)
first_layer_weights = list(expert_weights[0])
# Buffers to hold the expert weights during the exchange.
# NOTE: Currently we assume the same weights across different layers
# have the same shape.
weights_buffers: list[list[torch.Tensor]] = [
[torch.empty_like(w) for w in first_layer_weights]
for _ in range(max_group_layers)
weights_buffer: list[torch.Tensor] = [
torch.empty_like(w) for w in first_layer_weights
]
if is_profile:
# Reserve communication buffers via a minimal dummy all_gather on first layer
for layer_idx in range(max_group_layers):
for weight, buffer in zip(expert_weights[0], weights_buffers[layer_idx]):
dummy_recv_buffer = [buffer for _ in range(ep_size)]
torch.distributed.barrier()
all_gather(
dummy_recv_buffer,
weight,
group=ep_group,
)
for weight, buffer in zip(expert_weights[0], weights_buffer):
dummy_recv_buffer = [buffer for _ in range(ep_size)]
torch.distributed.barrier()
all_gather(
dummy_recv_buffer,
weight,
group=ep_group,
)
return
logger.info_once(
f"EPLB Sync: rearrange max_group_layers: {max_group_layers}", scope="global"
)
# NOTE(bowen): We need this synchronize to run, but I don't know why.
# If you figure out the reason, please let me know -- thank you!
@ -684,34 +589,26 @@ def rearrange_expert_weights_inplace(
old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy()
new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy()
start = 0
while start < num_moe_layers:
end = min(start + max_group_layers, num_moe_layers)
old_group = old_global_expert_indices_cpu[start:end]
new_group = new_global_expert_indices_cpu[start:end]
weights_group = [expert_weights[i] for i in range(start, end)]
buffers_group = weights_buffers[: (end - start)]
for layer_idx in range(num_moe_layers):
is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
num_local_experts=num_local_physical_experts,
old_indices_group=old_group,
new_indices_group=new_group,
expert_weights_group=weights_group,
buffers_group=buffers_group,
old_indices=old_global_expert_indices_cpu[layer_idx],
new_indices=new_global_expert_indices_cpu[layer_idx],
expert_weights=expert_weights[layer_idx],
expert_weights_buffers=weights_buffer,
cuda_stream=None,
ep_group=ep_group,
)
move_from_buffer(
weights_group=weights_group,
buffers_group=buffers_group,
expert_weights=expert_weights[layer_idx],
expert_weights_buffers=weights_buffer,
is_unchanged=is_unchanged,
is_received_locally=is_received_locally,
recv_metadata=recv_metadata,
new_indices_group=new_group,
new_indices=new_global_expert_indices_cpu[layer_idx],
ep_group=ep_group,
)
start = end
def _map_old_expert_indices_with_rank_mapping(

View File

@ -244,7 +244,6 @@ if TYPE_CHECKING:
VLLM_SHARED_EXPERTS_STREAM_TOKEN_THRESHOLD: int = 256
VLLM_COMPILE_CACHE_SAVE_FORMAT: Literal["binary", "unpacked"] = "binary"
VLLM_USE_V2_MODEL_RUNNER: bool = False
VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS: int = 1
def get_default_cache_root():
@ -1564,10 +1563,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_V2_MODEL_RUNNER": lambda: bool(
int(os.getenv("VLLM_USE_V2_MODEL_RUNNER", "0"))
),
# Max number of layers to group in synchronous EPLB weight communication.
"VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS": lambda: int(
os.getenv("VLLM_EPLB_SYNC_MAX_GROUPED_LAYERS", "1")
),
}
# --8<-- [end:env-vars-definition]