Abstract eplb algo (#26471)

Signed-off-by: Che Ruan <cr623@ic.ac.uk>
Signed-off-by: mengxingkongzhouhan <117415539+mengxingkongzhouhan@users.noreply.github.com>
Signed-off-by: Mercykid-bash <ruanche0218@gmail.com>
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
Co-authored-by: Che Ruan <cr623@ic.ac.uk>
Co-authored-by: mengxingkongzhouhan <117415539+mengxingkongzhouhan@users.noreply.github.com>
Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Mercykid-bash 2025-12-05 03:09:09 +08:00 committed by GitHub
parent e10c84e06a
commit 1119f6e47a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 364 additions and 285 deletions

View File

@ -4,7 +4,7 @@
import pytest
import torch
from vllm.distributed.eplb.rebalance_algo import rebalance_experts
from vllm.distributed.eplb.policy.default import DefaultEplbPolicy
def test_basic_rebalance():
@ -23,7 +23,7 @@ def test_basic_rebalance():
num_nodes = 2
num_gpus = 8
phy2log, log2phy, logcnt = rebalance_experts(
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
@ -77,7 +77,7 @@ def test_single_gpu_case():
num_nodes = 1
num_gpus = 1
phy2log, log2phy, logcnt = rebalance_experts(
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
@ -99,7 +99,7 @@ def test_equal_weights():
num_nodes = 2
num_gpus = 4
phy2log, log2phy, logcnt = rebalance_experts(
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
@ -122,7 +122,7 @@ def test_extreme_weight_imbalance():
num_nodes = 2
num_gpus = 4
phy2log, log2phy, logcnt = rebalance_experts(
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
@ -150,7 +150,7 @@ def test_multiple_layers():
num_nodes = 2
num_gpus = 4
phy2log, log2phy, logcnt = rebalance_experts(
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
@ -175,14 +175,14 @@ def test_parameter_validation():
# Test non-divisible case - this should handle normally without throwing
# errors because the function will fall back to global load balancing
# strategy
phy2log, log2phy, logcnt = rebalance_experts(weight, 8, 3, 2, 4)
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(weight, 8, 3, 2, 4)
assert phy2log.shape == (1, 8)
assert logcnt.shape == (1, 4)
# Test cases that will actually cause errors:
# num_physical_experts not divisible by num_gpus
with pytest.raises(AssertionError):
rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4
DefaultEplbPolicy.rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4
def test_small_scale_hierarchical():
@ -197,7 +197,7 @@ def test_small_scale_hierarchical():
num_nodes = 2 # 2 nodes
num_gpus = 4 # 4 GPUs
phy2log, log2phy, logcnt = rebalance_experts(
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
@ -224,7 +224,7 @@ def test_global_load_balance_fallback():
num_nodes = 2
num_gpus = 4
phy2log, log2phy, logcnt = rebalance_experts(
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
@ -246,7 +246,7 @@ def test_device_compatibility(device):
num_nodes = 1
num_gpus = 2
phy2log, log2phy, logcnt = rebalance_experts(
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
@ -263,7 +263,9 @@ def test_additional_cases():
weight1 = torch.tensor(
[[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]
)
phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8)
phy2log1, log2phy1, logcnt1 = DefaultEplbPolicy.rebalance_experts(
weight1, 24, 8, 4, 8
)
assert phy2log1.shape == (1, 24)
assert logcnt1.shape == (1, 16)
@ -276,7 +278,9 @@ def test_additional_cases():
[12, 25, 50, 100, 150, 200], # Increasing weights
]
)
phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2)
phy2log2, log2phy2, logcnt2 = DefaultEplbPolicy.rebalance_experts(
weight2, 10, 3, 1, 2
)
assert phy2log2.shape == (2, 10)
assert logcnt2.shape == (2, 6)
@ -300,7 +304,7 @@ if __name__ == "__main__":
num_nodes = 2
num_gpus = 8
phy2log, log2phy, logcnt = rebalance_experts(
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
print(phy2log)

View File

@ -35,6 +35,7 @@ logger = init_logger(__name__)
ExpertPlacementStrategy = Literal["linear", "round_robin"]
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
DataParallelBackend = Literal["ray", "mp"]
EPLBPolicyOption = Literal["default"]
@config
@ -65,6 +66,9 @@ class EPLBConfig:
Whether to use non-blocking EPLB.
"""
policy: EPLBPolicyOption = "default"
"""The policy type for expert parallel load balancing (EPLB)."""
@config
@dataclass

View File

@ -1,8 +1,3 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Expert parallelism load balancer (EPLB).
"""
from .eplb_state import *
from .rebalance_algo import *
"""Expert parallelism load balancer (EPLB)."""

View File

@ -45,7 +45,7 @@ from vllm.logger import init_logger
from vllm.model_executor.models.interfaces import MixtureOfExperts
from .async_worker import start_async_worker
from .rebalance_algo import rebalance_experts
from .policy import EPLB_POLICIES, AbstractEplbPolicy, DefaultEplbPolicy
from .rebalance_execute import move_from_buffer, rearrange_expert_weights_inplace
logger = init_logger(__name__)
@ -213,18 +213,23 @@ class EplbState:
self.parallel_config = parallel_config
self.device = device
self.model_states: dict[str, EplbModelState] = {}
self.policy: type[AbstractEplbPolicy] = DefaultEplbPolicy
"""
Selected EPLB algorithm class
"""
self.expert_load_window_step: int = 0
"""
Current step in the sliding window.
Different from `expert_rearrangement_step`,
each EP rank may have its own `expert_load_window_step`.
"""
self.expert_load_window_step: int = 0
self.expert_load_window_size: int = 0
"""
Size of the expert load sliding window.
This is a constant and is taken from the config.
"""
self.expert_load_window_size: int = 0
self.expert_rearrangement_step: int = 0
"""
Steps after last rearrangement.
Will trigger a rearrangement if it exceeds the threshold.
@ -415,6 +420,10 @@ class EplbState:
)
self.expert_rearrangement_step_interval = eplb_step_interval
# Set the policy based on the selected eplb algorithm type.
policy_type = self.parallel_config.eplb_config.policy
self.policy = EPLB_POLICIES[policy_type]
logger.debug("Selected EPLB policy: %d", policy_type)
if global_expert_load is not None:
ep_group = get_ep_group().device_group
assert global_expert_load.shape == (
@ -441,7 +450,7 @@ class EplbState:
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = rebalance_experts(
) = self.policy.rebalance_experts(
global_expert_load,
num_replicas,
num_groups,
@ -776,6 +785,7 @@ class EplbState:
f"{num_gpus=}, {num_nodes=}"
)
# Get new expert mappings
for eplb_model_state, global_expert_load_window in zip(
self.model_states.values(), global_expert_load_windows
):
@ -784,7 +794,7 @@ class EplbState:
new_physical_to_logical_map,
new_logical_to_physical_map,
new_logical_replica_count,
) = rebalance_experts(
) = self.policy.rebalance_experts(
global_expert_load_window,
num_replicas,
num_groups,

View File

@ -0,0 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import get_args
from vllm.config.parallel import EPLBPolicyOption
from .abstract import AbstractEplbPolicy
from .default import DefaultEplbPolicy
EPLB_POLICIES = {"default": DefaultEplbPolicy}
# Ensure that the EPLB_POLICIES keys match the EPLBPolicyOption values
assert set(EPLB_POLICIES.keys()) == set(get_args(EPLBPolicyOption))
__all__ = [
"AbstractEplbPolicy",
"DefaultEplbPolicy",
"EPLB_POLICIES",
]

View File

@ -0,0 +1,40 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
class AbstractEplbPolicy(ABC):
@classmethod
@abstractmethod
def rebalance_experts(
cls,
weight: torch.Tensor,
num_replicas: int,
num_groups: int,
num_nodes: int,
num_ranks: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Entry point for expert-parallelism load balancer.
Parameters:
weight: [layers, num_logical_experts], the load statistics
for all logical experts
num_replicas: number of physical experts, must be a multiple of
`num_ranks`
num_groups: number of expert groups
num_nodes: number of server nodes
num_ranks: number of ranks, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map: [layers, num_replicas], the expert
index of each replica
logical_to_physical_map: [layers, num_logical_experts, X],
the replica indices for each expert
expert_count: [layers, num_logical_experts], number of
physical replicas for each logical expert
"""
raise NotImplementedError

View File

@ -0,0 +1,267 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Expert parallelism load balancer (EPLB) for vLLM.
This module implements the core rearrangement algorithm.
The rearrangement algorithm is adapted from
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example
on how the EPLB algorithm works.
"""
import numpy as np
import torch
from .abstract import AbstractEplbPolicy
class DefaultEplbPolicy(AbstractEplbPolicy):
@classmethod
def balanced_packing(
cls, weight: torch.Tensor, num_packs: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Pack n weighted objects to m packs, such that each bin contains exactly
n/m objects and the weights of all packs are as balanced as possible.
Parameters:
weight: [X, n], the weight of each item
num_packs: number of packs
Returns:
pack_index: [X, n], the pack index of each item
rank_in_pack: [X, n], the rank of the item in the pack
"""
num_layers, num_groups = weight.shape
assert num_groups % num_packs == 0
groups_per_pack = num_groups // num_packs
device = weight.device
if groups_per_pack == 1:
pack_index = torch.arange(
weight.size(-1), dtype=torch.int64, device=device
).expand(weight.shape)
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device)
return pack_index, rank_in_pack
weight_np = weight.cpu().numpy()
# Sort and get indices in decending order
indices_np = np.argsort(-weight_np, axis=-1)
pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
# Run the packing algorithm
for i in range(num_layers):
pack_weights = [0.0] * num_packs
pack_items = [0] * num_packs
for group in indices_np[i]:
# Find a pack with capacity that has the lowest weight
pack = min(
(j for j in range(num_packs) if pack_items[j] < groups_per_pack),
key=pack_weights.__getitem__,
)
assert pack_items[pack] < groups_per_pack
pack_index_np[i, group] = pack
rank_in_pack_np[i, group] = pack_items[pack]
pack_weights[pack] += weight_np[i, group]
pack_items[pack] += 1
pack_index = torch.from_numpy(pack_index_np).to(device)
rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device)
return pack_index, rank_in_pack
@classmethod
def replicate_experts(
cls, weight: torch.Tensor, num_phy: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
load of all replicas is minimized.
Parameters:
weight: [X, num_log]
num_phy: total number of experts after replication
Returns:
phy2log: [X, num_phy], logical expert id of each physical expert
rank: [X, num_phy], the replica rank
logcnt: [X, num_log], number of replicas for each logical expert
"""
n, num_log = weight.shape
num_redundant = num_phy - num_log
assert num_redundant >= 0
device = weight.device
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
arangen = torch.arange(n, dtype=torch.int64, device=device)
for i in range(num_log, num_phy):
redundant_indices = (weight / logcnt).max(dim=-1).indices
phy2log[:, i] = redundant_indices
rank[:, i] = logcnt[arangen, redundant_indices]
logcnt[arangen, redundant_indices] += 1
return phy2log, rank, logcnt
@classmethod
def rebalance_experts_hierarchical(
cls,
weight: torch.Tensor,
num_physical_experts: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Parameters:
weight: [num_moe_layers, num_logical_experts]
num_physical_experts: number of physical experts after replication
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network
(e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
phy2log: [layers, num_replicas], the expert
index of each replica
log2phy: [layers, num_logical_experts, X],
the replica indices for each expert
logcnt: [layers, num_logical_experts], number of
physical replicas for each logical expert
"""
num_layers, num_logical_experts = weight.shape
assert num_logical_experts % num_groups == 0
group_size = num_logical_experts // num_groups
assert num_groups % num_nodes == 0
groups_per_node = num_groups // num_nodes
assert num_gpus % num_nodes == 0
assert num_physical_experts % num_gpus == 0
phy_experts_per_gpu = num_physical_experts // num_gpus
def inverse(perm: torch.Tensor) -> torch.Tensor:
inv = torch.empty_like(perm)
inv.scatter_(
1,
perm,
torch.arange(
perm.size(1), dtype=torch.int64, device=perm.device
).expand(perm.shape),
)
return inv
# Step 1: pack groups to nodes
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
group_pack_index, group_rank_in_pack = cls.balanced_packing(
tokens_per_group, num_nodes
)
log2mlog = (
(
(group_pack_index * groups_per_node + group_rank_in_pack) * group_size
).unsqueeze(-1)
+ torch.arange(
group_size, dtype=torch.int64, device=group_pack_index.device
)
).flatten(-2)
mlog2log = inverse(log2mlog)
# Step 2: construct redundant experts within nodes
# [num_layers * num_nodes, num_logical_experts // num_nodes]
tokens_per_mlog = weight.gather(-1, mlog2log).view(
-1, num_logical_experts // num_nodes
)
phy2mlog, phyrank, mlogcnt = cls.replicate_experts(
tokens_per_mlog, num_physical_experts // num_nodes
)
# Step 3: pack physical_experts to GPUs
# [num_layers * num_nodes, num_physical_experts // num_nodes]
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
pack_index, rank_in_pack = cls.balanced_packing(
tokens_per_phy, num_gpus // num_nodes
)
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
pphy2phy = inverse(phy2pphy)
pphy2mlog = phy2mlog.gather(
-1, pphy2phy
) # [num_layers * num_nodes, num_log_per_nodes]
pphy2mlog = (
pphy2mlog.view(num_layers, num_nodes, -1)
+ torch.arange(
0,
num_logical_experts,
num_logical_experts // num_nodes,
device=group_pack_index.device,
).view(1, -1, 1)
).flatten(-2)
pphy2log = mlog2log.gather(-1, pphy2mlog)
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
return pphy2log, pphyrank, logcnt
@classmethod
def rebalance_experts(
cls,
weight: torch.Tensor,
num_replicas: int,
num_groups: int,
num_nodes: int,
num_ranks: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Entry point for expert-parallelism load balancer.
Parameters:
weight: [layers, num_logical_experts], the load statistics for all
logical experts
num_replicas: number of physical experts, must be a multiple of
`num_gpus`
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network
(e.g, NVLink) is faster
num_ranks: number of ranks, must be a multiple of `num_nodes`
Returns:
phy2log: [layers, num_replicas], the expert
index of each replica
log2phy: [layers, num_logical_experts, X],
the replica indices for each expert
logcnt: [layers, num_logical_experts], number of
physical replicas for each logical expert
"""
num_layers, num_logical_experts = weight.shape
weight = weight.float()
if num_groups % num_nodes == 0:
# use hierarchical load-balance policy
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_ranks
)
else:
# use global load-balance policy
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical(
weight, num_replicas, 1, 1, num_ranks
)
num_redundant_experts = num_replicas - num_logical_experts
maxlogcnt = num_redundant_experts + 1
log2phy: torch.Tensor = torch.full(
(num_layers, num_logical_experts, maxlogcnt),
-1,
dtype=torch.int64,
device=logcnt.device,
)
log2phy.view(num_layers, -1).scatter_(
-1,
phy2log * maxlogcnt + phyrank,
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
num_layers, -1
),
)
return phy2log, log2phy, logcnt

View File

@ -1,260 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Expert parallelism load balancer (EPLB) for vLLM.
This module implements the core rearrangement algorithm.
The rearrangement algorithm is adapted from
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example
on how the EPLB algorithm works.
"""
import numpy as np
import torch
def balanced_packing(
weight: torch.Tensor, num_packs: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Pack n weighted objects to m packs, such that each bin contains exactly
n/m objects and the weights of all packs are as balanced as possible.
Parameters:
weight: [X, n], the weight of each item
num_packs: number of packs
Returns:
pack_index: [X, n], the pack index of each item
rank_in_pack: [X, n], the rank of the item in the pack
"""
num_layers, num_groups = weight.shape
assert num_groups % num_packs == 0
groups_per_pack = num_groups // num_packs
device = weight.device
if groups_per_pack == 1:
pack_index = torch.arange(
weight.size(-1), dtype=torch.int64, device=device
).expand(weight.shape)
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device)
return pack_index, rank_in_pack
weight_np = weight.cpu().numpy()
# Sort and get indices in decending order
indices_np = np.argsort(-weight_np, axis=-1)
pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
# Run the packing algorithm
for i in range(num_layers):
pack_weights = [0.0] * num_packs
pack_items = [0] * num_packs
for group in indices_np[i]:
# Find a pack with capacity that has the lowest weight
pack = min(
(j for j in range(num_packs) if pack_items[j] < groups_per_pack),
key=pack_weights.__getitem__,
)
assert pack_items[pack] < groups_per_pack
pack_index_np[i, group] = pack
rank_in_pack_np[i, group] = pack_items[pack]
pack_weights[pack] += weight_np[i, group]
pack_items[pack] += 1
pack_index = torch.from_numpy(pack_index_np).to(device)
rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device)
return pack_index, rank_in_pack
def replicate_experts(
weight: torch.Tensor, num_phy: int
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
load of all replicas is minimized.
Parameters:
weight: [X, num_log]
num_phy: total number of experts after replication
Returns:
phy2log: [X, num_phy], logical expert id of each physical expert
rank: [X, num_phy], the replica rank
logcnt: [X, num_log], number of replicas for each logical expert
"""
n, num_log = weight.shape
num_redundant = num_phy - num_log
assert num_redundant >= 0
device = weight.device
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
arangen = torch.arange(n, dtype=torch.int64, device=device)
for i in range(num_log, num_phy):
redundant_indices = (weight / logcnt).max(dim=-1).indices
phy2log[:, i] = redundant_indices
rank[:, i] = logcnt[arangen, redundant_indices]
logcnt[arangen, redundant_indices] += 1
return phy2log, rank, logcnt
def rebalance_experts_hierarchical(
weight: torch.Tensor,
num_physical_experts: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Parameters:
weight: [num_moe_layers, num_logical_experts]
num_physical_experts: number of physical experts after replication
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network
(e.g., NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map (torch.Tensor):
[num_moe_layers, num_physical_experts]
logical_to_physical_map (torch.Tensor):
[num_moe_layers, num_logical_experts, X]
logical_count (torch.Tensor):
[num_moe_layers, num_logical_experts]
"""
num_layers, num_logical_experts = weight.shape
assert num_logical_experts % num_groups == 0
group_size = num_logical_experts // num_groups
assert num_groups % num_nodes == 0
groups_per_node = num_groups // num_nodes
assert num_gpus % num_nodes == 0
assert num_physical_experts % num_gpus == 0
phy_experts_per_gpu = num_physical_experts // num_gpus
def inverse(perm: torch.Tensor) -> torch.Tensor:
inv = torch.empty_like(perm)
inv.scatter_(
1,
perm,
torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(
perm.shape
),
)
return inv
# Step 1: pack groups to nodes
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes)
log2mlog = (
(
(group_pack_index * groups_per_node + group_rank_in_pack) * group_size
).unsqueeze(-1)
+ torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)
).flatten(-2)
mlog2log = inverse(log2mlog)
# Step 2: construct redundant experts within nodes
# [num_layers * num_nodes, num_logical_experts // num_nodes]
tokens_per_mlog = weight.gather(-1, mlog2log).view(
-1, num_logical_experts // num_nodes
)
phy2mlog, phyrank, mlogcnt = replicate_experts(
tokens_per_mlog, num_physical_experts // num_nodes
)
# Step 3: pack physical_experts to GPUs
# [num_layers * num_nodes, num_physical_experts // num_nodes]
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes)
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
pphy2phy = inverse(phy2pphy)
pphy2mlog = phy2mlog.gather(
-1, pphy2phy
) # [num_layers * num_nodes, num_log_per_nodes]
pphy2mlog = (
pphy2mlog.view(num_layers, num_nodes, -1)
+ torch.arange(
0,
num_logical_experts,
num_logical_experts // num_nodes,
device=group_pack_index.device,
).view(1, -1, 1)
).flatten(-2)
pphy2log = mlog2log.gather(-1, pphy2mlog)
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
return pphy2log, pphyrank, logcnt
def rebalance_experts(
weight: torch.Tensor,
num_replicas: int,
num_groups: int,
num_nodes: int,
num_gpus: int,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Entry point for expert-parallelism load balancer.
Parameters:
weight: [layers, num_logical_experts], the load statistics for all
logical experts
num_replicas: number of physical experts, must be a multiple of
`num_gpus`
num_groups: number of expert groups
num_nodes: number of server nodes, where the intra-node network
(e.g, NVLink) is faster
num_gpus: number of GPUs, must be a multiple of `num_nodes`
Returns:
physical_to_logical_map:
[layers, num_replicas], the expert index of each replica
logical_to_physical_map:
[layers, num_logical_experts, X], the replica indices for each
expert
expert_count:
[layers, num_logical_experts], number of physical
replicas for each logical expert
"""
num_layers, num_logical_experts = weight.shape
weight = weight.float()
if num_groups % num_nodes == 0:
# use hierarchical load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, num_groups, num_nodes, num_gpus
)
else:
# use global load-balance policy
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
weight, num_replicas, 1, 1, num_gpus
)
num_redundant_experts = num_replicas - num_logical_experts
maxlogcnt = num_redundant_experts + 1
log2phy: torch.Tensor = torch.full(
(num_layers, num_logical_experts, maxlogcnt),
-1,
dtype=torch.int64,
device=logcnt.device,
)
log2phy.view(num_layers, -1).scatter_(
-1,
phy2log * maxlogcnt + phyrank,
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
num_layers, -1
),
)
return phy2log, log2phy, logcnt
__all__ = ["rebalance_experts"]