Vectorize get_ep_ranks_with_experts

Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
ilmarkov 2025-11-26 21:50:09 +00:00
parent c6f14d1a27
commit 08083749be

View File

@ -24,80 +24,111 @@ from vllm.logger import init_logger
logger = init_logger(__name__)
def idx_local_to_global(
local_idx: int,
local_cnt: int,
ep_rank: int,
) -> int:
"""
Convert a local expert index to a global expert index.
"""
return ep_rank * local_cnt + local_idx
def idx_global_to_local(
global_idx: int,
local_cnt: int,
ep_rank: int,
) -> int:
"""
Convert a global expert index to a local expert index.
"""
return global_idx - ep_rank * local_cnt
def global_idx_to_rank(
global_idx: int,
local_cnt: int,
) -> int:
"""
Convert a global expert index to a rank index.
"""
return global_idx // local_cnt
def get_ep_ranks_with_expert(
idx: int,
def get_ep_ranks_with_experts_batch(
expert_ids: np.ndarray,
num_local_experts: int,
old_indices: np.ndarray,
new_indices: np.ndarray,
) -> tuple[list[int], list[int]]:
) -> tuple[dict[int, list[int]], dict[int, list[int]]]:
"""
Get the ranks of the experts that need to be exchanged.
Args:
idx: The index of the expert.
expert_ids: 1D array of expert indices to query.
num_local_experts: The number of local experts.
old_indices: The old indices of the experts.
new_indices: The new indices of the experts.
Returns:
A tuple of two lists:
- The ranks of the experts that need to be sent.
- The ranks of the experts that need to be received.
A tuple of two dictionaries mapping expert_id to:
- ranks_to_send: The ranks that have this expert and need to send.
- ranks_to_recv: The ranks that need to receive this expert.
"""
# Indices where expert idx appears
old_pos = np.nonzero(old_indices == idx)[0]
new_pos = np.nonzero(new_indices == idx)[0]
# Map positions to ranks
if old_pos.size > 0:
old_ranks = old_pos // num_local_experts
uniq_send, first_idx_send = np.unique(old_ranks, return_index=True)
order_send = np.argsort(first_idx_send)
ranks_to_send = uniq_send[order_send].astype(int).tolist()
else:
ranks_to_send = []
if new_pos.size > 0:
new_ranks = new_pos // num_local_experts
uniq_recv, first_idx_recv = np.unique(new_ranks, return_index=True)
order_recv = np.argsort(first_idx_recv)
ranks_to_recv = uniq_recv[order_recv].astype(int).tolist()
else:
ranks_to_recv = []
# Remove ranks that have local copies to avoid unnecessary recv
ranks_to_send_set = set(ranks_to_send)
ranks_to_recv_actual = [r for r in ranks_to_recv if r not in ranks_to_send_set]
return ranks_to_send, ranks_to_recv_actual
ranks_to_send_map: dict[int, list[int]] = {}
ranks_to_recv_map: dict[int, list[int]] = {}
# Fast path: if no experts, return empty dicts
if expert_ids.size == 0:
return ranks_to_send_map, ranks_to_recv_map
unique_experts = np.unique(expert_ids)
num_positions = len(old_indices)
position_indices = np.arange(num_positions, dtype=np.int32)
# Vectorized approach: find all positions matching any query expert in one pass
# Use np.isin to get boolean masks for all relevant positions at once
old_relevant_mask = np.isin(old_indices, unique_experts)
new_relevant_mask = np.isin(new_indices, unique_experts)
# Process old_indices (send ranks)
if np.any(old_relevant_mask):
old_relevant_positions = position_indices[old_relevant_mask]
old_relevant_experts = old_indices[old_relevant_mask]
old_relevant_ranks = old_relevant_positions // num_local_experts
# Sort by expert first, then by position (to maintain first-appearance order)
sort_order = np.lexsort((old_relevant_positions, old_relevant_experts))
sorted_experts = old_relevant_experts[sort_order]
sorted_ranks = old_relevant_ranks[sort_order]
# Find boundaries where expert changes
expert_boundaries = np.concatenate(
[[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]]
)
# For each expert, extract unique ranks in order of first appearance
for i in range(len(expert_boundaries) - 1):
start, end = expert_boundaries[i], expert_boundaries[i + 1]
expert = int(sorted_experts[start])
expert_ranks = sorted_ranks[start:end]
# Get unique ranks preserving order
_, unique_idx = np.unique(expert_ranks, return_index=True)
unique_ranks = expert_ranks[np.sort(unique_idx)]
ranks_to_send_map[expert] = unique_ranks.tolist()
# Process new_indices (recv ranks)
if np.any(new_relevant_mask):
new_relevant_positions = position_indices[new_relevant_mask]
new_relevant_experts = new_indices[new_relevant_mask]
new_relevant_ranks = new_relevant_positions // num_local_experts
# Sort by expert first, then by position
sort_order = np.lexsort((new_relevant_positions, new_relevant_experts))
sorted_experts = new_relevant_experts[sort_order]
sorted_ranks = new_relevant_ranks[sort_order]
# Find boundaries where expert changes
expert_boundaries = np.concatenate(
[[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]]
)
# For each expert, extract unique ranks and exclude local copies
for i in range(len(expert_boundaries) - 1):
start, end = expert_boundaries[i], expert_boundaries[i + 1]
expert = int(sorted_experts[start])
expert_ranks = sorted_ranks[start:end]
# Get unique ranks preserving order
_, unique_idx = np.unique(expert_ranks, return_index=True)
unique_ranks = expert_ranks[np.sort(unique_idx)]
# Remove ranks that have local copies (in send map)
send_ranks_set = set(ranks_to_send_map.get(expert, []))
recv_ranks_actual = [
int(r) for r in unique_ranks if r not in send_ranks_set
]
ranks_to_recv_map[expert] = recv_ranks_actual
# Handle experts that only appear in old (send only) or new (recv only)
for expert in unique_experts:
expert = int(expert)
if expert not in ranks_to_send_map:
ranks_to_send_map[expert] = []
if expert not in ranks_to_recv_map:
ranks_to_recv_map[expert] = []
return ranks_to_send_map, ranks_to_recv_map
def move_to_buffer(
@ -228,6 +259,10 @@ def move_to_buffer(
p2p_ops: list[P2POp] = []
# Pre-compute global ranks mapping
ep_size = ep_group.size()
rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)}
# 2. Post sends per layer
for layer_idx in range(group_size):
old_indices = old_indices_group[layer_idx]
@ -241,13 +276,17 @@ def move_to_buffer(
order = np.argsort(experts, kind="stable")
experts = experts[order]
srcs = srcs[order]
send_map, recv_map = get_ep_ranks_with_experts_batch(
experts,
num_local_experts,
old_indices,
layer_new_indices,
)
for expert, src in zip(experts.tolist(), srcs.tolist()):
ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert(
expert,
num_local_experts,
old_indices,
layer_new_indices,
)
ranks_to_send = send_map[expert]
ranks_to_recv = recv_map[expert]
if not ranks_to_send or not ranks_to_recv:
continue
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
@ -260,7 +299,7 @@ def move_to_buffer(
if recver_pos < len(ranks_to_recv):
recv_ranks.append(ranks_to_recv[recver_pos])
for dst in recv_ranks:
dst_global = get_global_rank(ep_group, dst)
dst_global = rank_to_global[dst]
p2p_ops += [
P2POp(
torch.distributed.isend,
@ -283,13 +322,18 @@ def move_to_buffer(
order = np.argsort(experts, kind="stable")
experts = experts[order]
dsts = dsts[order]
# Batch query all experts for this layer
send_map, recv_map = get_ep_ranks_with_experts_batch(
experts,
num_local_experts,
old_indices,
layer_new_indices,
)
for expert, dst in zip(experts.tolist(), dsts.tolist()):
ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert(
expert,
num_local_experts,
old_indices,
layer_new_indices,
)
ranks_to_send = send_map[expert]
ranks_to_recv = recv_map[expert]
if not ranks_to_send or not ranks_to_recv:
continue
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
@ -299,7 +343,7 @@ def move_to_buffer(
src = ranks_to_send[recver_pos // num_dst_per_sender]
else:
src = ranks_to_send[recver_pos - remainder_start]
src_global = get_global_rank(ep_group, src)
src_global = rank_to_global[src]
p2p_ops += [
P2POp(
torch.distributed.irecv,