mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2026-05-30 21:07:09 +08:00
Merge 5cd45c646f53923fc5b5cc046aa044c1ce94aa08 into 254f6b986720c92ddf97fbb1a6a6465da8e87e29
This commit is contained in:
commit
74ee657332
@ -1,6 +1,7 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -310,3 +311,143 @@ if __name__ == "__main__":
|
|||||||
print(phy2log)
|
print(phy2log)
|
||||||
|
|
||||||
test_basic_rebalance()
|
test_basic_rebalance()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_phy_replicas_idx_from_phy2log(phy2log: np.ndarray) -> np.ndarray:
|
||||||
|
"""Create replicas indices mapping from phy2log."""
|
||||||
|
pr = np.zeros_like(phy2log, dtype=np.int64)
|
||||||
|
for layer in range(phy2log.shape[0]):
|
||||||
|
seen: dict[int, int] = {}
|
||||||
|
row = phy2log[layer].tolist()
|
||||||
|
for i, expert in enumerate(row):
|
||||||
|
r = seen.get(expert, 0)
|
||||||
|
pr[layer, i] = r
|
||||||
|
seen[expert] = r + 1
|
||||||
|
return pr
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_intragpu_rearrangement(
|
||||||
|
old_global_expert_indices: np.ndarray,
|
||||||
|
new_phy2log: np.ndarray,
|
||||||
|
new_phy_replicas_idx: np.ndarray,
|
||||||
|
post_phy2log: np.ndarray,
|
||||||
|
post_phy_replicas_idx: np.ndarray,
|
||||||
|
num_ranks: int,
|
||||||
|
slots_per_gpu: int,
|
||||||
|
):
|
||||||
|
# Per-GPU checks
|
||||||
|
for gpu_idx in range(num_ranks):
|
||||||
|
start = gpu_idx * slots_per_gpu
|
||||||
|
end = start + slots_per_gpu
|
||||||
|
old_seg = old_global_expert_indices[0, start:end]
|
||||||
|
new_seg = new_phy2log[0, start:end]
|
||||||
|
new_rnk = new_phy_replicas_idx[0, start:end]
|
||||||
|
post_seg = post_phy2log[0, start:end]
|
||||||
|
post_rnk = post_phy_replicas_idx[0, start:end]
|
||||||
|
|
||||||
|
# Pairwise equality for (expert, rank) pairs to ensure nothing is lost
|
||||||
|
def sorted_pairs(seg, rnk):
|
||||||
|
pairs = list(zip(seg.tolist(), rnk.tolist()))
|
||||||
|
pairs.sort()
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
assert sorted_pairs(post_seg, post_rnk) == sorted_pairs(new_seg, new_rnk), (
|
||||||
|
f"Per-GPU pairs of (expert,rank) must match new mapping for GPU {gpu_idx}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# For experts that remain on the same GPU, the old slot is preserved
|
||||||
|
# for at least one occurrence; rank at that slot must be valid for that expert
|
||||||
|
old_list = old_seg.tolist()
|
||||||
|
new_list = new_seg.tolist()
|
||||||
|
post_list = post_seg.tolist()
|
||||||
|
remained = set(old_list) & set(new_list)
|
||||||
|
new_ranks_for_expert: dict[int, list[int]] = {}
|
||||||
|
for v, r in zip(new_list, new_rnk.tolist()):
|
||||||
|
new_ranks_for_expert.setdefault(v, []).append(r)
|
||||||
|
for expert in remained:
|
||||||
|
old_pos = old_list.index(expert)
|
||||||
|
assert post_list[old_pos] == expert, (
|
||||||
|
f"Expert {expert} on GPU {gpu_idx} should stay at old slot {old_pos}"
|
||||||
|
)
|
||||||
|
# Rank at preserved slot must be one of the ranks
|
||||||
|
# the expert has in new mapping
|
||||||
|
assert post_rnk.tolist()[old_pos] in new_ranks_for_expert[expert], (
|
||||||
|
f"Rank for expert {expert} at preserved slot on GPU {gpu_idx} "
|
||||||
|
"must come from new mapping"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"num_ranks, slots_per_gpu, old_phy2log, new_phy2log",
|
||||||
|
[
|
||||||
|
pytest.param(
|
||||||
|
# Setup: 2 GPUs, 4 slots each, 1 layer
|
||||||
|
# Old mapping: GPU0 -> [0,1,2,3], GPU1 -> [4,5,6,7]
|
||||||
|
# New mapping shuffles within GPU0 and brings 4,5 into GPU0.
|
||||||
|
# GPU0 new -> [1,5,0,4]; GPU1 new -> [6,2,7,3]
|
||||||
|
2,
|
||||||
|
4,
|
||||||
|
np.array([[0, 1, 2, 3, 4, 5, 6, 7]]),
|
||||||
|
np.array([[1, 5, 0, 4, 6, 2, 7, 3]]),
|
||||||
|
id="simple",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
# Setup: 2 GPUs, 5 slots each (total 10 physical experts), 1 layer
|
||||||
|
# Old mapping:
|
||||||
|
# GPU0 -> [0, 1, 0, 2, 3] (expert 0 duplicated)
|
||||||
|
# GPU1 -> [4, 5, 6, 1, 2]
|
||||||
|
# New mapping reorders within GPUs and moves some experts across GPUs,
|
||||||
|
# while still including duplicates:
|
||||||
|
# GPU0 new -> [0, 5, 4, 0, 1] (expert 0 duplicated, 4/5 incoming)
|
||||||
|
# GPU1 new -> [6, 2, 3, 2, 1] (expert 2 duplicated)
|
||||||
|
2,
|
||||||
|
5,
|
||||||
|
np.array([[0, 1, 0, 2, 3, 4, 5, 6, 1, 2]]),
|
||||||
|
np.array([[0, 5, 4, 0, 1, 6, 2, 3, 2, 1]]),
|
||||||
|
id="duplicates",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
# Setup: 3 GPUs, 4 slots each (total 12 physical experts), 1 layer
|
||||||
|
# Old mapping:
|
||||||
|
# GPU0 -> [0, 1, 2, 3]
|
||||||
|
# GPU1 -> [0, 1, 2, 3]
|
||||||
|
# GPU2 -> [0, 1, 2, 3]
|
||||||
|
# New mapping decides to use one expert on 2 GPUs and shuffles
|
||||||
|
# experts on the third GPU,
|
||||||
|
# GPU0 new -> [0, 0, 0, 0]
|
||||||
|
# GPU1 new -> [0, 0, 0, 0]
|
||||||
|
# GPU2 new -> [1, 2, 3, 0]
|
||||||
|
3,
|
||||||
|
4,
|
||||||
|
np.array([[0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]]),
|
||||||
|
np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, 0]]),
|
||||||
|
id="skewed_expert",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_preserve_intragpu_slots(
|
||||||
|
num_ranks: int,
|
||||||
|
slots_per_gpu: int,
|
||||||
|
old_phy2log: torch.Tensor,
|
||||||
|
new_phy2log: torch.Tensor,
|
||||||
|
):
|
||||||
|
"""Experts that stay on a GPU keep their old slots; incoming not lost."""
|
||||||
|
phy_replicas_idx = _make_phy_replicas_idx_from_phy2log(new_phy2log)
|
||||||
|
|
||||||
|
post_phy2log, post_phy_replicas_idx = DefaultEplbPolicy.preserve_intragpu_slots(
|
||||||
|
new_phy2log, phy_replicas_idx, num_ranks, old_phy2log
|
||||||
|
)
|
||||||
|
|
||||||
|
# Shapes preserved
|
||||||
|
assert post_phy2log.shape == new_phy2log.shape
|
||||||
|
assert post_phy_replicas_idx.shape == phy_replicas_idx.shape
|
||||||
|
|
||||||
|
_validate_intragpu_rearrangement(
|
||||||
|
old_phy2log,
|
||||||
|
new_phy2log,
|
||||||
|
phy_replicas_idx,
|
||||||
|
post_phy2log,
|
||||||
|
post_phy_replicas_idx,
|
||||||
|
num_ranks,
|
||||||
|
slots_per_gpu,
|
||||||
|
)
|
||||||
|
|||||||
@ -286,32 +286,32 @@ def _test_async_transfer_layer_without_mtp_worker(
|
|||||||
device,
|
device,
|
||||||
old_indices,
|
old_indices,
|
||||||
)
|
)
|
||||||
|
old_indices_cpu = old_indices.cpu()
|
||||||
|
new_indices_cpu = new_indices.cpu()
|
||||||
|
|
||||||
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
expert_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
||||||
cuda_stream = torch.cuda.Stream(device=device)
|
cuda_stream = torch.cuda.Stream(device=device)
|
||||||
|
|
||||||
for layer_idx in range(num_layers):
|
for layer_idx in range(num_layers):
|
||||||
is_unchanged, is_received_locally, experts_recv_loc = asyncio.run(
|
is_unchanged, is_received_locally, recv_metadata = asyncio.run(
|
||||||
transfer_layer(
|
transfer_layer(
|
||||||
old_global_expert_indices=old_indices,
|
old_layer_indices=old_indices_cpu[layer_idx],
|
||||||
new_global_expert_indices=new_indices,
|
new_layer_indices=new_indices_cpu[layer_idx],
|
||||||
expert_weights=expert_weights,
|
expert_weights=expert_weights[layer_idx],
|
||||||
expert_weights_buffer=expert_buffer,
|
expert_weights_buffer=expert_buffer,
|
||||||
ep_group=ep_group,
|
ep_group=ep_group,
|
||||||
layer=layer_idx,
|
|
||||||
cuda_stream=cuda_stream,
|
cuda_stream=cuda_stream,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
cuda_stream.synchronize()
|
cuda_stream.synchronize()
|
||||||
move_from_buffer(
|
move_from_buffer(
|
||||||
expert_weights=expert_weights[layer_idx],
|
expert_weights=expert_weights[layer_idx],
|
||||||
expert_weights_buffer=expert_buffer,
|
expert_weights_buffers=expert_buffer,
|
||||||
is_unchanged=is_unchanged,
|
is_unchanged=is_unchanged,
|
||||||
is_received_locally=is_received_locally,
|
is_received_locally=is_received_locally,
|
||||||
experts_recv_loc=experts_recv_loc,
|
recv_metadata=recv_metadata,
|
||||||
new_indices=new_indices[layer_idx].tolist(),
|
new_indices=new_indices_cpu[layer_idx],
|
||||||
ep_group=ep_group,
|
ep_rank=ep_group.rank(),
|
||||||
)
|
)
|
||||||
|
|
||||||
verify_expert_weights_after_shuffle(
|
verify_expert_weights_after_shuffle(
|
||||||
|
|||||||
@ -69,6 +69,10 @@ class EPLBConfig:
|
|||||||
Log the balancedness each step of expert parallelism.
|
Log the balancedness each step of expert parallelism.
|
||||||
This is turned off by default since it will cause communication overhead.
|
This is turned off by default since it will cause communication overhead.
|
||||||
"""
|
"""
|
||||||
|
log_balancedness_interval: int = 1
|
||||||
|
"""
|
||||||
|
Interval for logging the balancedness.
|
||||||
|
"""
|
||||||
use_async: bool = False
|
use_async: bool = False
|
||||||
"""
|
"""
|
||||||
Whether to use non-blocking EPLB.
|
Whether to use non-blocking EPLB.
|
||||||
|
|||||||
@ -17,7 +17,7 @@ from vllm.logger import init_logger
|
|||||||
from .rebalance_execute import transfer_layer
|
from .rebalance_execute import transfer_layer
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .eplb_state import EplbState
|
from .eplb_state import EplbModelState, EplbState
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -57,6 +57,44 @@ def start_async_worker(
|
|||||||
return thread
|
return thread
|
||||||
|
|
||||||
|
|
||||||
|
def run_rebalance_experts(
|
||||||
|
model_state: "EplbModelState",
|
||||||
|
eplb_state: "EplbState",
|
||||||
|
) -> None:
|
||||||
|
assert model_state.eplb_stats is not None
|
||||||
|
eplb_stats = model_state.eplb_stats
|
||||||
|
# Move the global expert load window to CPU for computation.
|
||||||
|
# It has to be done in the main stream to avoid race condition
|
||||||
|
# with the main thread.
|
||||||
|
global_expert_load_window = eplb_stats.global_expert_load_window.cpu()
|
||||||
|
# Compute new expert mappings for the model
|
||||||
|
(
|
||||||
|
new_physical_to_logical_map,
|
||||||
|
new_logical_to_physical_map,
|
||||||
|
new_logical_replica_count,
|
||||||
|
) = eplb_state.policy.rebalance_experts(
|
||||||
|
global_expert_load_window,
|
||||||
|
eplb_stats.num_replicas,
|
||||||
|
eplb_stats.num_groups,
|
||||||
|
eplb_stats.num_nodes,
|
||||||
|
eplb_stats.num_gpus,
|
||||||
|
model_state.physical_to_logical_map,
|
||||||
|
)
|
||||||
|
assert new_physical_to_logical_map.device == torch.device("cpu")
|
||||||
|
|
||||||
|
model_state.new_physical_to_logical_map = new_physical_to_logical_map
|
||||||
|
|
||||||
|
max_slots = model_state.logical_to_physical_map.shape[-1]
|
||||||
|
padded_logical = torch.nn.functional.pad(
|
||||||
|
new_logical_to_physical_map,
|
||||||
|
(0, max(0, max_slots - new_logical_to_physical_map.shape[-1])),
|
||||||
|
value=-1,
|
||||||
|
).to(model_state.logical_to_physical_map.device)
|
||||||
|
new_replica = new_logical_replica_count.to(model_state.logical_replica_count.device)
|
||||||
|
model_state.new_logical_to_physical_map = padded_logical
|
||||||
|
model_state.new_logical_replica_count = new_replica
|
||||||
|
|
||||||
|
|
||||||
async def transfer_run_periodically(
|
async def transfer_run_periodically(
|
||||||
state: "EplbState",
|
state: "EplbState",
|
||||||
ep_group: ProcessGroup,
|
ep_group: ProcessGroup,
|
||||||
@ -71,33 +109,46 @@ async def transfer_run_periodically(
|
|||||||
for model_state in state.model_states.values():
|
for model_state in state.model_states.values():
|
||||||
if not model_state.is_async_enabled:
|
if not model_state.is_async_enabled:
|
||||||
continue
|
continue
|
||||||
|
rebalancing_algorithm_executed = False
|
||||||
|
logger.info(
|
||||||
|
"Async worker computed new indices for model %s",
|
||||||
|
model_state.model_name,
|
||||||
|
)
|
||||||
|
|
||||||
current_num_layers = model_state.model.num_moe_layers
|
current_num_layers = model_state.model.num_moe_layers
|
||||||
while (
|
while (
|
||||||
model_state.rebalanced
|
model_state.rebalanced
|
||||||
and model_state.layer_to_transfer < current_num_layers
|
and model_state.layer_to_transfer < current_num_layers
|
||||||
):
|
):
|
||||||
if (
|
if not model_state.ep_buffer_ready and model_state.rebalanced:
|
||||||
not model_state.ep_buffer_ready
|
|
||||||
and model_state.rebalanced
|
|
||||||
and model_state.new_physical_to_logical_map is not None
|
|
||||||
):
|
|
||||||
await asyncio.to_thread(model_state.buffer_lock.acquire)
|
await asyncio.to_thread(model_state.buffer_lock.acquire)
|
||||||
try:
|
try:
|
||||||
if model_state.layer_to_transfer >= current_num_layers:
|
if model_state.layer_to_transfer >= current_num_layers:
|
||||||
break
|
break
|
||||||
|
if not rebalancing_algorithm_executed:
|
||||||
|
run_rebalance_experts(model_state, state)
|
||||||
|
rebalancing_algorithm_executed = True
|
||||||
|
assert model_state.new_physical_to_logical_map is not None
|
||||||
|
|
||||||
|
layer_idx = model_state.layer_to_transfer
|
||||||
|
old_layer_indices = model_state.physical_to_logical_map[
|
||||||
|
layer_idx
|
||||||
|
]
|
||||||
|
new_layer_indices = model_state.new_physical_to_logical_map[
|
||||||
|
layer_idx
|
||||||
|
]
|
||||||
|
|
||||||
(
|
(
|
||||||
model_state.is_unchanged,
|
model_state.is_unchanged,
|
||||||
model_state.is_received_locally,
|
model_state.is_received_locally,
|
||||||
model_state.experts_recv_loc,
|
model_state.recv_metadata,
|
||||||
) = await transfer_layer(
|
) = await transfer_layer(
|
||||||
old_global_expert_indices=model_state.physical_to_logical_map,
|
old_layer_indices=old_layer_indices,
|
||||||
new_global_expert_indices=model_state.new_physical_to_logical_map,
|
new_layer_indices=new_layer_indices,
|
||||||
expert_weights=model_state.model.expert_weights,
|
expert_weights=model_state.model.expert_weights[layer_idx],
|
||||||
expert_weights_buffer=model_state.expert_buffer,
|
expert_weights_buffer=model_state.expert_buffer,
|
||||||
ep_group=ep_group,
|
ep_group=ep_group,
|
||||||
is_profile=is_profile,
|
is_profile=is_profile,
|
||||||
layer=model_state.layer_to_transfer,
|
|
||||||
cuda_stream=cuda_stream,
|
cuda_stream=cuda_stream,
|
||||||
rank_mapping=rank_mapping,
|
rank_mapping=rank_mapping,
|
||||||
)
|
)
|
||||||
|
|||||||
@ -27,10 +27,10 @@ physical experts.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import ProcessGroup, all_reduce
|
from torch.distributed import ProcessGroup, all_reduce
|
||||||
|
|
||||||
@ -46,11 +46,44 @@ from vllm.model_executor.models.interfaces import MixtureOfExperts
|
|||||||
|
|
||||||
from .async_worker import start_async_worker
|
from .async_worker import start_async_worker
|
||||||
from .policy import EPLB_POLICIES, AbstractEplbPolicy, DefaultEplbPolicy
|
from .policy import EPLB_POLICIES, AbstractEplbPolicy, DefaultEplbPolicy
|
||||||
from .rebalance_execute import move_from_buffer, rearrange_expert_weights_inplace
|
from .rebalance_execute import (
|
||||||
|
RecvMetadata,
|
||||||
|
move_from_buffer,
|
||||||
|
rearrange_expert_weights_inplace,
|
||||||
|
)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EplbStats:
|
||||||
|
"""
|
||||||
|
Model stats used in EPLB rebalanding algorithm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
global_expert_load_window: torch.Tensor
|
||||||
|
"""
|
||||||
|
Experts load window.
|
||||||
|
Shape: (window_size, num_moe_layers, num_physical_experts)
|
||||||
|
"""
|
||||||
|
num_replicas: int
|
||||||
|
"""
|
||||||
|
Number of physical experts.
|
||||||
|
"""
|
||||||
|
num_groups: int
|
||||||
|
"""
|
||||||
|
Number of expert groups.
|
||||||
|
"""
|
||||||
|
num_nodes: int
|
||||||
|
"""
|
||||||
|
Number of nodes.
|
||||||
|
"""
|
||||||
|
num_gpus: int
|
||||||
|
"""
|
||||||
|
Number of GPUs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EplbModelState:
|
class EplbModelState:
|
||||||
"""EPLB metrics."""
|
"""EPLB metrics."""
|
||||||
@ -164,20 +197,23 @@ class EplbModelState:
|
|||||||
"""
|
"""
|
||||||
Whether the async EPLB needs to poll peers for buffer readiness.
|
Whether the async EPLB needs to poll peers for buffer readiness.
|
||||||
"""
|
"""
|
||||||
is_unchanged: list[bool]
|
eplb_stats: EplbStats | None
|
||||||
|
"""
|
||||||
|
EPLB stats for the model.
|
||||||
|
"""
|
||||||
|
is_unchanged: np.ndarray
|
||||||
"""
|
"""
|
||||||
intermediate variable between `move_to_buffer` and `move_to_workspace`.
|
intermediate variable between `move_to_buffer` and `move_to_workspace`.
|
||||||
The size is same as the num of physical experts in the current layer.
|
The size is same as the num of physical experts in the current layer.
|
||||||
"""
|
"""
|
||||||
is_received_locally: list[bool]
|
is_received_locally: np.ndarray
|
||||||
"""
|
"""
|
||||||
intermediate variable between `move_to_buffer` and `move_to_workspace`.
|
intermediate variable between `move_to_buffer` and `move_to_workspace`.
|
||||||
The size is same as the num of physical experts in the current layer.
|
The size is same as the num of physical experts in the current layer.
|
||||||
"""
|
"""
|
||||||
experts_recv_loc: dict[int, int]
|
recv_metadata: RecvMetadata
|
||||||
"""
|
"""
|
||||||
intermediate variable between `move_to_buffer` and `move_to_workspace`.
|
intermediate variable between `move_to_buffer` and `move_to_workspace`.
|
||||||
The size is same as the num of physical experts in the current layer.
|
|
||||||
"""
|
"""
|
||||||
is_async_enabled: bool
|
is_async_enabled: bool
|
||||||
"""
|
"""
|
||||||
@ -507,9 +543,15 @@ class EplbState:
|
|||||||
layer_to_transfer=0,
|
layer_to_transfer=0,
|
||||||
rebalanced=False,
|
rebalanced=False,
|
||||||
pending_global_ready_check=False,
|
pending_global_ready_check=False,
|
||||||
is_unchanged=[],
|
eplb_stats=None,
|
||||||
is_received_locally=[],
|
is_unchanged=np.array([]),
|
||||||
experts_recv_loc={},
|
is_received_locally=np.array([]),
|
||||||
|
recv_metadata=RecvMetadata(
|
||||||
|
recv_primary_mask=np.array([]),
|
||||||
|
recv_count=0,
|
||||||
|
recv_expert_ids=np.array([]),
|
||||||
|
recv_dst_rows=np.array([]),
|
||||||
|
),
|
||||||
is_async_enabled=self.is_async,
|
is_async_enabled=self.is_async,
|
||||||
cuda_device_index=self.cuda_device_index,
|
cuda_device_index=self.cuda_device_index,
|
||||||
new_physical_to_logical_map=new_physical_to_logical_map,
|
new_physical_to_logical_map=new_physical_to_logical_map,
|
||||||
@ -553,7 +595,12 @@ class EplbState:
|
|||||||
for eplb_model_state in self.model_states.values():
|
for eplb_model_state in self.model_states.values():
|
||||||
eplb_model_state.expert_load_pass.zero_()
|
eplb_model_state.expert_load_pass.zero_()
|
||||||
|
|
||||||
if log_stats:
|
if (
|
||||||
|
log_stats
|
||||||
|
and self.expert_rearrangement_step
|
||||||
|
% self.parallel_config.eplb_config.log_balancedness_interval
|
||||||
|
== 0
|
||||||
|
):
|
||||||
# Sync the expert load pass for each model (main and drafter).
|
# Sync the expert load pass for each model (main and drafter).
|
||||||
# expert_load_pass: (num_moe_layers, num_physical_experts)
|
# expert_load_pass: (num_moe_layers, num_physical_experts)
|
||||||
expert_load_pass_list = self._sync_load_pass()
|
expert_load_pass_list = self._sync_load_pass()
|
||||||
@ -585,9 +632,10 @@ class EplbState:
|
|||||||
|
|
||||||
if ep_group.rank() == 0:
|
if ep_group.rank() == 0:
|
||||||
logger.info(
|
logger.info(
|
||||||
"EPLB step: %d for model %s: avg_tokens=%.2f, "
|
"EPLB step: %d/%d for model %s: avg_tokens=%.2f, "
|
||||||
"max_tokens=%d, balancedness=%.4f",
|
"max_tokens=%d, balancedness=%.4f",
|
||||||
self.expert_rearrangement_step,
|
self.expert_rearrangement_step,
|
||||||
|
self.expert_rearrangement_step_interval,
|
||||||
eplb_model_state.model_name,
|
eplb_model_state.model_name,
|
||||||
avg_tokens,
|
avg_tokens,
|
||||||
max_tokens,
|
max_tokens,
|
||||||
@ -684,11 +732,14 @@ class EplbState:
|
|||||||
ep_group = get_ep_group().device_group
|
ep_group = get_ep_group().device_group
|
||||||
ep_rank = ep_group.rank()
|
ep_rank = ep_group.rank()
|
||||||
|
|
||||||
time_start = None
|
start_event = None
|
||||||
|
end_event = None
|
||||||
is_main_rank = ep_rank == 0
|
is_main_rank = ep_rank == 0
|
||||||
if is_main_rank:
|
if is_main_rank:
|
||||||
torch.cuda.synchronize()
|
if not self.is_async or is_profile:
|
||||||
time_start = time.perf_counter()
|
start_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
end_event = torch.cuda.Event(enable_timing=True)
|
||||||
|
start_event.record()
|
||||||
logger.info(
|
logger.info(
|
||||||
"Rearranging experts %s %s...",
|
"Rearranging experts %s %s...",
|
||||||
"(async mode)" if self.is_async else "sync mode",
|
"(async mode)" if self.is_async else "sync mode",
|
||||||
@ -789,20 +840,21 @@ class EplbState:
|
|||||||
for eplb_model_state, global_expert_load_window in zip(
|
for eplb_model_state, global_expert_load_window in zip(
|
||||||
self.model_states.values(), global_expert_load_windows
|
self.model_states.values(), global_expert_load_windows
|
||||||
):
|
):
|
||||||
# Get new expert mappings for the model
|
|
||||||
(
|
|
||||||
new_physical_to_logical_map,
|
|
||||||
new_logical_to_physical_map,
|
|
||||||
new_logical_replica_count,
|
|
||||||
) = self.policy.rebalance_experts(
|
|
||||||
global_expert_load_window,
|
|
||||||
num_replicas,
|
|
||||||
num_groups,
|
|
||||||
num_nodes,
|
|
||||||
num_gpus,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not eplb_model_state.is_async_enabled or is_profile:
|
if not eplb_model_state.is_async_enabled or is_profile:
|
||||||
|
# Get new expert mappings for the model
|
||||||
|
(
|
||||||
|
new_physical_to_logical_map,
|
||||||
|
new_logical_to_physical_map,
|
||||||
|
new_logical_replica_count,
|
||||||
|
) = self.policy.rebalance_experts(
|
||||||
|
global_expert_load_window,
|
||||||
|
num_replicas,
|
||||||
|
num_groups,
|
||||||
|
num_nodes,
|
||||||
|
num_gpus,
|
||||||
|
eplb_model_state.physical_to_logical_map,
|
||||||
|
)
|
||||||
|
|
||||||
# Update expert weights
|
# Update expert weights
|
||||||
rearrange_expert_weights_inplace(
|
rearrange_expert_weights_inplace(
|
||||||
eplb_model_state.physical_to_logical_map,
|
eplb_model_state.physical_to_logical_map,
|
||||||
@ -848,35 +900,29 @@ class EplbState:
|
|||||||
new_logical_replica_count
|
new_logical_replica_count
|
||||||
)
|
)
|
||||||
if is_main_rank:
|
if is_main_rank:
|
||||||
assert time_start is not None
|
assert start_event is not None
|
||||||
torch.cuda.synchronize()
|
assert end_event is not None
|
||||||
time_end = time.perf_counter()
|
end_event.record()
|
||||||
|
end_event.synchronize()
|
||||||
|
gpu_elapsed = start_event.elapsed_time(end_event) / 1000.0
|
||||||
logger.info(
|
logger.info(
|
||||||
"Rearranged experts%sin %.2f seconds.",
|
"Rearranged experts %s in %.2f s.",
|
||||||
" (profile) " if is_profile else " ",
|
" (profile) " if is_profile else " ",
|
||||||
time_end - time_start,
|
gpu_elapsed,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
device = eplb_model_state.physical_to_logical_map.device
|
eplb_model_state.eplb_stats = EplbStats(
|
||||||
new_physical = new_physical_to_logical_map.to(device)
|
# We copy the tensor to snapshot the workload on the main
|
||||||
max_slots = eplb_model_state.logical_to_physical_map.shape[-1]
|
# thread to be used on the async thread.
|
||||||
padded_logical = torch.nn.functional.pad(
|
global_expert_load_window=global_expert_load_window.clone(),
|
||||||
new_logical_to_physical_map,
|
num_replicas=num_replicas,
|
||||||
(0, max(0, max_slots - new_logical_to_physical_map.shape[-1])),
|
num_groups=num_groups,
|
||||||
value=-1,
|
num_nodes=num_nodes,
|
||||||
).to(eplb_model_state.logical_to_physical_map.device)
|
num_gpus=num_gpus,
|
||||||
new_replica = new_logical_replica_count.to(
|
|
||||||
eplb_model_state.logical_replica_count.device
|
|
||||||
)
|
)
|
||||||
|
|
||||||
eplb_model_state.new_physical_to_logical_map = new_physical
|
|
||||||
eplb_model_state.new_logical_to_physical_map = padded_logical
|
|
||||||
eplb_model_state.new_logical_replica_count = new_replica
|
|
||||||
|
|
||||||
eplb_model_state.rebalanced = True
|
eplb_model_state.rebalanced = True
|
||||||
eplb_model_state.layer_to_transfer = 0
|
eplb_model_state.layer_to_transfer = 0
|
||||||
eplb_model_state.pending_global_ready_check = True
|
eplb_model_state.pending_global_ready_check = True
|
||||||
|
|
||||||
# Signal async thread to start transferring layers
|
# Signal async thread to start transferring layers
|
||||||
if self.is_async and (not is_profile):
|
if self.is_async and (not is_profile):
|
||||||
self.rearrange_event.set()
|
self.rearrange_event.set()
|
||||||
@ -908,11 +954,13 @@ class EplbState:
|
|||||||
|
|
||||||
target_device = model_state.physical_to_logical_map.device
|
target_device = model_state.physical_to_logical_map.device
|
||||||
new_physical = model_state.new_physical_to_logical_map
|
new_physical = model_state.new_physical_to_logical_map
|
||||||
|
# In order to avoid race condition with async eplb worker,
|
||||||
|
# we need to copy blocking in case of updated EP size.
|
||||||
if model_state.physical_to_logical_map.shape[1] != new_physical.shape[1]:
|
if model_state.physical_to_logical_map.shape[1] != new_physical.shape[1]:
|
||||||
model_state.physical_to_logical_map = new_physical.to(target_device)
|
model_state.physical_to_logical_map = new_physical.to(target_device)
|
||||||
else:
|
else:
|
||||||
model_state.physical_to_logical_map[layer].copy_(
|
model_state.physical_to_logical_map[layer].copy_(
|
||||||
new_physical[layer].to(target_device)
|
new_physical[layer].to(target_device, non_blocking=True)
|
||||||
)
|
)
|
||||||
|
|
||||||
logical_device = model_state.logical_to_physical_map.device
|
logical_device = model_state.logical_to_physical_map.device
|
||||||
@ -968,25 +1016,28 @@ class EplbState:
|
|||||||
stream = torch.cuda.current_stream(device=device_index)
|
stream = torch.cuda.current_stream(device=device_index)
|
||||||
stream.wait_event(model_state.buffer_ready_event)
|
stream.wait_event(model_state.buffer_ready_event)
|
||||||
model_state.buffer_ready_event = None
|
model_state.buffer_ready_event = None
|
||||||
|
expert_weights = model_state.model.expert_weights[
|
||||||
|
model_state.layer_to_transfer
|
||||||
|
]
|
||||||
|
expert_weights_buffer = model_state.expert_buffer
|
||||||
|
new_indices = model_state.new_physical_to_logical_map[
|
||||||
|
model_state.layer_to_transfer
|
||||||
|
].numpy()
|
||||||
move_from_buffer(
|
move_from_buffer(
|
||||||
expert_weights=model_state.model.expert_weights[
|
expert_weights=expert_weights,
|
||||||
model_state.layer_to_transfer
|
expert_weights_buffers=expert_weights_buffer,
|
||||||
],
|
|
||||||
expert_weights_buffer=model_state.expert_buffer,
|
|
||||||
is_unchanged=model_state.is_unchanged,
|
is_unchanged=model_state.is_unchanged,
|
||||||
is_received_locally=model_state.is_received_locally,
|
is_received_locally=model_state.is_received_locally,
|
||||||
experts_recv_loc=model_state.experts_recv_loc,
|
recv_metadata=model_state.recv_metadata,
|
||||||
new_indices=model_state.new_physical_to_logical_map[
|
new_indices=new_indices,
|
||||||
model_state.layer_to_transfer
|
ep_rank=ep_group.rank(),
|
||||||
].tolist(),
|
|
||||||
ep_group=ep_group,
|
|
||||||
)
|
)
|
||||||
transferred_layer = model_state.layer_to_transfer
|
transferred_layer = model_state.layer_to_transfer
|
||||||
self._update_layer_mapping_from_new(model_state, transferred_layer)
|
self._update_layer_mapping_from_new(model_state, transferred_layer)
|
||||||
# After the main thread consumes, advance layer_to_transfer
|
# After the main thread consumes, advance layer_to_transfer
|
||||||
model_state.layer_to_transfer += 1
|
model_state.layer_to_transfer += 1
|
||||||
model_state.ep_buffer_ready = 0
|
model_state.ep_buffer_ready = 0
|
||||||
logger.info(
|
logger.debug(
|
||||||
"model %s successfully move_to_workspace layer %d",
|
"model %s successfully move_to_workspace layer %d",
|
||||||
model_state.model_name,
|
model_state.model_name,
|
||||||
transferred_layer,
|
transferred_layer,
|
||||||
@ -1005,9 +1056,7 @@ class EplbState:
|
|||||||
assert model_state.new_physical_to_logical_map is not None
|
assert model_state.new_physical_to_logical_map is not None
|
||||||
assert model_state.new_logical_to_physical_map is not None
|
assert model_state.new_logical_to_physical_map is not None
|
||||||
assert model_state.new_logical_replica_count is not None
|
assert model_state.new_logical_replica_count is not None
|
||||||
if not is_profile:
|
|
||||||
for layer_idx in range(model_state.physical_to_logical_map.shape[0]):
|
|
||||||
self._update_layer_mapping_from_new(model_state, layer_idx)
|
|
||||||
model_state.new_physical_to_logical_map = None
|
model_state.new_physical_to_logical_map = None
|
||||||
model_state.new_logical_to_physical_map = None
|
model_state.new_logical_to_physical_map = None
|
||||||
model_state.new_logical_replica_count = None
|
model_state.new_logical_replica_count = None
|
||||||
|
|||||||
@ -16,6 +16,7 @@ class AbstractEplbPolicy(ABC):
|
|||||||
num_groups: int,
|
num_groups: int,
|
||||||
num_nodes: int,
|
num_nodes: int,
|
||||||
num_ranks: int,
|
num_ranks: int,
|
||||||
|
old_global_expert_indices: torch.Tensor | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Entry point for expert-parallelism load balancer.
|
Entry point for expert-parallelism load balancer.
|
||||||
@ -28,7 +29,9 @@ class AbstractEplbPolicy(ABC):
|
|||||||
num_groups: number of expert groups
|
num_groups: number of expert groups
|
||||||
num_nodes: number of server nodes
|
num_nodes: number of server nodes
|
||||||
num_ranks: number of ranks, must be a multiple of `num_nodes`
|
num_ranks: number of ranks, must be a multiple of `num_nodes`
|
||||||
|
old_global_expert_indices: [layers, num_logical_experts], the old global
|
||||||
|
expert indices. Used to avoid unnecessary weight copying
|
||||||
|
for experts moving within one rank.
|
||||||
Returns:
|
Returns:
|
||||||
physical_to_logical_map: [layers, num_replicas], the expert
|
physical_to_logical_map: [layers, num_replicas], the expert
|
||||||
index of each replica
|
index of each replica
|
||||||
|
|||||||
@ -21,8 +21,8 @@ from .abstract import AbstractEplbPolicy
|
|||||||
class DefaultEplbPolicy(AbstractEplbPolicy):
|
class DefaultEplbPolicy(AbstractEplbPolicy):
|
||||||
@classmethod
|
@classmethod
|
||||||
def balanced_packing(
|
def balanced_packing(
|
||||||
cls, weight: torch.Tensor, num_packs: int
|
cls, weight: np.ndarray, num_packs: int
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Pack n weighted objects to m packs, such that each bin contains exactly
|
Pack n weighted objects to m packs, such that each bin contains exactly
|
||||||
n/m objects and the weights of all packs are as balanced as possible.
|
n/m objects and the weights of all packs are as balanced as possible.
|
||||||
@ -39,50 +39,43 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
assert num_groups % num_packs == 0
|
assert num_groups % num_packs == 0
|
||||||
groups_per_pack = num_groups // num_packs
|
groups_per_pack = num_groups // num_packs
|
||||||
|
|
||||||
device = weight.device
|
|
||||||
|
|
||||||
if groups_per_pack == 1:
|
if groups_per_pack == 1:
|
||||||
pack_index = torch.arange(
|
pack_index = np.tile(np.arange(num_groups, dtype=np.int64), (num_layers, 1))
|
||||||
weight.size(-1), dtype=torch.int64, device=device
|
rank_in_pack = np.zeros_like(pack_index, dtype=np.int64)
|
||||||
).expand(weight.shape)
|
|
||||||
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device)
|
|
||||||
return pack_index, rank_in_pack
|
return pack_index, rank_in_pack
|
||||||
|
|
||||||
weight_np = weight.cpu().numpy()
|
|
||||||
|
|
||||||
# Sort and get indices in decending order
|
# Sort and get indices in decending order
|
||||||
indices_np = np.argsort(-weight_np, axis=-1)
|
indices = np.argsort(-weight, axis=-1)
|
||||||
|
|
||||||
pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
|
pack_index = np.full((num_layers, num_groups), -1, dtype=np.int64)
|
||||||
rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
|
rank_in_pack = np.full((num_layers, num_groups), -1, dtype=np.int64)
|
||||||
|
|
||||||
|
pack_weights = np.zeros((num_layers, num_packs), dtype=np.float64)
|
||||||
|
pack_items = np.zeros((num_layers, num_packs), dtype=np.int64)
|
||||||
|
|
||||||
# Run the packing algorithm
|
# Run the packing algorithm
|
||||||
for i in range(num_layers):
|
for layer_idx in range(num_layers):
|
||||||
pack_weights = [0.0] * num_packs
|
weights_row = pack_weights[layer_idx]
|
||||||
pack_items = [0] * num_packs
|
items_row = pack_items[layer_idx]
|
||||||
|
|
||||||
for group in indices_np[i]:
|
for group in indices[layer_idx]:
|
||||||
# Find a pack with capacity that has the lowest weight
|
# Pick the lightest pack; full packs are masked out by inf.
|
||||||
pack = min(
|
pack = int(np.argmin(weights_row))
|
||||||
(j for j in range(num_packs) if pack_items[j] < groups_per_pack),
|
|
||||||
key=pack_weights.__getitem__,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert pack_items[pack] < groups_per_pack
|
pack_index[layer_idx, group] = pack
|
||||||
pack_index_np[i, group] = pack
|
rank_in_pack[layer_idx, group] = items_row[pack]
|
||||||
rank_in_pack_np[i, group] = pack_items[pack]
|
weights_row[pack] += weight[layer_idx, group]
|
||||||
pack_weights[pack] += weight_np[i, group]
|
items_row[pack] += 1
|
||||||
pack_items[pack] += 1
|
if items_row[pack] == groups_per_pack:
|
||||||
|
# Mark as unavailable for future selections.
|
||||||
pack_index = torch.from_numpy(pack_index_np).to(device)
|
weights_row[pack] = np.inf
|
||||||
rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device)
|
|
||||||
|
|
||||||
return pack_index, rank_in_pack
|
return pack_index, rank_in_pack
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def replicate_experts(
|
def replicate_experts(
|
||||||
cls, weight: torch.Tensor, num_phy: int
|
cls, weight: np.ndarray, num_phy: int
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
|
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
|
||||||
load of all replicas is minimized.
|
load of all replicas is minimized.
|
||||||
@ -93,33 +86,32 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
phy2log: [X, num_phy], logical expert id of each physical expert
|
phy2log: [X, num_phy], logical expert id of each physical expert
|
||||||
rank: [X, num_phy], the replica rank
|
replica_idx: [X, num_phy], the index of the replica for each logical expert
|
||||||
logcnt: [X, num_log], number of replicas for each logical expert
|
logcnt: [X, num_log], number of replicas for each logical expert
|
||||||
"""
|
"""
|
||||||
n, num_log = weight.shape
|
n, num_log = weight.shape
|
||||||
num_redundant = num_phy - num_log
|
num_redundant = num_phy - num_log
|
||||||
assert num_redundant >= 0
|
assert num_redundant >= 0
|
||||||
device = weight.device
|
phy2log = np.tile(np.arange(num_phy, dtype=np.int64), (n, 1))
|
||||||
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
|
replica_idx = np.zeros((n, num_phy), dtype=np.int64)
|
||||||
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
|
logcnt = np.ones((n, num_log), dtype=np.int64)
|
||||||
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
|
arangen = np.arange(n, dtype=np.int64)
|
||||||
arangen = torch.arange(n, dtype=torch.int64, device=device)
|
|
||||||
for i in range(num_log, num_phy):
|
for i in range(num_log, num_phy):
|
||||||
redundant_indices = (weight / logcnt).max(dim=-1).indices
|
redundant_indices = np.argmax(weight / logcnt, axis=-1)
|
||||||
phy2log[:, i] = redundant_indices
|
phy2log[:, i] = redundant_indices
|
||||||
rank[:, i] = logcnt[arangen, redundant_indices]
|
replica_idx[:, i] = logcnt[arangen, redundant_indices]
|
||||||
logcnt[arangen, redundant_indices] += 1
|
logcnt[arangen, redundant_indices] += 1
|
||||||
return phy2log, rank, logcnt
|
return phy2log, replica_idx, logcnt
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def rebalance_experts_hierarchical(
|
def rebalance_experts_hierarchical(
|
||||||
cls,
|
cls,
|
||||||
weight: torch.Tensor,
|
weight: np.ndarray,
|
||||||
num_physical_experts: int,
|
num_physical_experts: int,
|
||||||
num_groups: int,
|
num_groups: int,
|
||||||
num_nodes: int,
|
num_nodes: int,
|
||||||
num_gpus: int,
|
num_gpus: int,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||||
"""
|
"""
|
||||||
Parameters:
|
Parameters:
|
||||||
weight: [num_moe_layers, num_logical_experts]
|
weight: [num_moe_layers, num_logical_experts]
|
||||||
@ -132,7 +124,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
Returns:
|
Returns:
|
||||||
phy2log: [layers, num_replicas], the expert
|
phy2log: [layers, num_replicas], the expert
|
||||||
index of each replica
|
index of each replica
|
||||||
log2phy: [layers, num_logical_experts, X],
|
pphy_replicas_idx: [layers, num_logical_experts, X],
|
||||||
the replica indices for each expert
|
the replica indices for each expert
|
||||||
logcnt: [layers, num_logical_experts], number of
|
logcnt: [layers, num_logical_experts], number of
|
||||||
physical replicas for each logical expert
|
physical replicas for each logical expert
|
||||||
@ -146,66 +138,160 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
assert num_physical_experts % num_gpus == 0
|
assert num_physical_experts % num_gpus == 0
|
||||||
phy_experts_per_gpu = num_physical_experts // num_gpus
|
phy_experts_per_gpu = num_physical_experts // num_gpus
|
||||||
|
|
||||||
def inverse(perm: torch.Tensor) -> torch.Tensor:
|
def inverse(perm: np.ndarray) -> np.ndarray:
|
||||||
inv = torch.empty_like(perm)
|
inv = np.empty_like(perm)
|
||||||
inv.scatter_(
|
row_idx = np.arange(perm.shape[0])[:, None]
|
||||||
1,
|
col_idx = np.arange(perm.shape[1], dtype=np.int64)
|
||||||
perm,
|
inv[row_idx, perm] = col_idx
|
||||||
torch.arange(
|
|
||||||
perm.size(1), dtype=torch.int64, device=perm.device
|
|
||||||
).expand(perm.shape),
|
|
||||||
)
|
|
||||||
return inv
|
return inv
|
||||||
|
|
||||||
# Step 1: pack groups to nodes
|
# Step 1: pack groups to nodes
|
||||||
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
|
tokens_per_group = weight.reshape(num_layers, num_groups, group_size).sum(
|
||||||
|
axis=-1
|
||||||
|
)
|
||||||
group_pack_index, group_rank_in_pack = cls.balanced_packing(
|
group_pack_index, group_rank_in_pack = cls.balanced_packing(
|
||||||
tokens_per_group, num_nodes
|
tokens_per_group, num_nodes
|
||||||
)
|
)
|
||||||
|
# Map each logical expert into a node-local ordering based on packed groups.
|
||||||
log2mlog = (
|
log2mlog = (
|
||||||
(
|
(
|
||||||
(group_pack_index * groups_per_node + group_rank_in_pack) * group_size
|
(group_pack_index * groups_per_node + group_rank_in_pack)[..., None]
|
||||||
).unsqueeze(-1)
|
* group_size
|
||||||
+ torch.arange(
|
|
||||||
group_size, dtype=torch.int64, device=group_pack_index.device
|
|
||||||
)
|
)
|
||||||
).flatten(-2)
|
+ np.arange(group_size, dtype=np.int64)
|
||||||
|
).reshape(num_layers, num_logical_experts)
|
||||||
mlog2log = inverse(log2mlog)
|
mlog2log = inverse(log2mlog)
|
||||||
|
|
||||||
# Step 2: construct redundant experts within nodes
|
# Step 2: construct redundant experts within nodes
|
||||||
# [num_layers * num_nodes, num_logical_experts // num_nodes]
|
# Reorder weights into the node-local layout so replication is done per node.
|
||||||
tokens_per_mlog = weight.gather(-1, mlog2log).view(
|
tokens_per_mlog = np.take_along_axis(weight, mlog2log, axis=1).reshape(
|
||||||
-1, num_logical_experts // num_nodes
|
-1, num_logical_experts // num_nodes
|
||||||
)
|
)
|
||||||
phy2mlog, phyrank, mlogcnt = cls.replicate_experts(
|
phy2mlog, replicas_idx, mlogcnt = cls.replicate_experts(
|
||||||
tokens_per_mlog, num_physical_experts // num_nodes
|
tokens_per_mlog, num_physical_experts // num_nodes
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 3: pack physical_experts to GPUs
|
# Step 3: pack physical_experts to GPUs
|
||||||
# [num_layers * num_nodes, num_physical_experts // num_nodes]
|
# Effective per-physical load = logical load divided by replica count.
|
||||||
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
|
tokens_per_phy = np.take_along_axis(tokens_per_mlog / mlogcnt, phy2mlog, axis=1)
|
||||||
pack_index, rank_in_pack = cls.balanced_packing(
|
pack_index, rank_in_pack = cls.balanced_packing(
|
||||||
tokens_per_phy, num_gpus // num_nodes
|
tokens_per_phy, num_gpus // num_nodes
|
||||||
)
|
)
|
||||||
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
|
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
|
||||||
pphy2phy = inverse(phy2pphy)
|
pphy2phy = inverse(phy2pphy)
|
||||||
|
|
||||||
pphy2mlog = phy2mlog.gather(
|
# Reorder node-local logical indices into the post-packing physical order.
|
||||||
-1, pphy2phy
|
pphy2mlog = np.take_along_axis(phy2mlog, pphy2phy, axis=1)
|
||||||
) # [num_layers * num_nodes, num_log_per_nodes]
|
|
||||||
pphy2mlog = (
|
pphy2mlog = (
|
||||||
pphy2mlog.view(num_layers, num_nodes, -1)
|
pphy2mlog.reshape(num_layers, num_nodes, -1)
|
||||||
+ torch.arange(
|
+ np.arange(
|
||||||
0,
|
0,
|
||||||
num_logical_experts,
|
num_logical_experts,
|
||||||
num_logical_experts // num_nodes,
|
num_logical_experts // num_nodes,
|
||||||
device=group_pack_index.device,
|
dtype=np.int64,
|
||||||
).view(1, -1, 1)
|
)[None, :, None]
|
||||||
).flatten(-2)
|
).reshape(num_layers, -1)
|
||||||
pphy2log = mlog2log.gather(-1, pphy2mlog)
|
# Map node-local logical indices back to global logical expert ids.
|
||||||
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
|
pphy2log = np.take_along_axis(mlog2log, pphy2mlog, axis=1)
|
||||||
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
|
# Reorder replica ranks to the post-packing physical ordering.
|
||||||
return pphy2log, pphyrank, logcnt
|
pphy_replicas_idx = np.take_along_axis(replicas_idx, pphy2phy, axis=1).reshape(
|
||||||
|
num_layers, -1
|
||||||
|
)
|
||||||
|
# Convert replica counts back to the original logical ordering.
|
||||||
|
logcnt = np.take_along_axis(mlogcnt.reshape(num_layers, -1), log2mlog, axis=1)
|
||||||
|
return pphy2log, pphy_replicas_idx, logcnt
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def preserve_intragpu_slots(
|
||||||
|
cls,
|
||||||
|
phy2log: np.ndarray,
|
||||||
|
phy_replicas_idx: np.ndarray,
|
||||||
|
num_ranks: int,
|
||||||
|
old_phy2log: np.ndarray,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""
|
||||||
|
Reorder the new mapping per GPU so that experts that remain on the same GPU
|
||||||
|
keep their previous slot positions when possible. Incoming experts to that GPU
|
||||||
|
fill any remaining available slots. This is applied only when the number of GPUs
|
||||||
|
is unchanged and the slots per GPU remain the same between
|
||||||
|
the old and new mappings.
|
||||||
|
"""
|
||||||
|
num_phy_experts = phy2log.shape[1]
|
||||||
|
if num_ranks <= 0 or num_phy_experts % num_ranks != 0:
|
||||||
|
return phy2log, phy_replicas_idx
|
||||||
|
|
||||||
|
# Move to CPU and convert to NumPy for processing
|
||||||
|
slots_per_gpu = num_phy_experts // num_ranks
|
||||||
|
num_layers = phy2log.shape[0]
|
||||||
|
|
||||||
|
post_phy2log = phy2log.copy()
|
||||||
|
post_phy_replicas_idx = phy_replicas_idx.copy()
|
||||||
|
|
||||||
|
for gpu_idx in range(num_ranks):
|
||||||
|
start = gpu_idx * slots_per_gpu
|
||||||
|
end = start + slots_per_gpu
|
||||||
|
# Experts across all layers for this GPU
|
||||||
|
old_local = old_phy2log[:, start:end] # [layers, slots]
|
||||||
|
new_local = phy2log[:, start:end] # [layers, slots]
|
||||||
|
new_ridx = phy_replicas_idx[:, start:end] # [layers, slots]
|
||||||
|
|
||||||
|
used_new_indices = np.zeros((num_layers, slots_per_gpu), dtype=bool)
|
||||||
|
preserved_positions = np.zeros((num_layers, slots_per_gpu), dtype=bool)
|
||||||
|
|
||||||
|
# First pass: preserve same-logical experts in their previous slots
|
||||||
|
for slot_idx in range(slots_per_gpu):
|
||||||
|
# matches: [layers, slots], True where new local experts have
|
||||||
|
# the same logical value as the old from 'slot_idx' and not checked yet
|
||||||
|
matches = (new_local == old_local[:, slot_idx][:, None]) & (
|
||||||
|
~used_new_indices
|
||||||
|
)
|
||||||
|
has_any = matches.any(axis=1)
|
||||||
|
if np.any(has_any):
|
||||||
|
first_idx = np.argmax(matches, axis=1)
|
||||||
|
layer_indices = np.nonzero(has_any)[0]
|
||||||
|
matched_new_positions = first_idx[layer_indices]
|
||||||
|
post_phy2log[layer_indices, start + slot_idx] = new_local[
|
||||||
|
layer_indices, matched_new_positions
|
||||||
|
]
|
||||||
|
post_phy_replicas_idx[layer_indices, start + slot_idx] = new_ridx[
|
||||||
|
layer_indices, matched_new_positions
|
||||||
|
]
|
||||||
|
used_new_indices[layer_indices, matched_new_positions] = True
|
||||||
|
preserved_positions[layer_indices, slot_idx] = True
|
||||||
|
|
||||||
|
# Second pass: fill remaining slots with remaining new experts
|
||||||
|
remaining_mask = ~used_new_indices # [layers, slots]
|
||||||
|
fill_mask = ~preserved_positions # [layers, slots]
|
||||||
|
if remaining_mask.any() and fill_mask.any():
|
||||||
|
idx_base = np.tile(np.arange(slots_per_gpu), (num_layers, 1))
|
||||||
|
# Sentinel value for unavailable positions.
|
||||||
|
large = slots_per_gpu + 1
|
||||||
|
# Priorities: keep original index for available spots, set sentinel
|
||||||
|
# for unavailable; lower is earlier.
|
||||||
|
remaining_priority = np.where(remaining_mask, idx_base, large)
|
||||||
|
fill_priority = np.where(fill_mask, idx_base, large)
|
||||||
|
# Sort to get ordered indices of available src/dst positions per layer.
|
||||||
|
remaining_indices = np.argsort(remaining_priority, axis=1)
|
||||||
|
fill_indices = np.argsort(fill_priority, axis=1)
|
||||||
|
# Fill count per layer (cannot exceed either side).
|
||||||
|
remaining_counts = remaining_mask.sum(axis=1)
|
||||||
|
fill_counts = fill_mask.sum(axis=1)
|
||||||
|
take_counts = np.minimum(remaining_counts, fill_counts)
|
||||||
|
# Assign remaining new experts to remaining slots per layer.
|
||||||
|
for layer_idx in range(num_layers):
|
||||||
|
k = int(take_counts[layer_idx])
|
||||||
|
if k <= 0:
|
||||||
|
continue
|
||||||
|
src_pos = remaining_indices[layer_idx, :k]
|
||||||
|
dst_pos = fill_indices[layer_idx, :k]
|
||||||
|
post_phy2log[layer_idx, start + dst_pos] = new_local[
|
||||||
|
layer_idx, src_pos
|
||||||
|
]
|
||||||
|
post_phy_replicas_idx[layer_idx, start + dst_pos] = new_ridx[
|
||||||
|
layer_idx, src_pos
|
||||||
|
]
|
||||||
|
|
||||||
|
return post_phy2log, post_phy_replicas_idx
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def rebalance_experts(
|
def rebalance_experts(
|
||||||
@ -215,6 +301,7 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
num_groups: int,
|
num_groups: int,
|
||||||
num_nodes: int,
|
num_nodes: int,
|
||||||
num_ranks: int,
|
num_ranks: int,
|
||||||
|
old_global_expert_indices: torch.Tensor | None = None,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Entry point for expert-parallelism load balancer.
|
Entry point for expert-parallelism load balancer.
|
||||||
@ -228,7 +315,9 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
num_nodes: number of server nodes, where the intra-node network
|
num_nodes: number of server nodes, where the intra-node network
|
||||||
(e.g, NVLink) is faster
|
(e.g, NVLink) is faster
|
||||||
num_ranks: number of ranks, must be a multiple of `num_nodes`
|
num_ranks: number of ranks, must be a multiple of `num_nodes`
|
||||||
|
old_global_expert_indices: [layers, num_logical_experts], the old global
|
||||||
|
expert indices. Used to avoid unnecessary weight copying
|
||||||
|
for experts moving within one rank.
|
||||||
Returns:
|
Returns:
|
||||||
phy2log: [layers, num_replicas], the expert
|
phy2log: [layers, num_replicas], the expert
|
||||||
index of each replica
|
index of each replica
|
||||||
@ -237,31 +326,51 @@ class DefaultEplbPolicy(AbstractEplbPolicy):
|
|||||||
logcnt: [layers, num_logical_experts], number of
|
logcnt: [layers, num_logical_experts], number of
|
||||||
physical replicas for each logical expert
|
physical replicas for each logical expert
|
||||||
"""
|
"""
|
||||||
|
device = weight.device
|
||||||
num_layers, num_logical_experts = weight.shape
|
num_layers, num_logical_experts = weight.shape
|
||||||
weight = weight.float()
|
weight_np = weight.float().cpu().numpy()
|
||||||
|
old_phy2log_np = (
|
||||||
|
old_global_expert_indices.cpu().numpy()
|
||||||
|
if old_global_expert_indices is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
if num_groups % num_nodes == 0:
|
if num_groups % num_nodes == 0:
|
||||||
# use hierarchical load-balance policy
|
# use hierarchical load-balance policy
|
||||||
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical(
|
phy2log_np, phy_replicas_idx_np, logcnt_np = (
|
||||||
weight, num_replicas, num_groups, num_nodes, num_ranks
|
cls.rebalance_experts_hierarchical(
|
||||||
|
weight_np, num_replicas, num_groups, num_nodes, num_ranks
|
||||||
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# use global load-balance policy
|
# use global load-balance policy
|
||||||
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical(
|
phy2log_np, phy_replicas_idx_np, logcnt_np = (
|
||||||
weight, num_replicas, 1, 1, num_ranks
|
cls.rebalance_experts_hierarchical(
|
||||||
|
weight_np, num_replicas, 1, 1, num_ranks
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optional postprocessing to preserve slots for experts moving
|
||||||
|
# within the same GPU
|
||||||
|
# Only apply when the number of GPUs and slots per GPU remain unchanged.
|
||||||
|
# Helps to avoid unnecessary weight copying when experts move
|
||||||
|
# within the same GPU.
|
||||||
|
if old_global_expert_indices is not None:
|
||||||
|
phy2log_np, phy_replicas_idx_np = cls.preserve_intragpu_slots(
|
||||||
|
phy2log_np, phy_replicas_idx_np, num_ranks, old_phy2log_np
|
||||||
)
|
)
|
||||||
num_redundant_experts = num_replicas - num_logical_experts
|
num_redundant_experts = num_replicas - num_logical_experts
|
||||||
maxlogcnt = num_redundant_experts + 1
|
maxlogcnt = num_redundant_experts + 1
|
||||||
log2phy: torch.Tensor = torch.full(
|
log2phy_np = np.full(
|
||||||
(num_layers, num_logical_experts, maxlogcnt),
|
(num_layers, num_logical_experts, maxlogcnt), -1, dtype=np.int64
|
||||||
-1,
|
|
||||||
dtype=torch.int64,
|
|
||||||
device=logcnt.device,
|
|
||||||
)
|
)
|
||||||
log2phy.view(num_layers, -1).scatter_(
|
layer_indices = np.arange(num_layers)[:, None]
|
||||||
-1,
|
replica_indices = np.tile(
|
||||||
phy2log * maxlogcnt + phyrank,
|
np.arange(num_replicas, dtype=np.int64), (num_layers, 1)
|
||||||
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
|
|
||||||
num_layers, -1
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
log2phy_np[layer_indices, phy2log_np, phy_replicas_idx_np] = replica_indices
|
||||||
|
|
||||||
|
phy2log = torch.from_numpy(phy2log_np).to(device)
|
||||||
|
log2phy = torch.from_numpy(log2phy_np).to(device)
|
||||||
|
logcnt = torch.from_numpy(logcnt_np).to(device)
|
||||||
return phy2log, log2phy, logcnt
|
return phy2log, log2phy, logcnt
|
||||||
|
|||||||
@ -6,9 +6,10 @@ The actual execution of the rearrangement.
|
|||||||
This involves the exchange of expert weights between GPUs.
|
This involves the exchange of expert weights between GPUs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from collections.abc import Iterable, MutableSequence, Sequence
|
from collections.abc import Iterable, Sequence
|
||||||
from functools import partial
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed import (
|
from torch.distributed import (
|
||||||
P2POp,
|
P2POp,
|
||||||
@ -18,214 +19,318 @@ from torch.distributed import (
|
|||||||
get_global_rank,
|
get_global_rank,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from vllm.logger import init_logger
|
||||||
|
|
||||||
def idx_local_to_global(
|
logger = init_logger(__name__)
|
||||||
local_idx: int,
|
|
||||||
local_cnt: int,
|
|
||||||
ep_rank: int,
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Convert a local expert index to a global expert index.
|
|
||||||
"""
|
|
||||||
return ep_rank * local_cnt + local_idx
|
|
||||||
|
|
||||||
|
|
||||||
def idx_global_to_local(
|
@dataclass
|
||||||
global_idx: int,
|
class RecvMetadata:
|
||||||
local_cnt: int,
|
"""Metadata describing remote receives during EPLB rebalancing."""
|
||||||
ep_rank: int,
|
|
||||||
) -> int:
|
recv_primary_mask: np.ndarray
|
||||||
"""
|
"""Mask of (num_local_experts,) indicating primary experts received."""
|
||||||
Convert a global expert index to a local expert index.
|
recv_count: int
|
||||||
"""
|
"""Number of received experts for the layer."""
|
||||||
return global_idx - ep_rank * local_cnt
|
recv_expert_ids: np.ndarray
|
||||||
|
"""Expert ids (num_local_experts,) of remote primary experts."""
|
||||||
|
recv_dst_rows: np.ndarray
|
||||||
|
"""Target expert indices (num_local_experts,) in local tensors to send."""
|
||||||
|
|
||||||
|
|
||||||
def global_idx_to_rank(
|
# Type alias for the result of move_to_buffer or transfer_layer
|
||||||
global_idx: int,
|
MoveToBufferResult = tuple[np.ndarray, np.ndarray, RecvMetadata]
|
||||||
local_cnt: int,
|
|
||||||
) -> int:
|
|
||||||
"""
|
|
||||||
Convert a global expert index to a rank index.
|
|
||||||
"""
|
|
||||||
return global_idx // local_cnt
|
|
||||||
|
|
||||||
|
|
||||||
def get_ep_ranks_with_expert(
|
def get_ep_ranks_with_experts_batch(
|
||||||
idx: int,
|
expert_ids: np.ndarray,
|
||||||
num_local_experts: int,
|
num_local_experts: int,
|
||||||
old_indices: Sequence[int],
|
old_indices: np.ndarray,
|
||||||
new_indices: Sequence[int],
|
new_indices: np.ndarray,
|
||||||
) -> tuple[MutableSequence[int], MutableSequence[int]]:
|
) -> tuple[dict[int, list[int]], dict[int, list[int]]]:
|
||||||
"""
|
"""
|
||||||
Get the ranks of the experts that need to be exchanged.
|
Get the ranks of the experts that need to be exchanged.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
idx: The index of the expert.
|
expert_ids: 1D array of expert indices to query.
|
||||||
num_local_experts: The number of local experts.
|
num_local_experts: The number of local experts.
|
||||||
old_indices: The old indices of the experts.
|
old_indices: The old indices of the experts.
|
||||||
new_indices: The new indices of the experts.
|
new_indices: The new indices of the experts.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A tuple of two lists:
|
A tuple of two dictionaries mapping expert_id to:
|
||||||
- The ranks of the experts that need to be sent.
|
- ranks_to_send: The ranks that have this expert and need to send.
|
||||||
- The ranks of the experts that need to be received.
|
- ranks_to_recv: The ranks that need to receive this expert.
|
||||||
"""
|
"""
|
||||||
global2rank = partial(
|
ranks_to_send_map: dict[int, list[int]] = {}
|
||||||
global_idx_to_rank,
|
ranks_to_recv_map: dict[int, list[int]] = {}
|
||||||
local_cnt=num_local_experts,
|
|
||||||
)
|
|
||||||
|
|
||||||
ranks_to_send: list[int] = []
|
# Fast path: if no experts, return empty dicts
|
||||||
ranks_to_recv: list[int] = []
|
if expert_ids.size == 0:
|
||||||
|
return ranks_to_send_map, ranks_to_recv_map
|
||||||
|
|
||||||
for i, e in enumerate(old_indices):
|
unique_experts = np.unique(expert_ids)
|
||||||
if e == idx:
|
num_positions = len(old_indices)
|
||||||
rank = global2rank(i)
|
position_indices = np.arange(num_positions, dtype=np.int32)
|
||||||
if not ranks_to_send or ranks_to_send[-1] != rank:
|
|
||||||
ranks_to_send.append(rank)
|
|
||||||
|
|
||||||
for i, e in enumerate(new_indices):
|
# Vectorized approach: find all positions matching any query expert in one pass
|
||||||
if e == idx:
|
# Use np.isin to get boolean masks for all relevant positions at once
|
||||||
rank = global2rank(i)
|
old_relevant_mask = np.isin(old_indices, unique_experts)
|
||||||
if not ranks_to_recv or ranks_to_recv[-1] != rank:
|
new_relevant_mask = np.isin(new_indices, unique_experts)
|
||||||
ranks_to_recv.append(rank)
|
|
||||||
|
|
||||||
# Remove those ranks that can get this expert locally.
|
# Process old_indices (send ranks)
|
||||||
ranks_to_send_set = set(ranks_to_send)
|
if np.any(old_relevant_mask):
|
||||||
ranks_to_recv_actual = [
|
old_relevant_positions = position_indices[old_relevant_mask]
|
||||||
rank for rank in ranks_to_recv if rank not in ranks_to_send_set
|
old_relevant_experts = old_indices[old_relevant_mask]
|
||||||
]
|
old_relevant_ranks = old_relevant_positions // num_local_experts
|
||||||
|
|
||||||
return ranks_to_send, ranks_to_recv_actual
|
# Sort by expert first, then by position (to maintain first-appearance order)
|
||||||
|
sort_order = np.lexsort((old_relevant_positions, old_relevant_experts))
|
||||||
|
sorted_experts = old_relevant_experts[sort_order]
|
||||||
|
sorted_ranks = old_relevant_ranks[sort_order]
|
||||||
|
|
||||||
|
# Find boundaries where expert changes
|
||||||
|
expert_boundaries = np.concatenate(
|
||||||
|
[[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]]
|
||||||
|
)
|
||||||
|
|
||||||
|
# For each expert, extract unique ranks in order of first appearance
|
||||||
|
for i in range(len(expert_boundaries) - 1):
|
||||||
|
start, end = expert_boundaries[i], expert_boundaries[i + 1]
|
||||||
|
expert = int(sorted_experts[start])
|
||||||
|
expert_ranks = sorted_ranks[start:end]
|
||||||
|
|
||||||
|
# Get unique ranks preserving order
|
||||||
|
_, unique_idx = np.unique(expert_ranks, return_index=True)
|
||||||
|
unique_ranks = expert_ranks[np.sort(unique_idx)]
|
||||||
|
ranks_to_send_map[expert] = unique_ranks.tolist()
|
||||||
|
|
||||||
|
# Process new_indices (recv ranks)
|
||||||
|
if np.any(new_relevant_mask):
|
||||||
|
new_relevant_positions = position_indices[new_relevant_mask]
|
||||||
|
new_relevant_experts = new_indices[new_relevant_mask]
|
||||||
|
new_relevant_ranks = new_relevant_positions // num_local_experts
|
||||||
|
|
||||||
|
# Sort by expert first, then by position
|
||||||
|
sort_order = np.lexsort((new_relevant_positions, new_relevant_experts))
|
||||||
|
sorted_experts = new_relevant_experts[sort_order]
|
||||||
|
sorted_ranks = new_relevant_ranks[sort_order]
|
||||||
|
|
||||||
|
# Find boundaries where expert changes
|
||||||
|
expert_boundaries = np.concatenate(
|
||||||
|
[[0], np.where(np.diff(sorted_experts) != 0)[0] + 1, [len(sorted_experts)]]
|
||||||
|
)
|
||||||
|
|
||||||
|
# For each expert, extract unique ranks and exclude local copies
|
||||||
|
for i in range(len(expert_boundaries) - 1):
|
||||||
|
start, end = expert_boundaries[i], expert_boundaries[i + 1]
|
||||||
|
expert = int(sorted_experts[start])
|
||||||
|
expert_ranks = sorted_ranks[start:end]
|
||||||
|
|
||||||
|
# Get unique ranks preserving order
|
||||||
|
_, unique_idx = np.unique(expert_ranks, return_index=True)
|
||||||
|
unique_ranks = expert_ranks[np.sort(unique_idx)]
|
||||||
|
|
||||||
|
# Remove ranks that have local copies (in send map)
|
||||||
|
send_ranks_set = set(ranks_to_send_map.get(expert, []))
|
||||||
|
recv_ranks_actual = [
|
||||||
|
int(r) for r in unique_ranks if r not in send_ranks_set
|
||||||
|
]
|
||||||
|
ranks_to_recv_map[expert] = recv_ranks_actual
|
||||||
|
|
||||||
|
# Handle experts that only appear in old (send only) or new (recv only)
|
||||||
|
for expert in unique_experts:
|
||||||
|
expert = int(expert)
|
||||||
|
if expert not in ranks_to_send_map:
|
||||||
|
ranks_to_send_map[expert] = []
|
||||||
|
if expert not in ranks_to_recv_map:
|
||||||
|
ranks_to_recv_map[expert] = []
|
||||||
|
|
||||||
|
return ranks_to_send_map, ranks_to_recv_map
|
||||||
|
|
||||||
|
|
||||||
def move_to_buffer(
|
def move_to_buffer(
|
||||||
num_local_experts: int,
|
num_local_experts: int,
|
||||||
old_indices: Sequence[int],
|
old_indices: np.ndarray,
|
||||||
new_indices: Sequence[int],
|
new_indices: np.ndarray,
|
||||||
expert_weights: Iterable[torch.Tensor],
|
expert_weights: Iterable[torch.Tensor],
|
||||||
expert_weights_buffer: Sequence[torch.Tensor],
|
expert_weights_buffers: Sequence[torch.Tensor],
|
||||||
cuda_stream: torch.cuda.Stream | None,
|
cuda_stream: torch.cuda.Stream | None,
|
||||||
ep_group: ProcessGroup,
|
ep_group: ProcessGroup,
|
||||||
) -> tuple[list[bool], list[bool], dict[int, int]]:
|
) -> MoveToBufferResult:
|
||||||
"""
|
"""
|
||||||
Perform expert weights rearrangement of one layer.
|
Rearranges expert weights during EPLB rebalancing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
num_local_experts: Number of local experts.
|
||||||
|
old_indices: (num_experts_total,) ndarray of current (old)
|
||||||
|
global-to-local expert assignments.
|
||||||
|
new_indices: (num_experts_total,) ndarray of desired (new)
|
||||||
|
global-to-local assignments after rebalance.
|
||||||
|
expert_weights: Original expert weights for the layer.
|
||||||
|
expert_weights_buffers: Intermediate buffers (one per tensor).
|
||||||
|
cuda_stream: CUDA stream for async copies (can be None for sync mode).
|
||||||
|
ep_group: Distributed process group for expert parallel comms.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
is_unchanged (np.ndarray): (num_local_experts,), True where an expert row
|
||||||
|
is unchanged after rebalance.
|
||||||
|
is_received_locally (np.ndarray): (num_local_experts,), True where a row
|
||||||
|
can be updated from local data.
|
||||||
|
RecvMetadata: Metadata needed for completing remote weight transfers.
|
||||||
"""
|
"""
|
||||||
|
assert old_indices.shape == new_indices.shape
|
||||||
ep_rank = ep_group.rank()
|
ep_rank = ep_group.rank()
|
||||||
local2global = partial(
|
|
||||||
idx_local_to_global,
|
recv_primary_mask = np.zeros((num_local_experts,), dtype=np.bool_)
|
||||||
local_cnt=num_local_experts,
|
send_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64)
|
||||||
ep_rank=ep_rank,
|
send_src_rows = np.full((num_local_experts,), -1, dtype=np.int32)
|
||||||
|
recv_expert_ids = np.full((num_local_experts,), -1, dtype=np.int64)
|
||||||
|
recv_dst_rows = np.full((num_local_experts,), -1, dtype=np.int32)
|
||||||
|
|
||||||
|
base = ep_rank * num_local_experts
|
||||||
|
local_rows = np.arange(num_local_experts, dtype=np.int32)
|
||||||
|
local_global = base + local_rows
|
||||||
|
|
||||||
|
old_local_expert_ids = old_indices[local_global]
|
||||||
|
new_local_expert_ids = new_indices[local_global]
|
||||||
|
|
||||||
|
# Unchanged mask
|
||||||
|
is_unchanged = old_local_expert_ids == new_local_expert_ids
|
||||||
|
|
||||||
|
# Local receive eligibility
|
||||||
|
new_valid = new_local_expert_ids != -1
|
||||||
|
can_recv_local = np.isin(
|
||||||
|
new_local_expert_ids, old_local_expert_ids, assume_unique=False
|
||||||
|
)
|
||||||
|
is_received_locally = np.logical_or(
|
||||||
|
is_unchanged, np.logical_and(new_valid, can_recv_local)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 0. Do nothing for experts that did not change.
|
# Send map: first src row per unique expert present locally in old mapping
|
||||||
is_unchanged = [
|
send_count = 0
|
||||||
old_indices[local2global(i)] == new_indices[local2global(i)]
|
valid_old = old_local_expert_ids != -1
|
||||||
for i in range(num_local_experts)
|
if np.any(valid_old):
|
||||||
]
|
uniq_experts, first_idx = np.unique(
|
||||||
|
old_local_expert_ids[valid_old], return_index=True
|
||||||
|
)
|
||||||
|
filtered_rows = local_rows[valid_old]
|
||||||
|
src_rows = filtered_rows[first_idx]
|
||||||
|
send_count = int(uniq_experts.shape[0])
|
||||||
|
send_expert_ids[:send_count] = uniq_experts
|
||||||
|
send_src_rows[:send_count] = src_rows
|
||||||
|
|
||||||
# 1. Perform weight copy inside the local rank.
|
# Recv map: primary dst per unique expert needed remotely
|
||||||
is_received_locally = is_unchanged[:]
|
recv_count = 0
|
||||||
for src in range(num_local_experts):
|
need_recv_mask = np.logical_and(~is_received_locally, new_valid)
|
||||||
src_global = local2global(src)
|
if np.any(need_recv_mask):
|
||||||
for dst in range(num_local_experts):
|
desired_experts = new_local_expert_ids[need_recv_mask]
|
||||||
dst_global = local2global(dst)
|
desired_dsts = local_rows[need_recv_mask]
|
||||||
if is_received_locally[dst]:
|
uniq_recv_experts, uniq_indices = np.unique(desired_experts, return_index=True)
|
||||||
continue
|
dst_rows = desired_dsts[uniq_indices]
|
||||||
if old_indices[src_global] == -1 or new_indices[dst_global] == -1:
|
recv_count = int(uniq_recv_experts.shape[0])
|
||||||
continue
|
recv_expert_ids[:recv_count] = uniq_recv_experts
|
||||||
if old_indices[src_global] == new_indices[dst_global]:
|
recv_dst_rows[:recv_count] = dst_rows
|
||||||
is_received_locally[dst] = True
|
recv_primary_mask[dst_rows] = True
|
||||||
for weight, buffer in zip(expert_weights, expert_weights_buffer):
|
|
||||||
with torch.cuda.stream(cuda_stream):
|
eligible_local_buffer_mask = np.logical_and(~is_unchanged, is_received_locally)
|
||||||
buffer[dst].copy_(weight[src], non_blocking=True)
|
|
||||||
|
# 1. Local moves into tmp buffers
|
||||||
|
if bool(eligible_local_buffer_mask.any()) and send_count > 0:
|
||||||
|
dest_indices = np.nonzero(eligible_local_buffer_mask)[0].tolist()
|
||||||
|
expert_to_src_map = dict(
|
||||||
|
zip(send_expert_ids[:send_count], send_src_rows[:send_count])
|
||||||
|
)
|
||||||
|
for dst in dest_indices:
|
||||||
|
expert = new_local_expert_ids[dst]
|
||||||
|
src_local = expert_to_src_map.get(expert, -1)
|
||||||
|
if src_local != -1:
|
||||||
|
for w, b in zip(expert_weights, expert_weights_buffers):
|
||||||
|
b[dst].copy_(w[src_local], non_blocking=True)
|
||||||
|
|
||||||
p2p_ops: list[P2POp] = []
|
p2p_ops: list[P2POp] = []
|
||||||
|
|
||||||
# 2. Initiate sending of weights.
|
# Pre-compute global ranks mapping
|
||||||
experts_send_loc: dict[int, int] = {}
|
ep_size = ep_group.size()
|
||||||
for src in range(num_local_experts):
|
rank_to_global = {rank: get_global_rank(ep_group, rank) for rank in range(ep_size)}
|
||||||
expert = old_indices[local2global(src)]
|
|
||||||
if expert == -1:
|
|
||||||
continue
|
|
||||||
if expert in experts_send_loc:
|
|
||||||
continue
|
|
||||||
experts_send_loc[expert] = src
|
|
||||||
|
|
||||||
# We need to sort here to match send/recv
|
# 2. Post sends
|
||||||
for expert, src in sorted(experts_send_loc.items()):
|
if send_count > 0:
|
||||||
ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert(
|
experts = send_expert_ids[:send_count]
|
||||||
expert,
|
srcs = send_src_rows[:send_count]
|
||||||
|
order = np.argsort(experts, kind="stable")
|
||||||
|
experts = experts[order]
|
||||||
|
srcs = srcs[order]
|
||||||
|
|
||||||
|
send_map, recv_map = get_ep_ranks_with_experts_batch(
|
||||||
|
experts,
|
||||||
num_local_experts,
|
num_local_experts,
|
||||||
old_indices,
|
old_indices,
|
||||||
new_indices,
|
new_indices,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Calculate the ranks to send by this rank
|
for expert, src in zip(experts.tolist(), srcs.tolist()):
|
||||||
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
|
ranks_to_send = send_map[expert]
|
||||||
sender_pos = ranks_to_send.index(ep_rank)
|
ranks_to_recv = recv_map[expert]
|
||||||
recv_begin = sender_pos * num_dst_per_sender
|
if not ranks_to_send or not ranks_to_recv:
|
||||||
recv_end = recv_begin + num_dst_per_sender
|
continue
|
||||||
recv_ranks = ranks_to_recv[recv_begin:recv_end]
|
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
|
||||||
|
sender_pos = ranks_to_send.index(ep_rank)
|
||||||
|
recv_begin = sender_pos * num_dst_per_sender
|
||||||
|
recv_end = recv_begin + num_dst_per_sender
|
||||||
|
recv_ranks = ranks_to_recv[recv_begin:recv_end]
|
||||||
|
remainder_start = len(ranks_to_send) * num_dst_per_sender
|
||||||
|
recver_pos = remainder_start + sender_pos
|
||||||
|
if recver_pos < len(ranks_to_recv):
|
||||||
|
recv_ranks.append(ranks_to_recv[recver_pos])
|
||||||
|
for dst in recv_ranks:
|
||||||
|
dst_global = rank_to_global[dst]
|
||||||
|
p2p_ops += [
|
||||||
|
P2POp(
|
||||||
|
torch.distributed.isend,
|
||||||
|
w[src],
|
||||||
|
dst_global,
|
||||||
|
)
|
||||||
|
for w in expert_weights
|
||||||
|
]
|
||||||
|
|
||||||
# Tackle remainders
|
# 3. Post recvs
|
||||||
remainder_start = len(ranks_to_send) * num_dst_per_sender
|
if recv_count > 0:
|
||||||
recver_pos = remainder_start + sender_pos
|
experts = recv_expert_ids[:recv_count]
|
||||||
if recver_pos < len(ranks_to_recv):
|
dsts = recv_dst_rows[:recv_count]
|
||||||
recv_ranks.append(ranks_to_recv[recver_pos])
|
order = np.argsort(experts, kind="stable")
|
||||||
|
experts = experts[order]
|
||||||
|
dsts = dsts[order]
|
||||||
|
|
||||||
for dst in recv_ranks:
|
send_map, recv_map = get_ep_ranks_with_experts_batch(
|
||||||
dst_global = get_global_rank(ep_group, dst)
|
experts,
|
||||||
|
num_local_experts,
|
||||||
|
old_indices,
|
||||||
|
new_indices,
|
||||||
|
)
|
||||||
|
|
||||||
|
for expert, dst in zip(experts.tolist(), dsts.tolist()):
|
||||||
|
ranks_to_send = send_map[expert]
|
||||||
|
ranks_to_recv = recv_map[expert]
|
||||||
|
if not ranks_to_send or not ranks_to_recv:
|
||||||
|
continue
|
||||||
|
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
|
||||||
|
recver_pos = ranks_to_recv.index(ep_rank)
|
||||||
|
remainder_start = len(ranks_to_send) * num_dst_per_sender
|
||||||
|
if recver_pos < remainder_start:
|
||||||
|
src = ranks_to_send[recver_pos // num_dst_per_sender]
|
||||||
|
else:
|
||||||
|
src = ranks_to_send[recver_pos - remainder_start]
|
||||||
|
src_global = rank_to_global[src]
|
||||||
p2p_ops += [
|
p2p_ops += [
|
||||||
P2POp(
|
P2POp(
|
||||||
torch.distributed.isend,
|
torch.distributed.irecv,
|
||||||
weight[src],
|
b[dst],
|
||||||
dst_global,
|
src_global,
|
||||||
)
|
)
|
||||||
for weight in expert_weights
|
for b in expert_weights_buffers
|
||||||
]
|
]
|
||||||
|
|
||||||
# 3. Initiate receiving of weights.
|
|
||||||
experts_recv_loc: dict[int, int] = {}
|
|
||||||
for dst in range(num_local_experts):
|
|
||||||
if is_received_locally[dst]:
|
|
||||||
continue
|
|
||||||
expert = new_indices[local2global(dst)]
|
|
||||||
if expert == -1:
|
|
||||||
continue
|
|
||||||
if expert in experts_recv_loc:
|
|
||||||
continue
|
|
||||||
experts_recv_loc[expert] = dst
|
|
||||||
|
|
||||||
# We need to sort here to match send/recv
|
|
||||||
for expert, dst in sorted(experts_recv_loc.items()):
|
|
||||||
ranks_to_send, ranks_to_recv = get_ep_ranks_with_expert(
|
|
||||||
expert,
|
|
||||||
num_local_experts,
|
|
||||||
old_indices,
|
|
||||||
new_indices,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Calculate the rank to recv by this rank
|
|
||||||
num_dst_per_sender = len(ranks_to_recv) // len(ranks_to_send)
|
|
||||||
recver_pos = ranks_to_recv.index(ep_rank)
|
|
||||||
remainder_start = len(ranks_to_send) * num_dst_per_sender
|
|
||||||
if recver_pos < remainder_start:
|
|
||||||
src = ranks_to_send[recver_pos // num_dst_per_sender]
|
|
||||||
else:
|
|
||||||
src = ranks_to_send[recver_pos - remainder_start]
|
|
||||||
|
|
||||||
src_global = get_global_rank(ep_group, src)
|
|
||||||
p2p_ops += [
|
|
||||||
P2POp(
|
|
||||||
torch.distributed.irecv,
|
|
||||||
weight[dst],
|
|
||||||
src_global,
|
|
||||||
)
|
|
||||||
for weight in expert_weights_buffer
|
|
||||||
]
|
|
||||||
|
|
||||||
# 4. Execute the P2P operations. The real communication happens here.
|
# 4. Execute the P2P operations. The real communication happens here.
|
||||||
if p2p_ops and cuda_stream is not None:
|
if p2p_ops and cuda_stream is not None:
|
||||||
with torch.cuda.stream(cuda_stream):
|
with torch.cuda.stream(cuda_stream):
|
||||||
@ -237,51 +342,107 @@ def move_to_buffer(
|
|||||||
for req in reqs:
|
for req in reqs:
|
||||||
req.wait()
|
req.wait()
|
||||||
# wait for the communication to finish
|
# wait for the communication to finish
|
||||||
return is_unchanged, is_received_locally, experts_recv_loc
|
return (
|
||||||
|
is_unchanged,
|
||||||
|
is_received_locally,
|
||||||
|
RecvMetadata(
|
||||||
|
recv_primary_mask=recv_primary_mask,
|
||||||
|
recv_count=recv_count,
|
||||||
|
recv_expert_ids=recv_expert_ids,
|
||||||
|
recv_dst_rows=recv_dst_rows,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def move_from_buffer(
|
def move_from_buffer(
|
||||||
expert_weights: Iterable[torch.Tensor],
|
expert_weights: Iterable[torch.Tensor],
|
||||||
expert_weights_buffer: list[torch.Tensor],
|
expert_weights_buffers: list[torch.Tensor],
|
||||||
is_unchanged: list[bool],
|
is_unchanged: np.ndarray,
|
||||||
is_received_locally: list[bool],
|
is_received_locally: np.ndarray,
|
||||||
experts_recv_loc: dict[int, int],
|
recv_metadata: RecvMetadata,
|
||||||
new_indices: Sequence[int],
|
new_indices: np.ndarray,
|
||||||
ep_group: ProcessGroup,
|
ep_rank: int,
|
||||||
) -> None:
|
) -> None:
|
||||||
ep_rank = ep_group.rank()
|
"""
|
||||||
num_local_experts = len(is_unchanged)
|
Copies expert weights from communication buffers back to the target weight tensors
|
||||||
|
after EPLB rebalancing.
|
||||||
|
|
||||||
local2global = partial(
|
Args:
|
||||||
idx_local_to_global, local_cnt=num_local_experts, ep_rank=ep_rank
|
expert_weights: List of the actual MoE layer weights used in the execution.
|
||||||
|
expert_weights_buffers: Intermediate buffers containing the experts weights
|
||||||
|
after the transfer is completed.
|
||||||
|
is_unchanged: (num_local_experts,), True where an expert row is unchanged.
|
||||||
|
is_received_locally: (num_local_experts,), True where a row is updated locally.
|
||||||
|
recv_metadata: RecvMetadata containing remote receive metadata.
|
||||||
|
new_indices: (num_experts_total,) mapping from local rows to desired
|
||||||
|
(possibly global) expert id, after rebalance.
|
||||||
|
ep_rank: Rank of the process in the expert parallel group.
|
||||||
|
"""
|
||||||
|
recv_primary_mask = recv_metadata.recv_primary_mask
|
||||||
|
recv_count = recv_metadata.recv_count
|
||||||
|
recv_expert_ids = recv_metadata.recv_expert_ids
|
||||||
|
recv_dst_rows = recv_metadata.recv_dst_rows
|
||||||
|
num_local_experts = is_unchanged.shape[0]
|
||||||
|
|
||||||
|
# Mask for rows to copy back from buffers:
|
||||||
|
# copy if locally received OR remote primary recv
|
||||||
|
copy_mask = np.logical_or(is_received_locally, recv_primary_mask)
|
||||||
|
dest_mask_np = np.logical_and(~is_unchanged, copy_mask)
|
||||||
|
if bool(dest_mask_np.any()):
|
||||||
|
dest_indices = np.nonzero(dest_mask_np)[0].tolist()
|
||||||
|
for dst in dest_indices:
|
||||||
|
for w, b in zip(expert_weights, expert_weights_buffers):
|
||||||
|
w[dst].copy_(b[dst], non_blocking=True)
|
||||||
|
|
||||||
|
if recv_count == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Duplicate remote received rows to non-primary duplicate dsts
|
||||||
|
base = ep_rank * num_local_experts
|
||||||
|
local_experts = new_indices[base + np.arange(num_local_experts, dtype=np.int32)]
|
||||||
|
duplicate_mask = np.logical_and(
|
||||||
|
np.logical_and(~is_unchanged, ~is_received_locally),
|
||||||
|
np.logical_and(~recv_primary_mask, local_experts != -1),
|
||||||
)
|
)
|
||||||
|
# All received experts are unique in the destination, so no need to copy duplicates
|
||||||
|
if not bool(duplicate_mask.any()):
|
||||||
|
return
|
||||||
|
|
||||||
for dst in range(num_local_experts):
|
dup_dst_rows = np.nonzero(duplicate_mask)[0]
|
||||||
if is_unchanged[dst]:
|
dup_experts = local_experts[dup_dst_rows]
|
||||||
continue
|
|
||||||
if is_received_locally[dst]:
|
prim_experts = recv_expert_ids[:recv_count]
|
||||||
for weight, buffer in zip(expert_weights, expert_weights_buffer):
|
prim_dsts = recv_dst_rows[:recv_count]
|
||||||
weight[dst].copy_(buffer[dst], non_blocking=True)
|
order = np.argsort(prim_experts, kind="stable")
|
||||||
else:
|
prim_experts_sorted = prim_experts[order]
|
||||||
expert = new_indices[local2global(dst)]
|
prim_dsts_sorted = prim_dsts[order]
|
||||||
if expert == -1:
|
pos = np.searchsorted(prim_experts_sorted, dup_experts)
|
||||||
continue
|
valid = np.logical_and(
|
||||||
src = experts_recv_loc[expert]
|
pos < prim_experts_sorted.shape[0],
|
||||||
for weight, buffer in zip(expert_weights, expert_weights_buffer):
|
prim_experts_sorted[np.minimum(pos, prim_experts_sorted.shape[0] - 1)]
|
||||||
weight[dst].copy_(buffer[src], non_blocking=True)
|
== dup_experts,
|
||||||
|
)
|
||||||
|
if not bool(valid.any()):
|
||||||
|
return
|
||||||
|
|
||||||
|
matched_dst_rows = dup_dst_rows[valid]
|
||||||
|
matched_src_rows = prim_dsts_sorted[pos[valid]]
|
||||||
|
|
||||||
|
for dst, src in zip(matched_dst_rows.tolist(), matched_src_rows.tolist()):
|
||||||
|
for w in expert_weights:
|
||||||
|
w[dst].copy_(w[src], non_blocking=True)
|
||||||
|
|
||||||
|
|
||||||
async def transfer_layer(
|
async def transfer_layer(
|
||||||
old_global_expert_indices: torch.Tensor,
|
old_layer_indices: torch.Tensor,
|
||||||
new_global_expert_indices: torch.Tensor,
|
new_layer_indices: torch.Tensor,
|
||||||
expert_weights: Sequence[Iterable[torch.Tensor]],
|
expert_weights: Iterable[torch.Tensor],
|
||||||
expert_weights_buffer: Sequence[torch.Tensor],
|
expert_weights_buffer: Sequence[torch.Tensor],
|
||||||
ep_group: ProcessGroup,
|
ep_group: ProcessGroup,
|
||||||
is_profile: bool = False,
|
is_profile: bool = False,
|
||||||
layer: int = 0,
|
|
||||||
cuda_stream: torch.cuda.Stream | None = None,
|
cuda_stream: torch.cuda.Stream | None = None,
|
||||||
rank_mapping: dict[int, int] | None = None,
|
rank_mapping: dict[int, int] | None = None,
|
||||||
) -> tuple[list[bool], list[bool], dict[int, int]]:
|
) -> MoveToBufferResult:
|
||||||
"""
|
"""
|
||||||
Rearranges the expert weights in place according to the new expert indices.
|
Rearranges the expert weights in place according to the new expert indices.
|
||||||
|
|
||||||
@ -289,50 +450,68 @@ async def transfer_layer(
|
|||||||
while keys are physical.
|
while keys are physical.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
old_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
|
old_layer_indices: Shape (num_physical_experts,).
|
||||||
new_global_expert_indices: Shape (num_moe_layers, num_physical_experts).
|
new_layer_indices: Shape (num_physical_experts,).
|
||||||
expert_weights: A sequence of shape (num_moe_layers)(weight_count)
|
expert_weights: Iterable of weight tensors for this layer, each with shape
|
||||||
of tensors of shape (num_local_physical_experts, hidden_size_i).
|
(num_local_physical_experts, hidden_size_i).
|
||||||
For example, a linear layer may have up and down projection,
|
For example, a linear layer may have up and down projection.
|
||||||
so weight_count = 2. Each weight's hidden size can be different.
|
expert_weights_buffer: Intermediate buffers (one per weight tensor).
|
||||||
ep_group: The device process group for expert parallelism.
|
ep_group: The device process group for expert parallelism.
|
||||||
is_profile (bool): If `True`, do not perform any actual weight copy.
|
is_profile (bool): If `True`, do not perform any actual weight copy.
|
||||||
This is used during profile run, where we only perform dummy
|
This is used during profile run, where we only perform dummy
|
||||||
communications to reserve enough memory for the buffers.
|
communications to reserve enough memory for the buffers.
|
||||||
|
cuda_stream: CUDA stream for async copies (can be None for sync mode).
|
||||||
|
rank_mapping: Optional rank mapping for elastic expert parallelism.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
is_unchanged (np.ndarray): (num_local_experts,), True where expert
|
||||||
|
is left unchanged.
|
||||||
|
is_received_locally (np.ndarray): (num_local_experts,), True where expert
|
||||||
|
can be received locally.
|
||||||
|
RecvMetadata: Metadata needed for completing remote weight transfers.
|
||||||
"""
|
"""
|
||||||
ep_size = ep_group.size()
|
ep_size = ep_group.size()
|
||||||
if rank_mapping is not None:
|
if rank_mapping is not None:
|
||||||
|
# Add a layer dimension for compatibility with mapping functions
|
||||||
|
old_layer_indices_2d = old_layer_indices.unsqueeze(0)
|
||||||
|
new_layer_indices_2d = new_layer_indices.unsqueeze(0)
|
||||||
|
|
||||||
if len(rank_mapping) == ep_group.size():
|
if len(rank_mapping) == ep_group.size():
|
||||||
# scale down
|
# scale down
|
||||||
new_global_expert_indices = _map_new_expert_indices_with_rank_mapping(
|
new_layer_indices_2d = _map_new_expert_indices_with_rank_mapping(
|
||||||
new_global_expert_indices,
|
new_layer_indices_2d,
|
||||||
rank_mapping,
|
rank_mapping,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# scale up
|
# scale up
|
||||||
old_global_expert_indices = _map_old_expert_indices_with_rank_mapping(
|
old_layer_indices_2d = _map_old_expert_indices_with_rank_mapping(
|
||||||
old_global_expert_indices,
|
old_layer_indices_2d,
|
||||||
rank_mapping,
|
rank_mapping,
|
||||||
ep_group.size(),
|
ep_group.size(),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert old_global_expert_indices.shape[1] == new_global_expert_indices.shape[1]
|
# Remove the layer dimension
|
||||||
num_moe_layers, num_physical_experts = old_global_expert_indices.shape
|
old_layer_indices = old_layer_indices_2d.squeeze(0)
|
||||||
assert len(expert_weights) == num_moe_layers
|
new_layer_indices = new_layer_indices_2d.squeeze(0)
|
||||||
num_local_physical_experts = next(iter(expert_weights[0])).shape[0]
|
|
||||||
assert new_global_expert_indices.shape == (num_moe_layers, num_physical_experts)
|
assert old_layer_indices.shape == new_layer_indices.shape
|
||||||
|
num_physical_experts = old_layer_indices.shape[0]
|
||||||
|
num_local_physical_experts = next(iter(expert_weights)).shape[0]
|
||||||
assert num_physical_experts == ep_size * num_local_physical_experts
|
assert num_physical_experts == ep_size * num_local_physical_experts
|
||||||
|
|
||||||
is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer(
|
old_layer_indices_np = old_layer_indices.cpu().numpy()
|
||||||
|
new_layer_indices_np = new_layer_indices.cpu().numpy()
|
||||||
|
|
||||||
|
is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
|
||||||
num_local_experts=num_local_physical_experts,
|
num_local_experts=num_local_physical_experts,
|
||||||
old_indices=old_global_expert_indices[layer].tolist(),
|
old_indices=old_layer_indices_np,
|
||||||
new_indices=new_global_expert_indices[layer].tolist(),
|
new_indices=new_layer_indices_np,
|
||||||
expert_weights=expert_weights[layer],
|
expert_weights=expert_weights,
|
||||||
expert_weights_buffer=expert_weights_buffer,
|
expert_weights_buffers=expert_weights_buffer,
|
||||||
cuda_stream=cuda_stream,
|
cuda_stream=cuda_stream,
|
||||||
ep_group=ep_group,
|
ep_group=ep_group,
|
||||||
)
|
)
|
||||||
return is_unchanged, is_received_locally, experts_recv_loc
|
return is_unchanged, is_received_locally, recv_metadata
|
||||||
|
|
||||||
|
|
||||||
def rearrange_expert_weights_inplace(
|
def rearrange_expert_weights_inplace(
|
||||||
@ -388,19 +567,17 @@ def rearrange_expert_weights_inplace(
|
|||||||
ep_size = ep_group.size()
|
ep_size = ep_group.size()
|
||||||
assert num_physical_experts == ep_size * num_local_physical_experts
|
assert num_physical_experts == ep_size * num_local_physical_experts
|
||||||
|
|
||||||
# A buffer to hold the expert weights in one layer during the exchange.
|
first_layer_weights = list(expert_weights[0])
|
||||||
|
# Buffers to hold the expert weights during the exchange.
|
||||||
# NOTE: Currently we assume the same weights across different layers
|
# NOTE: Currently we assume the same weights across different layers
|
||||||
# have the same shape.
|
# have the same shape.
|
||||||
expert_weights_buffer = [torch.empty_like(w) for w in expert_weights[0]]
|
weights_buffer: list[torch.Tensor] = [
|
||||||
|
torch.empty_like(w) for w in first_layer_weights
|
||||||
|
]
|
||||||
if is_profile:
|
if is_profile:
|
||||||
# Maximum send size is to send all local experts to all ranks,
|
# Reserve communication buffers via a minimal dummy all_gather on first layer
|
||||||
# So we use a dummy `all_gather` to reserve enough communication buffer
|
for weight, buffer in zip(expert_weights[0], weights_buffer):
|
||||||
for weight, buffer in zip(expert_weights[0], expert_weights_buffer):
|
|
||||||
# A `/dev/null`-like buffer to avoid real memory allocation
|
|
||||||
dummy_recv_buffer = [buffer for _ in range(ep_size)]
|
dummy_recv_buffer = [buffer for _ in range(ep_size)]
|
||||||
# NOTE(bowen): Needed this barrier to avoid OOM during actual
|
|
||||||
# execution. I'm not very sure why this is needed
|
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
all_gather(
|
all_gather(
|
||||||
dummy_recv_buffer,
|
dummy_recv_buffer,
|
||||||
@ -409,32 +586,32 @@ def rearrange_expert_weights_inplace(
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
old_global_expert_indices_cpu = old_global_expert_indices.cpu()
|
|
||||||
new_global_expert_indices_cpu = new_global_expert_indices.cpu()
|
|
||||||
|
|
||||||
# NOTE(bowen): We need this synchronize to run, but I don't know why.
|
# NOTE(bowen): We need this synchronize to run, but I don't know why.
|
||||||
# If you figure out the reason, please let me know -- thank you!
|
# If you figure out the reason, please let me know -- thank you!
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
for layer in range(num_moe_layers):
|
old_global_expert_indices_cpu = old_global_expert_indices.cpu().numpy()
|
||||||
is_unchanged, is_received_locally, experts_recv_loc = move_to_buffer(
|
new_global_expert_indices_cpu = new_global_expert_indices.cpu().numpy()
|
||||||
|
|
||||||
|
for layer_idx in range(num_moe_layers):
|
||||||
|
is_unchanged, is_received_locally, recv_metadata = move_to_buffer(
|
||||||
num_local_experts=num_local_physical_experts,
|
num_local_experts=num_local_physical_experts,
|
||||||
old_indices=old_global_expert_indices_cpu[layer].tolist(),
|
old_indices=old_global_expert_indices_cpu[layer_idx],
|
||||||
new_indices=new_global_expert_indices_cpu[layer].tolist(),
|
new_indices=new_global_expert_indices_cpu[layer_idx],
|
||||||
expert_weights=expert_weights[layer],
|
expert_weights=expert_weights[layer_idx],
|
||||||
expert_weights_buffer=expert_weights_buffer,
|
expert_weights_buffers=weights_buffer,
|
||||||
cuda_stream=None,
|
cuda_stream=None,
|
||||||
ep_group=ep_group,
|
ep_group=ep_group,
|
||||||
)
|
)
|
||||||
|
|
||||||
move_from_buffer(
|
move_from_buffer(
|
||||||
expert_weights=expert_weights[layer],
|
expert_weights=expert_weights[layer_idx],
|
||||||
expert_weights_buffer=expert_weights_buffer,
|
expert_weights_buffers=weights_buffer,
|
||||||
is_unchanged=is_unchanged,
|
is_unchanged=is_unchanged,
|
||||||
is_received_locally=is_received_locally,
|
is_received_locally=is_received_locally,
|
||||||
experts_recv_loc=experts_recv_loc,
|
recv_metadata=recv_metadata,
|
||||||
new_indices=new_global_expert_indices[layer].tolist(),
|
new_indices=new_global_expert_indices_cpu[layer_idx],
|
||||||
ep_group=ep_group,
|
ep_rank=ep_group.rank(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -526,4 +703,4 @@ def _map_new_expert_indices_with_rank_mapping(
|
|||||||
return mapped_expert_indices
|
return mapped_expert_indices
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["transfer_layer", "move_from_buffer"]
|
__all__ = ["transfer_layer", "move_from_buffer", "RecvMetadata"]
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user