mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-05 05:47:03 +08:00
Vectorize get_ep_ranks_with_experts
Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
parent
c6f14d1a27
commit
08083749be
@ -24,80 +24,111 @@ from vllm.logger import init_logger
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def idx_local_to_global(
|
||||
local_idx: int,
|
||||
local_cnt: int,
|
||||
ep_rank: int,
|
||||
) -> int:
|
||||
"""
|
||||
Convert a local expert index to a global expert index.
|
||||
"""
|
||||
return ep_rank * local_cnt + local_idx
|
||||
|
||||
|
||||
def idx_global_to_local(
|
||||
global_idx: int,
|
||||
local_cnt: int,
|
||||
ep_rank: int,
|
||||
) -> int:
|
||||
"""
|
||||
Convert a global expert index to a local expert index.
|
||||
"""
|
||||
return global_idx - ep_rank * local_cnt
|
||||
|
||||
|
||||
def global_idx_to_rank(
|
||||
global_idx: int,
|
||||
local_cnt: int,
|
||||
) -> int:
|
||||
"""
|
||||
Convert a global expert index to a rank index.
|
||||
"""
|
||||
return global_idx // local_cnt
|
||||
|
||||
|
||||
def get_ep_ranks_with_expert(
|
||||
idx: int,
|
||||
def get_ep_ranks_with_experts_batch(
|
||||
expert_ids: np.ndarray,
|
||||
num_local_experts: int,
|
||||
old_indices: np.ndarray,
|
||||
new_indices: np.ndarray,
|
||||
) -> tuple[list[int], list[int]]:
|
||||
) -> tuple[dict[int, list[int]], dict[int, list[int]]]:
|
||||
"""
|
||||
Get the ranks of the experts that need to be exchanged.
|
||||
|
||||
Args:
|
||||
idx: The index of the expert.
|
||||
expert_ids: 1D array of expert indices to query.
|
||||
num_local_experts: The number of local experts.
|
||||
old_indices: The old indices of the experts.
|
||||
new_indices: The new indices of the experts.
|
||||
|
||||
Returns:
|
||||
A tuple of two lists:
|
||||
- The ranks of the experts that need to be sent.
|
||||
- The ranks of the experts that need to be received.
|
||||
A tuple of two dictionaries mapping expert_id to:
|
||||
- ranks_to_send: The ranks that have this expert and need to send.
|
||||
- ranks_to_recv: The ranks that need to receive this expert.
|
||||
"""
|
||||
# Indices where expert idx appears
|
||||
old_pos = np.nonzero(old_indices == idx)[0]
|
||||
new_pos = np.nonzero(new_indices == idx)[0]
|
||||
# Map positions to ranks
|
||||
if old_pos.size > 0:
|
||||
old_ranks = old_pos // num_local_experts
|
||||
uniq_send, first_idx_send = np.unique(old_ranks, return_index=True)
|
||||
order_send = np.argsort(first_idx_send)
|
||||
ranks_to_send = uniq_send[order_send].astype(int).tolist()
|
||||
else:
|
||||
ranks_to_send = []
|
||||
if new_pos.size > 0:
|
||||
new_ranks = new_pos // num_local_experts
|
||||
uniq_recv, first_idx_recv = np.unique(new_ranks, return_index=True)
|
||||
order_recv = np.argsort(first_idx_recv)
|
||||
ranks_to_recv = uniq_recv[order_recv].astype(int).tolist()
|
||||
else:
|
||||
ranks_to_recv = []
|
||||
# Remove ranks that have local copies to avoid unnecessary recv
|
||||
ranks_to_send_set = set(ranks_to_send)
|
||||
ranks_to_recv_actual = [r for r in ranks_to_recv if r not in ranks_to_send_set]
|
||||
return ranks_to_send, ranks_to_recv_actual
|
||||
ranks_to_send_map: dict[int, list[int]] = {}
|
||||
ranks_to_recv_map: dict[int, list[int]] = {}
|
||||
|
||||
# Fast path: if no experts, return empty dicts
|
||||
if expert_ids.size == 0:
|
||||
return ranks_to_send_map, ranks_to_recv_map
|
||||
|
||||
unique_experts = np.unique(expert_ids)
|
||||
num_positions = len(old_indices)
|
||||
position_indices = np.arange(num_positions, dtype=np.int32)
|
||||
|
||||
# Vectorized approach: find all positions matching any query expert in one pass
|
||||
# Use np.isin to get boolean masks for all relevant positions at once
|
||||
old_relevant_mask = np.isin(old_indices, unique_experts)
|
||||
new_relevant_mask = np.isin(new_indices, unique_experts)
|
||||
|
||||
# Process old_indices (send ranks)
|
||||
if np.any(old_relevant_mask):
|
||||
old_relevant_positions = position_indices[old_relevant_mask]
|
||||
old_relevant_experts = old_indices[old_relevant_mask]
|
||||
old_relevant_ranks = old_relevant_positions // num_local_experts
|
||||
|
||||
# Sort by expert first, then by position (to maintain first-appearance order)
|
||||
sort_order = np.lexsort((old_relevant_positions, old_relevant_experts))
|
||||
sorted_experts = old_relevant_experts[sort_order]
|
||||
sorted_ranks = old_relevant_ranks[sort_order]
|
||||
|
||||
# Find boundaries where expert changes
|
||||
expert_boundaries = np.concatenate(
|
||||
[[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]]
|
||||
)
|
||||
|
||||
# For each expert, extract unique ranks in order of first appearance
|
||||
for i in range(len(expert_boundaries) - 1):
|
||||
start, end = expert_boundaries[i], expert_boundaries[i + 1]
|
||||
expert = int(sorted_experts[start])
|
||||
expert_ranks = sorted_ranks[start:end]
|
||||
|
||||
# Get unique ranks preserving order
|
||||
_, unique_idx = np.unique(expert_ranks, return_index=True)
|
||||
unique_ranks = expert_ranks[np.sort(unique_idx)]
|
||||
ranks_to_send_map[expert] = unique_ranks.tolist()
|
||||
|
||||
# Process new_indices (recv ranks)
|
||||
if np.any(new_relevant_mask):
|
||||
new_relevant_positions = position_indices[new_relevant_mask]
|
||||
new_relevant_experts = new_indices[new_relevant_mask]
|
||||
new_relevant_ranks = new_relevant_positions // num_local_experts
|
||||
|
||||
# Sort by expert first, then by position
|
||||
sort_order = np.lexsort((new_relevant_positions, new_relevant_experts))
|
||||
sorted_experts = new_relevant_experts[sort_order]
|
||||
sorted_ranks = new_relevant_ranks[sort_order]
|
||||
|
||||
# Find boundaries where expert changes
|
||||
expert_boundaries = np.concatenate(
|
||||
[[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]]
|
||||
)
|
||||
|
||||
# For each expert, extract unique ranks and exclude local copies
|
||||
for i in range(len(expert_boundaries) - 1):
|
||||
start, end = expert_boundaries[i], expert_boundaries[i + 1]
|
||||
expert = int(sorted_experts[start])
|
||||
expert_ranks = sorted_ranks[start:end]
|
||||
|
||||
# Get unique ranks preserving order
|
||||
_, unique_idx = np.unique(expert_ranks, return_index=True)
|
||||
unique_ranks = expert_ranks[np.sort(unique_idx)]
|
||||
|
||||
# Remove ranks that have local copies (in send map)
|
||||
send_ranks_set = set(ranks_to_send_map.get(expert, []))
|
||||
recv_ranks_actual = [
|
||||
int(r) for r in unique_ranks if r not in send_ranks_set
|
||||
]
|
||||
ranks_to_recv_map[expert] = recv_ranks_actual
|
||||
|
||||
# Handle experts that only appear in old (send only) or new (recv only)
|
||||
for expert in unique_experts:
|
||||
expert = int(expert)
|
||||
if expert not in ranks_to_send_map:
|
||||
ranks_to_send_map[expert] = []
|
||||
if expert not in ranks_to_recv_map:
|
||||
ranks_to_recv_map[expert] = []
|
||||
|
||||
return ranks_to_send_map, ranks_to_recv_map
|
||||
|
||||
|
||||
def move_to_buffer(
|
||||
@ -228,6 +259,10 @@ def move_to_buffer(
|
||||
|
||||
p2p_ops: list[P2POp] = []
|
||||
|
||||
# Pre-compute global ranks mapping
|
||||
ep_size = ep_group.size()
|
||||
rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)}
|
||||
|
||||
# 2. Post sends per layer
|
||||
for layer_idx in range(group_size):
|
||||
old_indices = old_indices_group[layer_idx]
|
||||
@ -241,13 +276,17 @@ def move_to_buffer(
|
||||
order = np.argsort(experts, kind="stable")
|
||||
experts = experts[order]
|
||||
srcs = srcs[order]
|
||||
|
||||
send_map, recv_map = get_ep_ranks_with_experts_batch(
|
||||
experts,
|
||||
num_local_experts,
|
||||
old_indices,
|
||||
layer_new_indices,
|
||||
)
|
||||
|
||||
for expert, src in zip(experts.tolist(), srcs.tolist()):
|
||||
ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert(
|
||||
expert,
|
||||
num_local_experts,
|
||||
old_indices,
|
||||
layer_new_indices,
|
||||
)
|
||||
ranks_to_send = send_map[expert]
|
||||
ranks_to_recv = recv_map[expert]
|
||||
if not ranks_to_send or not ranks_to_recv:
|
||||
continue
|
||||
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
|
||||
@ -260,7 +299,7 @@ def move_to_buffer(
|
||||
if recver_pos < len(ranks_to_recv):
|
||||
recv_ranks.append(ranks_to_recv[recver_pos])
|
||||
for dst in recv_ranks:
|
||||
dst_global = get_global_rank(ep_group, dst)
|
||||
dst_global = rank_to_global[dst]
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
torch.distributed.isend,
|
||||
@ -283,13 +322,18 @@ def move_to_buffer(
|
||||
order = np.argsort(experts, kind="stable")
|
||||
experts = experts[order]
|
||||
dsts = dsts[order]
|
||||
|
||||
# Batch query all experts for this layer
|
||||
send_map, recv_map = get_ep_ranks_with_experts_batch(
|
||||
experts,
|
||||
num_local_experts,
|
||||
old_indices,
|
||||
layer_new_indices,
|
||||
)
|
||||
|
||||
for expert, dst in zip(experts.tolist(), dsts.tolist()):
|
||||
ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert(
|
||||
expert,
|
||||
num_local_experts,
|
||||
old_indices,
|
||||
layer_new_indices,
|
||||
)
|
||||
ranks_to_send = send_map[expert]
|
||||
ranks_to_recv = recv_map[expert]
|
||||
if not ranks_to_send or not ranks_to_recv:
|
||||
continue
|
||||
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
|
||||
@ -299,7 +343,7 @@ def move_to_buffer(
|
||||
src = ranks_to_send[recver_pos // num_dst_per_sender]
|
||||
else:
|
||||
src = ranks_to_send[recver_pos - remainder_start]
|
||||
src_global = get_global_rank(ep_group, src)
|
||||
src_global = rank_to_global[src]
|
||||
p2p_ops += [
|
||||
P2POp(
|
||||
torch.distributed.irecv,
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user