mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-04-17 03:17:03 +08:00
Address review comments
Signed-off-by: ilmarkov <markovilya197@gmail.com>
This commit is contained in:
parent
a5ecdc18c0
commit
040ae89c5e
@ -93,7 +93,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
phy2log: [X, num_phy], logical expert id of each physical expert
|
phy2log: [X, num_phy], logical expert id of each physical expert
|
||||||
rank: [X, num_phy], the replica rank
|
replica_idx: [X, num_phy], the index of the replica for each logical expert
|
||||||
logcnt: [X, num_log], number of replicas for each logical expert
|
logcnt: [X, num_log], number of replicas for each logical expert
|
||||||
"""
|
"""
|
||||||
n, num_log = weight.shape
|
n, num_log = weight.shape
|
||||||
@ -101,15 +101,15 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
assert num_redundant >= 0
|
assert num_redundant >= 0
|
||||||
device = weight.device
|
device = weight.device
|
||||||
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
|
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
|
||||||
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
|
replica_idx = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
|
||||||
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
|
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
|
||||||
arangen = torch.arange(n, dtype=torch.int64, device=device)
|
arangen = torch.arange(n, dtype=torch.int64, device=device)
|
||||||
for i in range(num_log, num_phy):
|
for i in range(num_log, num_phy):
|
||||||
redundant_indices = (weight / logcnt).max(dim=-1).indices
|
redundant_indices = (weight / logcnt).max(dim=-1).indices
|
||||||
phy2log[:, i] = redundant_indices
|
phy2log[:, i] = redundant_indices
|
||||||
rank[:, i] = logcnt[arangen, redundant_indices]
|
replica_idx[:, i] = logcnt[arangen, redundant_indices]
|
||||||
logcnt[arangen, redundant_indices] += 1
|
logcnt[arangen, redundant_indices] += 1
|
||||||
return phy2log, rank, logcnt
|
return phy2log, replica_idx, logcnt
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def rebalance_experts_hierarchical(
|
def rebalance_experts_hierarchical(
|
||||||
@ -132,7 +132,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
Returns:
|
Returns:
|
||||||
phy2log: [layers, num_replicas], the expert
|
phy2log: [layers, num_replicas], the expert
|
||||||
index of each replica
|
index of each replica
|
||||||
log2phy: [layers, num_logical_experts, X],
|
pphy_replicas_idx: [layers, num_logical_experts, X],
|
||||||
the replica indices for each expert
|
the replica indices for each expert
|
||||||
logcnt: [layers, num_logical_experts], number of
|
logcnt: [layers, num_logical_experts], number of
|
||||||
physical replicas for each logical expert
|
physical replicas for each logical expert
|
||||||
@ -177,7 +177,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
tokens_per_mlog = weight.gather(-1, mlog2log).view(
|
tokens_per_mlog = weight.gather(-1, mlog2log).view(
|
||||||
-1, num_logical_experts // num_nodes
|
-1, num_logical_experts // num_nodes
|
||||||
)
|
)
|
||||||
phy2mlog, phyrank, mlogcnt = cls.replicate_experts(
|
phy2mlog, replicas_idx, mlogcnt = cls.replicate_experts(
|
||||||
tokens_per_mlog, num_physical_experts // num_nodes
|
tokens_per_mlog, num_physical_experts // num_nodes
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -203,15 +203,15 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
).view(1, -1, 1)
|
).view(1, -1, 1)
|
||||||
).flatten(-2)
|
).flatten(-2)
|
||||||
pphy2log = mlog2log.gather(-1, pphy2mlog)
|
pphy2log = mlog2log.gather(-1, pphy2mlog)
|
||||||
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
|
pphy_replicas_idx = replicas_idx.gather(-1, pphy2phy).view(num_layers, -1)
|
||||||
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
|
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
|
||||||
return pphy2log, pphyrank, logcnt
|
return pphy2log, pphy_replicas_idx, logcnt
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def preserve_intragpu_slots(
|
def preserve_intragpu_slots(
|
||||||
cls,
|
cls,
|
||||||
phy2log: torch.Tensor,
|
phy2log: torch.Tensor,
|
||||||
phyrank: torch.Tensor,
|
phy_replicas_idx: torch.Tensor,
|
||||||
num_ranks: int,
|
num_ranks: int,
|
||||||
old_global_expert_indices: torch.Tensor,
|
old_global_expert_indices: torch.Tensor,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
@ -223,56 +223,52 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
the old and new mappings.
|
the old and new mappings.
|
||||||
"""
|
"""
|
||||||
device = phy2log.device
|
device = phy2log.device
|
||||||
new_num_phy = phy2log.shape[1]
|
num_phy_experts = phy2log.shape[1]
|
||||||
old_num_phy = old_global_expert_indices.shape[1]
|
if num_ranks <= 0 or num_phy_experts % num_ranks != 0:
|
||||||
if (
|
return phy2log, phy_replicas_idx
|
||||||
num_ranks <= 0
|
|
||||||
or new_num_phy % num_ranks != 0
|
|
||||||
or old_num_phy % num_ranks != 0
|
|
||||||
or (new_num_phy // num_ranks) != (old_num_phy // num_ranks)
|
|
||||||
):
|
|
||||||
return phy2log, phyrank
|
|
||||||
|
|
||||||
# Move to CPU and convert to NumPy for processing
|
# Move to CPU and convert to NumPy for processing
|
||||||
phy2log_np = phy2log.cpu().numpy()
|
new_phy2log_np = phy2log.cpu().numpy()
|
||||||
phyrank_np = phyrank.cpu().numpy()
|
replicas_idx_np = phy_replicas_idx.cpu().numpy()
|
||||||
old_np = old_global_expert_indices.cpu().numpy()
|
old_phy2log_np = old_global_expert_indices.cpu().numpy()
|
||||||
|
|
||||||
slots_per_gpu = new_num_phy // num_ranks
|
slots_per_gpu = num_phy_experts // num_ranks
|
||||||
num_layers = phy2log_np.shape[0]
|
num_layers = new_phy2log_np.shape[0]
|
||||||
|
|
||||||
post_phy2log_np = phy2log_np.copy()
|
post_phy2log_np = new_phy2log_np.copy()
|
||||||
post_phyrank_np = phyrank_np.copy()
|
post_phy_replicas_idx_np = replicas_idx_np.copy()
|
||||||
|
|
||||||
for gpu_idx in range(num_ranks):
|
for gpu_idx in range(num_ranks):
|
||||||
start = gpu_idx * slots_per_gpu
|
start = gpu_idx * slots_per_gpu
|
||||||
end = start + slots_per_gpu
|
end = start + slots_per_gpu
|
||||||
# Segments across all layers for this GPU
|
# Experts across all layers for this GPU
|
||||||
old_seg = old_np[:, start:end] # [L, S]
|
old_local = old_phy2log_np[:, start:end] # [layers, slots]
|
||||||
new_seg = phy2log_np[:, start:end] # [L, S]
|
new_local = new_phy2log_np[:, start:end] # [layers, slots]
|
||||||
new_rnk = phyrank_np[:, start:end] # [L, S]
|
new_ridx = replicas_idx_np[:, start:end] # [layers, slots]
|
||||||
|
|
||||||
used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool)
|
used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool)
|
||||||
preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool)
|
preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool)
|
||||||
|
|
||||||
# First pass: preserve same-logical experts in their previous slots
|
# First pass: preserve same-logical experts in their previous slots
|
||||||
for pos in range(slots_per_gpu):
|
for slot_idx in range(slots_per_gpu):
|
||||||
# matches: [L, S], True where new_seg has the same logical value
|
# matches: [layers, slots], True where new local experts have
|
||||||
# as the old slot 'pos' and not used
|
# the same logical value as the old from 'slot_idx' and not checked yet
|
||||||
matches = (new_seg == old_seg[:, pos][:, None]) & (~used_new_indices)
|
matches = (new_local == old_local[:, slot_idx][:, None]) & (
|
||||||
|
~used_new_indices
|
||||||
|
)
|
||||||
has_any = matches.any(axis=1)
|
has_any = matches.any(axis=1)
|
||||||
if np.any(has_any):
|
if np.any(has_any):
|
||||||
first_idx = np.argmax(matches, axis=1)
|
first_idx = np.argmax(matches, axis=1)
|
||||||
layer_indices = np.nonzero(has_any)[0]
|
layer_indices = np.nonzero(has_any)[0]
|
||||||
matched_new_positions = first_idx[layer_indices]
|
matched_new_positions = first_idx[layer_indices]
|
||||||
post_phy2log_np[layer_indices, start + pos] = new_seg[
|
post_phy2log_np[layer_indices, start + slot_idx] = new_local[
|
||||||
layer_indices, matched_new_positions
|
|
||||||
]
|
|
||||||
post_phyrank_np[layer_indices, start + pos] = new_rnk[
|
|
||||||
layer_indices, matched_new_positions
|
layer_indices, matched_new_positions
|
||||||
]
|
]
|
||||||
|
post_phy_replicas_idx_np[layer_indices, start + slot_idx] = (
|
||||||
|
new_ridx[layer_indices, matched_new_positions]
|
||||||
|
)
|
||||||
used_new_indices[layer_indices, matched_new_positions] = True
|
used_new_indices[layer_indices, matched_new_positions] = True
|
||||||
preserved_positions[layer_indices, pos] = True
|
preserved_positions[layer_indices, slot_idx] = True
|
||||||
|
|
||||||
# Second pass: fill remaining slots with remaining new experts
|
# Second pass: fill remaining slots with remaining new experts
|
||||||
remaining_mask = ~used_new_indices # [L, S]
|
remaining_mask = ~used_new_indices # [L, S]
|
||||||
@ -299,17 +295,17 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
continue
|
continue
|
||||||
src_pos = remaining_indices[layer_idx, :k]
|
src_pos = remaining_indices[layer_idx, :k]
|
||||||
dst_pos = fill_indices[layer_idx, :k]
|
dst_pos = fill_indices[layer_idx, :k]
|
||||||
post_phy2log_np[layer_idx, start + dst_pos] = new_seg[
|
post_phy2log_np[layer_idx, start + dst_pos] = new_local[
|
||||||
layer_idx, src_pos
|
layer_idx, src_pos
|
||||||
]
|
]
|
||||||
post_phyrank_np[layer_idx, start + dst_pos] = new_rnk[
|
post_phy_replicas_idx_np[layer_idx, start + dst_pos] = new_ridx[
|
||||||
layer_idx, src_pos
|
layer_idx, src_pos
|
||||||
]
|
]
|
||||||
|
|
||||||
# Convert back to torch and move to original device
|
# Convert back to torch and move to original device
|
||||||
post_phy2log = torch.from_numpy(post_phy2log_np).to(device)
|
post_phy2log = torch.from_numpy(post_phy2log_np).to(device)
|
||||||
post_phyrank = torch.from_numpy(post_phyrank_np).to(device)
|
post_phy_replicas_idx = torch.from_numpy(post_phy_replicas_idx_np).to(device)
|
||||||
return post_phy2log, post_phyrank
|
return post_phy2log, post_phy_replicas_idx
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def rebalance_experts(
|
def rebalance_experts(
|
||||||
@ -348,12 +344,12 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
weight = weight.float()
|
weight = weight.float()
|
||||||
if num_groups % num_nodes == 0:
|
if num_groups % num_nodes == 0:
|
||||||
# use hierarchical load-balance policy
|
# use hierarchical load-balance policy
|
||||||
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical(
|
phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical(
|
||||||
weight, num_replicas, num_groups, num_nodes, num_ranks
|
weight, num_replicas, num_groups, num_nodes, num_ranks
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# use global load-balance policy
|
# use global load-balance policy
|
||||||
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical(
|
phy2log, phy_replicas_idx, logcnt = cls.rebalance_experts_hierarchical(
|
||||||
weight, num_replicas, 1, 1, num_ranks
|
weight, num_replicas, 1, 1, num_ranks
|
||||||
)
|
)
|
||||||
# Optional postprocessing to preserve slots for experts moving
|
# Optional postprocessing to preserve slots for experts moving
|
||||||
@ -362,8 +358,8 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
# Helps to avoid unnecessary weight copying when experts move
|
# Helps to avoid unnecessary weight copying when experts move
|
||||||
# within the same GPU.
|
# within the same GPU.
|
||||||
if old_global_expert_indices is not None:
|
if old_global_expert_indices is not None:
|
||||||
phy2log, phyrank = cls.preserve_intragpu_slots(
|
phy2log, phy_replicas_idx = cls.preserve_intragpu_slots(
|
||||||
phy2log, phyrank, num_ranks, old_global_expert_indices
|
phy2log, phy_replicas_idx, num_ranks, old_global_expert_indices
|
||||||
)
|
)
|
||||||
num_redundant_experts = num_replicas - num_logical_experts
|
num_redundant_experts = num_replicas - num_logical_experts
|
||||||
maxlogcnt = num_redundant_experts + 1
|
maxlogcnt = num_redundant_experts + 1
|
||||||
@ -375,7 +371,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
)
|
)
|
||||||
log2phy.view(num_layers, -1).scatter_(
|
log2phy.view(num_layers, -1).scatter_(
|
||||||
-1,
|
-1,
|
||||||
phy2log * maxlogcnt + phyrank,
|
phy2log * maxlogcnt + phy_replicas_idx,
|
||||||
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
|
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
|
||||||
num_layers, -1
|
num_layers, -1
|
||||||
),
|
),
|
||||||
|
|||||||
@ -361,24 +361,23 @@ def move_from_buffer(
|
|||||||
is_received_locally: np.ndarray,
|
is_received_locally: np.ndarray,
|
||||||
recv_metadata: RecvMetadata,
|
recv_metadata: RecvMetadata,
|
||||||
new_indices: np.ndarray,
|
new_indices: np.ndarray,
|
||||||
ep_group: ProcessGroup,
|
ep_rank: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Copies expert weights from communication buffers back to the target weight tensors
|
Copies expert weights from communication buffers back to the target weight tensors
|
||||||
after EPLB rebalancing.
|
after EPLB rebalancing.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
expert_weights: List of weight tensors for a MoE layer.
|
expert_weights: List of the actual MoE layer weights used in the execution.
|
||||||
expert_weights_buffers: Intermediate buffers matching ``expert_weights``.
|
expert_weights_buffers: Intermediate buffers containing the experts weights
|
||||||
|
after the transfer is completed.
|
||||||
is_unchanged: (num_local_experts,), True where an expert row is unchanged.
|
is_unchanged: (num_local_experts,), True where an expert row is unchanged.
|
||||||
is_received_locally: (num_local_experts,), True where a row is updated locally.
|
is_received_locally: (num_local_experts,), True where a row is updated locally.
|
||||||
recv_metadata: RecvMetadata containing remote receive metadata.
|
recv_metadata: RecvMetadata containing remote receive metadata.
|
||||||
new_indices: (num_experts_total,) mapping from local rows to desired
|
new_indices: (num_experts_total,) mapping from local rows to desired
|
||||||
(possibly global) expert id, after rebalance.
|
(possibly global) expert id, after rebalance.
|
||||||
ep_group: torch.distributed.ProcessGroup for expert parallel communication
|
ep_rank: Rank of the process in the expert parallel group.
|
||||||
domain.
|
|
||||||
"""
|
"""
|
||||||
ep_rank = ep_group.rank()
|
|
||||||
recv_primary_mask = recv_metadata.recv_primary_mask
|
recv_primary_mask = recv_metadata.recv_primary_mask
|
||||||
recv_count = recv_metadata.recv_count
|
recv_count = recv_metadata.recv_count
|
||||||
recv_expert_ids = recv_metadata.recv_expert_ids
|
recv_expert_ids = recv_metadata.recv_expert_ids
|
||||||
@ -395,16 +394,17 @@ def move_from_buffer(
|
|||||||
for w, b in zip(expert_weights, expert_weights_buffers):
|
for w, b in zip(expert_weights, expert_weights_buffers):
|
||||||
w[dst].copy_(b[dst], non_blocking=True)
|
w[dst].copy_(b[dst], non_blocking=True)
|
||||||
|
|
||||||
# Duplicate remote received rows to non-primary duplicate dsts
|
|
||||||
if recv_count == 0:
|
if recv_count == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# Duplicate remote received rows to non-primary duplicate dsts
|
||||||
base = ep_rank * num_local_experts
|
base = ep_rank * num_local_experts
|
||||||
local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)]
|
local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)]
|
||||||
duplicate_mask = np.logical_and(
|
duplicate_mask = np.logical_and(
|
||||||
np.logical_and(~is_unchanged, ~is_received_locally),
|
np.logical_and(~is_unchanged, ~is_received_locally),
|
||||||
np.logical_and(~recv_primary_mask, local_experts != -1),
|
np.logical_and(~recv_primary_mask, local_experts != -1),
|
||||||
)
|
)
|
||||||
|
# All received experts are unique in the destination, so no need to copy duplicates
|
||||||
if not bool(duplicate_mask.any()):
|
if not bool(duplicate_mask.any()):
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -607,7 +607,7 @@ def rearrange_expert_weights_inplace(
|
|||||||
is_received_locally=is_received_locally,
|
is_received_locally=is_received_locally,
|
||||||
recv_metadata=recv_metadata,
|
recv_metadata=recv_metadata,
|
||||||
new_indices=new_global_expert_indices_cpu[layer_idx],
|
new_indices=new_global_expert_indices_cpu[layer_idx],
|
||||||
ep_group=ep_group,
|
ep_rank=ep_group.rank(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user