Optimize after codex review

Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
ilmarkov 2025-11-26 17:31:13 +00:00
parent 30bab971c0
commit c6f14d1a27

View File

@ -7,7 +7,6 @@ This involves the exchange of expert weights between GPUs.
""" """
from collections.abc import Iterable, Sequence from collections.abc import Iterable, Sequence
from functools import partial
import numpy as np import numpy as np
import torch import torch
@ -123,6 +122,10 @@ def move_to_buffer(
is_unchanged = np.zeros((group_size, num_local_experts), dtype=np.bool_) is_unchanged = np.zeros((group_size, num_local_experts), dtype=np.bool_)
is_received_locally = np.zeros((group_size, num_local_experts), dtype=np.bool_) is_received_locally = np.zeros((group_size, num_local_experts), dtype=np.bool_)
recv_primary_mask = np.zeros((group_size, num_local_experts), dtype=np.bool_) recv_primary_mask = np.zeros((group_size, num_local_experts), dtype=np.bool_)
# Cache desired new expert ids per local row, for all layers
new_local_expert_ids_mat = np.full(
(group_size, num_local_experts), -1, dtype=np.int64
)
send_counts = np.zeros(group_size, dtype=np.int32) send_counts = np.zeros(group_size, dtype=np.int32)
send_expert_ids = np.full((group_size, num_local_experts), -1, dtype=np.int64) send_expert_ids = np.full((group_size, num_local_experts), -1, dtype=np.int64)
send_src_rows = np.full((group_size, num_local_experts), -1, dtype=np.int32) send_src_rows = np.full((group_size, num_local_experts), -1, dtype=np.int32)
@ -140,6 +143,7 @@ def move_to_buffer(
old_local_expert_ids = old_indices[local_global] old_local_expert_ids = old_indices[local_global]
new_local_expert_ids = layer_new_indices[local_global] new_local_expert_ids = layer_new_indices[local_global]
new_local_expert_ids_mat[layer_idx, :] = new_local_expert_ids
# Unchanged per-dst mask # Unchanged per-dst mask
unchanged_mask = old_local_expert_ids == new_local_expert_ids unchanged_mask = old_local_expert_ids == new_local_expert_ids
@ -187,34 +191,41 @@ def move_to_buffer(
else: else:
recv_counts[layer_idx] = 0 recv_counts[layer_idx] = 0
# Precompute per-layer destination mask that actually needs local buffering:
# need change, received locally, and valid target expert id
eligible_local_buffer_mask = np.logical_and(
np.logical_and(~is_unchanged, is_received_locally),
new_local_expert_ids_mat != -1,
)
# 1. Local moves into tmp buffers # 1. Local moves into tmp buffers
for layer_idx in range(group_size): for layer_idx in range(group_size):
layer_is_unchanged = is_unchanged[layer_idx, :]
layer_is_received_locally = is_received_locally[layer_idx, :]
layer_new_indices = new_indices_group[layer_idx]
layer_send_count = int(send_counts[layer_idx]) layer_send_count = int(send_counts[layer_idx])
if layer_send_count <= 0:
continue
layer_send_experts = send_expert_ids[layer_idx, :layer_send_count] layer_send_experts = send_expert_ids[layer_idx, :layer_send_count]
layer_send_srcs = send_src_rows[layer_idx, :layer_send_count] layer_send_srcs = send_src_rows[layer_idx, :layer_send_count]
local2global = partial(
idx_local_to_global,
local_cnt=num_local_experts,
ep_rank=ep_rank,
)
layer_weights_list = list(expert_weights_group[layer_idx]) layer_weights_list = list(expert_weights_group[layer_idx])
layer_buffers_list = list(buffers_group[layer_idx]) layer_buffers_list = list(buffers_group[layer_idx])
for dst in range(num_local_experts): new_local_expert_ids = new_local_expert_ids_mat[layer_idx, :]
if layer_is_unchanged[dst] or not layer_is_received_locally[dst]:
continue # Only consider destination rows that are eligible for local buffering
dst_global = local2global(dst) eligible_mask = eligible_local_buffer_mask[layer_idx, :]
expert = layer_new_indices[dst_global] if not bool(eligible_mask.any()):
if expert == -1: continue
continue
matches = np.nonzero(layer_send_experts == expert)[0] dest_indices = np.nonzero(eligible_mask)[0].tolist()
if matches.size == 0: # Build a map from expert_id to its source row.
continue expert_to_src_map = dict(zip(layer_send_experts, layer_send_srcs))
src_local = int(layer_send_srcs[matches[0]])
for w, b in zip(layer_weights_list, layer_buffers_list): for dst in dest_indices:
b[dst].copy_(w[src_local]) expert = new_local_expert_ids[dst]
src_local = expert_to_src_map.get(expert, -1)
if src_local != -1:
for w, b in zip(layer_weights_list, layer_buffers_list):
b[dst].copy_(w[src_local])
p2p_ops: list[P2POp] = [] p2p_ops: list[P2POp] = []
# 2. Post sends per layer # 2. Post sends per layer