mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-05 19:47:23 +08:00
Optimize after codex review
Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
parent
30bab971c0
commit
c6f14d1a27
@ -7,7 +7,6 @@ This involves the exchange of expert weights between GPUs.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable, Sequence
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -123,6 +122,10 @@ def move_to_buffer(
|
||||
is_unchanged = np.zeros((group_size, num_local_experts), dtype=np.bool_)
|
||||
is_received_locally = np.zeros((group_size, num_local_experts), dtype=np.bool_)
|
||||
recv_primary_mask = np.zeros((group_size, num_local_experts), dtype=np.bool_)
|
||||
# Cache desired new expert ids per local row, for all layers
|
||||
new_local_expert_ids_mat = np.full(
|
||||
(group_size, num_local_experts), -1, dtype=np.int64
|
||||
)
|
||||
send_counts = np.zeros(group_size, dtype=np.int32)
|
||||
send_expert_ids = np.full((group_size, num_local_experts), -1, dtype=np.int64)
|
||||
send_src_rows = np.full((group_size, num_local_experts), -1, dtype=np.int32)
|
||||
@ -140,6 +143,7 @@ def move_to_buffer(
|
||||
|
||||
old_local_expert_ids = old_indices[local_global]
|
||||
new_local_expert_ids = layer_new_indices[local_global]
|
||||
new_local_expert_ids_mat[layer_idx, :] = new_local_expert_ids
|
||||
|
||||
# Unchanged per-dst mask
|
||||
unchanged_mask = old_local_expert_ids == new_local_expert_ids
|
||||
@ -187,34 +191,41 @@ def move_to_buffer(
|
||||
else:
|
||||
recv_counts[layer_idx] = 0
|
||||
|
||||
# Precompute per-layer destination mask that actually needs local buffering:
|
||||
# need change, received locally, and valid target expert id
|
||||
eligible_local_buffer_mask = np.logical_and(
|
||||
np.logical_and(~is_unchanged, is_received_locally),
|
||||
new_local_expert_ids_mat != -1,
|
||||
)
|
||||
|
||||
# 1. Local moves into tmp buffers
|
||||
for layer_idx in range(group_size):
|
||||
layer_is_unchanged = is_unchanged[layer_idx, :]
|
||||
layer_is_received_locally = is_received_locally[layer_idx, :]
|
||||
layer_new_indices = new_indices_group[layer_idx]
|
||||
layer_send_count = int(send_counts[layer_idx])
|
||||
if layer_send_count <= 0:
|
||||
continue
|
||||
|
||||
layer_send_experts = send_expert_ids[layer_idx, :layer_send_count]
|
||||
layer_send_srcs = send_src_rows[layer_idx, :layer_send_count]
|
||||
local2global = partial(
|
||||
idx_local_to_global,
|
||||
local_cnt=num_local_experts,
|
||||
ep_rank=ep_rank,
|
||||
)
|
||||
layer_weights_list = list(expert_weights_group[layer_idx])
|
||||
layer_buffers_list = list(buffers_group[layer_idx])
|
||||
for dst in range(num_local_experts):
|
||||
if layer_is_unchanged[dst] or not layer_is_received_locally[dst]:
|
||||
continue
|
||||
dst_global = local2global(dst)
|
||||
expert = layer_new_indices[dst_global]
|
||||
if expert == -1:
|
||||
continue
|
||||
matches = np.nonzero(layer_send_experts == expert)[0]
|
||||
if matches.size == 0:
|
||||
continue
|
||||
src_local = int(layer_send_srcs[matches[0]])
|
||||
for w, b in zip(layer_weights_list, layer_buffers_list):
|
||||
b[dst].copy_(w[src_local])
|
||||
new_local_expert_ids = new_local_expert_ids_mat[layer_idx, :]
|
||||
|
||||
# Only consider destination rows that are eligible for local buffering
|
||||
eligible_mask = eligible_local_buffer_mask[layer_idx, :]
|
||||
if not bool(eligible_mask.any()):
|
||||
continue
|
||||
|
||||
dest_indices = np.nonzero(eligible_mask)[0].tolist()
|
||||
# Build a map from expert_id to its source row.
|
||||
expert_to_src_map = dict(zip(layer_send_experts, layer_send_srcs))
|
||||
|
||||
for dst in dest_indices:
|
||||
expert = new_local_expert_ids[dst]
|
||||
src_local = expert_to_src_map.get(expert, -1)
|
||||
if src_local != -1:
|
||||
for w, b in zip(layer_weights_list, layer_buffers_list):
|
||||
b[dst].copy_(w[src_local])
|
||||
|
||||
p2p_ops: list[P2POp] = []
|
||||
|
||||
# 2. Post sends per layer
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user