mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-03-25 10:53:41 +08:00
Address review comments
Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
parent
a5ecdc18c0
commit
040ae89c5e
@ -93,7 +93,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
||||
|
||||
Returns:
|
||||
phy2log: [X, num_phy], logical expert id of each physical expert
|
||||
rank: [X, num_phy], the replica rank
|
||||
replica_idx: [X, num_phy], the index of the replica for each logical expert
|
||||
logcnt: [X, num_log], number of replicas for each logical expert
|
||||
"""
|
||||
n, num_log = weight.shape
|
||||
@ -101,15 +101,15 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
||||
assert num_redundant >= 0
|
||||
device = weight.device
|
||||
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
|
||||
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
|
||||
replica_idx = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
|
||||
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
|
||||
arangen = torch.arange(n, dtype=torch.int64, device=device)
|
||||
for i in range(num_log, num_phy):
|
||||
redundant_indices = (weight / logcnt).max(dim=-1).indices
|
||||
phy2log[:, i] = redundant_indices
|
||||
rank[:, i] = logcnt[arangen, redundant_indices]
|
||||
replica_idx[:, i] = logcnt[arangen, redundant_indices]
|
||||
logcnt[arangen, redundant_indices] += 1
|
||||
return phy2log, rank, logcnt
|
||||
return phy2log, replica_idx, logcnt
|
||||
|
||||
@classmethod
|
||||
def rebalance_experts_hierarchical(
|
||||
@ -132,7 +132,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
||||
Returns:
|
||||
phy2log: [layers, num_replicas], the expert
|
||||
index of each replica
|
||||
log2phy: [layers, num_logical_experts, X],
|
||||
pphy_replicas_idx: [layers, num_logical_experts, X],
|
||||
the replica indices for each expert
|
||||
logcnt: [layers, num_logical_experts], number of
|
||||
physical replicas for each logical expert
|
||||
@ -177,7 +177,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
||||
tokens_per_mlog = weight.gather(-1, mlog2log).view(
|
||||
-1, num_logical_experts // num_nodes
|
||||
)
|
||||
phy2mlog, phyrank, mlogcnt = cls.replicate_experts(
|
||||
phy2mlog, replicas_idx, mlogcnt = cls.replicate_experts(
|
||||
tokens_per_mlog, num_physical_experts // num_nodes
|
||||
)
|
||||
|
||||
@ -203,15 +203,15 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
||||
).view(1, -1, 1)
|
||||
).flatten(-2)
|
||||
pphy2log = mlog2log.gather(-1, pphy2mlog)
|
||||
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
|
||||
pphy_replicas_idx = replicas_idx.gather(-1, pphy2phy).view(num_layers, -1)
|
||||
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
|
||||
return pphy2log, pphyrank, logcnt
|
||||
return pphy2log, pphy_replicas_idx, logcnt
|
||||
|
||||
@classmethod
|
||||
def preserve_intragpu_slots(
|
||||
cls,
|
||||
phy2log: torch.Tensor,
|
||||
phyrank: torch.Tensor,
|
||||
phy_replicas_idx: torch.Tensor,
|
||||
num_ranks: int,
|
||||
old_global_expert_indices: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
@ -223,56 +223,52 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
||||
the old and new mappings.
|
||||
"""
|
||||
device = phy2log.device
|
||||
new_num_phy = phy2log.shape[1]
|
||||
old_num_phy = old_global_expert_indices.shape[1]
|
||||
if (
|
||||
num_ranks <= 0
|
||||
or new_num_phy % num_ranks != 0
|
||||
or old_num_phy % num_ranks != 0
|
||||
or (new_num_phy // num_ranks) != (old_num_phy // num_ranks)
|
||||
):
|
||||
return phy2log, phyrank
|
||||
num_phy_experts = phy2log.shape[1]
|
||||
if num_ranks <= 0 or num_phy_experts % num_ranks != 0:
|
||||
return phy2log, phy_replicas_idx
|
||||
|
||||
# Move to CPU and convert to NumPy for processing
|
||||
phy2log_np = phy2log.cpu().numpy()
|
||||
phyrank_np = phyrank.cpu().numpy()
|
||||
old_np = old_global_expert_indices.cpu().numpy()
|
||||
new_phy2log_np = phy2log.cpu().numpy()
|
||||
replicas_idx_np = phy_replicas_idx.cpu().numpy()
|
||||
old_phy2log_np = old_global_expert_indices.cpu().numpy()
|
||||
|
||||
slots_per_gpu = new_num_phy // num_ranks
|
||||
num_layers = phy2log_np.shape[0]
|
||||
slots_per_gpu = num_phy_experts // num_ranks
|
||||
num_layers = new_phy2log_np.shape[0]
|
||||
|
||||
post_phy2log_np = phy2log_np.copy()
|
||||
post_phyrank_np = phyrank_np.copy()
|
||||
post_phy2log_np = new_phy2log_np.copy()
|
||||
post_phy_replicas_idx_np = replicas_idx_np.copy()
|
||||
|
||||
for gpu_idx in range(num_ranks):
|
||||
start = gpu_idx * slots_per_gpu
|
||||
end = start + slots_per_gpu
|
||||
# Segments across all layers for this GPU
|
||||
old_seg = old_np[:, start:end] # [L, S]
|
||||
new_seg = phy2log_np[:, start:end] # [L, S]
|
||||
new_rnk = phyrank_np[:, start:end] # [L, S]
|
||||
# Experts across all layers for this GPU
|
||||
old_local = old_phy2log_np[:, start:end] # [layers, slots]
|
||||
new_local = new_phy2log_np[:, start:end] # [layers, slots]
|
||||
new_ridx = replicas_idx_np[:, start:end] # [layers, slots]
|
||||
|
||||
used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool)
|
||||
preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool)
|
||||
|
||||
# First pass: preserve same-logical experts in their previous slots
|
||||
for pos in range(slots_per_gpu):
|
||||
# matches: [L, S], True where new_seg has the same logical value
|
||||
# as the old slot 'pos' and not used
|
||||
matches = (new_seg == old_seg[:, pos][:, None]) & (~used_new_indices)
|
||||
for slot_idx in range(slots_per_gpu):
|
||||
# matches: [layers, slots], True where new local experts have
|
||||
# the same logical value as the old from 'slot_idx' and not checked yet
|
||||
matches = (new_local == old_local[:, slot_idx][:, None]) & (
|
||||
~used_new_indices
|
||||
)
|
||||
has_any = matches.any(axis=1)
|
||||
if np.any(has_any):
|
||||
first_idx = np.argmax(matches, axis=1)
|
||||
layer_indices = np.nonzero(has_any)[0]
|
||||
matched_new_positions = first_idx[layer_indices]
|
||||
post_phy2log_np[layer_indices, start + pos] = new_seg[
|
||||
layer_indices, matched_new_positions
|
||||
]
|
||||
post_phyrank_np[layer_indices, start + pos] = new_rnk[
|
||||
post_phy2log_np[layer_indices, start + slot_idx] = new_local[
|
||||
layer_indices, matched_new_positions
|
||||
]
|
||||
post_phy_replicas_idx_np[layer_indices, start + slot_idx] = (
|
||||
new_ridx[layer_indices, matched_new_positions]
|
||||
)
|
||||
used_new_indices[layer_indices, matched_new_positions] = True
|
||||
preserved_positions[layer_indices, pos] = True
|
||||
preserved_positions[layer_indices, slot_idx] = True
|
||||
|
||||
# Second pass: fill remaining slots with remaining new experts
|
||||
remaining_mask = ~used_new_indices # [L, S]
|
||||
@ -299,17 +295,17 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
||||
continue
|
||||
src_pos = remaining_indices[layer_idx, :k]
|
||||
dst_pos = fill_indices[layer_idx, :k]
|
||||
post_phy2log_np[layer_idx, start + dst_pos] = new_seg[
|
||||
post_phy2log_np[layer_idx, start + dst_pos] = new_local[
|
||||
layer_idx, src_pos
|
||||
]
|
||||
post_phyrank_np[layer_idx, start + dst_pos] = new_rnk[
|
||||
post_phy_replicas_idx_np[layer_idx, start + dst_pos] = new_ridx[
|
||||
layer_idx, src_pos
|
||||
]
|
||||
|
||||
# Convert back to torch and move to original device
|
||||
post_phy2log = torch.from_numpy(post_phy2log_np).to(device)
|
||||
post_phyrank = torch.from_numpy(post_phyrank_np).to(device)
|
||||
return post_phy2log, post_phyrank
|
||||
post_phy_replicas_idx = torch.from_numpy(post_phy_replicas_idx_np).to(device)
|
||||
return post_phy2log, post_phy_replicas_idx
|
||||
|
||||
@classmethod
|
||||
def rebalance_experts(
|
||||
@ -348,12 +344,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
||||
weight = weight.float()
|
||||
if num_groups % num_nodes == 0:
|
||||
# use hierarchical load-balance policy
|
||||
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical(
|
||||
phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical(
|
||||
weight, num_replicas, num_groups, num_nodes, num_ranks
|
||||
)
|
||||
else:
|
||||
# use global load-balance policy
|
||||
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical(
|
||||
phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical(
|
||||
weight, num_replicas, 1, 1, num_ranks
|
||||
)
|
||||
# Optional postprocessing to preserve slots for experts moving
|
||||
@ -362,8 +358,8 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
||||
# Helps to avoid unnecessary weight copying when experts move
|
||||
# within the same GPU.
|
||||
if old_global_expert_indices is not None:
|
||||
phy2log, phyrank = cls.preserve_intragpu_slots(
|
||||
phy2log, phyrank, num_ranks, old_global_expert_indices
|
||||
phy2log, phy_replicas_idx = cls.preserve_intragpu_slots(
|
||||
phy2log, phy_replicas_idx, num_ranks, old_global_expert_indices
|
||||
)
|
||||
num_redundant_experts = num_replicas - num_logical_experts
|
||||
maxlogcnt = num_redundant_experts + 1
|
||||
@ -375,7 +371,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
||||
)
|
||||
log2phy.view(num_layers, -1).scatter_(
|
||||
-1,
|
||||
phy2log * maxlogcnt + phyrank,
|
||||
phy2log * maxlogcnt + phy_replicas_idx,
|
||||
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
|
||||
num_layers, -1
|
||||
),
|
||||
|
||||
@ -361,24 +361,23 @@ def move_from_buffer(
|
||||
is_received_locally: np.ndarray,
|
||||
recv_metadata: RecvMetadata,
|
||||
new_indices: np.ndarray,
|
||||
ep_group: ProcessGroup,
|
||||
ep_rank: int,
|
||||
) -> None:
|
||||
"""
|
||||
Copies expert weights from communication buffers back to the target weight tensors
|
||||
after EPLB rebalancing.
|
||||
|
||||
Args:
|
||||
expert_weights: List of weight tensors for a MoE layer.
|
||||
expert_weights_buffers: Intermediate buffers matching ``expert_weights``.
|
||||
expert_weights: List of the actual MoE layer weights used in the execution.
|
||||
expert_weights_buffers: Intermediate buffers containing the experts weights
|
||||
after the transfer is completed.
|
||||
is_unchanged: (num_local_experts,), True where an expert row is unchanged.
|
||||
is_received_locally: (num_local_experts,), True where a row is updated locally.
|
||||
recv_metadata: RecvMetadata containing remote receive metadata.
|
||||
new_indices: (num_experts_total,) mapping from local rows to desired
|
||||
(possibly global) expert id, after rebalance.
|
||||
ep_group: torch.distributed.ProcessGroup for expert parallel communication
|
||||
domain.
|
||||
ep_rank: Rank of the process in the expert parallel group.
|
||||
"""
|
||||
ep_rank = ep_group.rank()
|
||||
recv_primary_mask = recv_metadata.recv_primary_mask
|
||||
recv_count = recv_metadata.recv_count
|
||||
recv_expert_ids = recv_metadata.recv_expert_ids
|
||||
@ -395,16 +394,17 @@ def move_from_buffer(
|
||||
for w, b in zip(expert_weights, expert_weights_buffers):
|
||||
w[dst].copy_(b[dst], non_blocking=True)
|
||||
|
||||
# Duplicate remote received rows to non-primary duplicate dsts
|
||||
if recv_count == 0:
|
||||
return
|
||||
|
||||
# Duplicate remote received rows to non-primary duplicate dsts
|
||||
base = ep_rank * num_local_experts
|
||||
local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)]
|
||||
duplicate_mask = np.logical_and(
|
||||
np.logical_and(~is_unchanged, ~is_received_locally),
|
||||
np.logical_and(~recv_primary_mask, local_experts != -1),
|
||||
)
|
||||
# All received experts are unique in the destination, so no need to copy duplicates
|
||||
if not bool(duplicate_mask.any()):
|
||||
return
|
||||
|
||||
@ -607,7 +607,7 @@ def rearrange_expert_weights_inplace(
|
||||
is_received_locally=is_received_locally,
|
||||
recv_metadata=recv_metadata,
|
||||
new_indices=new_global_expert_indices_cpu[layer_idx],
|
||||
ep_group=ep_group,
|
||||
ep_rank=ep_group.rank(),
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user