mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-10 08:04:58 +08:00
Abstract eplb algo (#26471)
Signed-off-by: Che Ruan <cr623@ic.ac.uk> Signed-off-by: mengxingkongzhouhan <117415539+mengxingkongzhouhan@users.noreply.github.com> Signed-off-by: Mercykid-bash <ruanche0218@gmail.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Che Ruan <cr623@ic.ac.uk> Co-authored-by: mengxingkongzhouhan <117415539+mengxingkongzhouhan@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
parent
e10c84e06a
commit
1119f6e47a
@ -4,7 +4,7 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.distributed.eplb.rebalance_algo import rebalance_experts
|
||||
from vllm.distributed.eplb.policy.default import DefaultEplbPolicy
|
||||
|
||||
|
||||
def test_basic_rebalance():
|
||||
@ -23,7 +23,7 @@ def test_basic_rebalance():
|
||||
num_nodes = 2
|
||||
num_gpus = 8
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
@ -77,7 +77,7 @@ def test_single_gpu_case():
|
||||
num_nodes = 1
|
||||
num_gpus = 1
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
@ -99,7 +99,7 @@ def test_equal_weights():
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
@ -122,7 +122,7 @@ def test_extreme_weight_imbalance():
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
@ -150,7 +150,7 @@ def test_multiple_layers():
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
@ -175,14 +175,14 @@ def test_parameter_validation():
|
||||
# Test non-divisible case - this should handle normally without throwing
|
||||
# errors because the function will fall back to global load balancing
|
||||
# strategy
|
||||
phy2log, log2phy, logcnt = rebalance_experts(weight, 8, 3, 2, 4)
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(weight, 8, 3, 2, 4)
|
||||
assert phy2log.shape == (1, 8)
|
||||
assert logcnt.shape == (1, 4)
|
||||
|
||||
# Test cases that will actually cause errors:
|
||||
# num_physical_experts not divisible by num_gpus
|
||||
with pytest.raises(AssertionError):
|
||||
rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4
|
||||
DefaultEplbPolicy.rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4
|
||||
|
||||
|
||||
def test_small_scale_hierarchical():
|
||||
@ -197,7 +197,7 @@ def test_small_scale_hierarchical():
|
||||
num_nodes = 2 # 2 nodes
|
||||
num_gpus = 4 # 4 GPUs
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
@ -224,7 +224,7 @@ def test_global_load_balance_fallback():
|
||||
num_nodes = 2
|
||||
num_gpus = 4
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
@ -246,7 +246,7 @@ def test_device_compatibility(device):
|
||||
num_nodes = 1
|
||||
num_gpus = 2
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
|
||||
@ -263,7 +263,9 @@ def test_additional_cases():
|
||||
weight1 = torch.tensor(
|
||||
[[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]]
|
||||
)
|
||||
phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8)
|
||||
phy2log1, log2phy1, logcnt1 = DefaultEplbPolicy.rebalance_experts(
|
||||
weight1, 24, 8, 4, 8
|
||||
)
|
||||
|
||||
assert phy2log1.shape == (1, 24)
|
||||
assert logcnt1.shape == (1, 16)
|
||||
@ -276,7 +278,9 @@ def test_additional_cases():
|
||||
[12, 25, 50, 100, 150, 200], # Increasing weights
|
||||
]
|
||||
)
|
||||
phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2)
|
||||
phy2log2, log2phy2, logcnt2 = DefaultEplbPolicy.rebalance_experts(
|
||||
weight2, 10, 3, 1, 2
|
||||
)
|
||||
|
||||
assert phy2log2.shape == (2, 10)
|
||||
assert logcnt2.shape == (2, 6)
|
||||
@ -300,7 +304,7 @@ if __name__ == "__main__":
|
||||
num_nodes = 2
|
||||
num_gpus = 8
|
||||
|
||||
phy2log, log2phy, logcnt = rebalance_experts(
|
||||
phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
print(phy2log)
|
||||
|
||||
@ -35,6 +35,7 @@ logger = init_logger(__name__)
|
||||
ExpertPlacementStrategy = Literal["linear", "round_robin"]
|
||||
DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"]
|
||||
DataParallelBackend = Literal["ray", "mp"]
|
||||
EPLBPolicyOption = Literal["default"]
|
||||
|
||||
|
||||
@config
|
||||
@ -65,6 +66,9 @@ class EPLBConfig:
|
||||
Whether to use non-blocking EPLB.
|
||||
"""
|
||||
|
||||
policy: EPLBPolicyOption = "default"
|
||||
"""The policy type for expert parallel load balancing (EPLB)."""
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
|
||||
@ -1,8 +1,3 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Expert parallelism load balancer (EPLB).
|
||||
"""
|
||||
|
||||
from .eplb_state import *
|
||||
from .rebalance_algo import *
|
||||
"""Expert parallelism load balancer (EPLB)."""
|
||||
|
||||
@ -45,7 +45,7 @@ from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.interfaces import MixtureOfExperts
|
||||
|
||||
from .async_worker import start_async_worker
|
||||
from .rebalance_algo import rebalance_experts
|
||||
from .policy import EPLB_POLICIES, AbstractEplbPolicy, DefaultEplbPolicy
|
||||
from .rebalance_execute import move_from_buffer, rearrange_expert_weights_inplace
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -213,18 +213,23 @@ class EplbState:
|
||||
self.parallel_config = parallel_config
|
||||
self.device = device
|
||||
self.model_states: dict[str, EplbModelState] = {}
|
||||
self.policy: type[AbstractEplbPolicy] = DefaultEplbPolicy
|
||||
"""
|
||||
Selected EPLB algorithm class
|
||||
"""
|
||||
self.expert_load_window_step: int = 0
|
||||
"""
|
||||
Current step in the sliding window.
|
||||
|
||||
Different from `expert_rearrangement_step`,
|
||||
each EP rank may have its own `expert_load_window_step`.
|
||||
"""
|
||||
self.expert_load_window_step: int = 0
|
||||
self.expert_load_window_size: int = 0
|
||||
"""
|
||||
Size of the expert load sliding window.
|
||||
This is a constant and is taken from the config.
|
||||
"""
|
||||
self.expert_load_window_size: int = 0
|
||||
self.expert_rearrangement_step: int = 0
|
||||
"""
|
||||
Steps after last rearrangement.
|
||||
Will trigger a rearrangement if it exceeds the threshold.
|
||||
@ -415,6 +420,10 @@ class EplbState:
|
||||
)
|
||||
self.expert_rearrangement_step_interval = eplb_step_interval
|
||||
|
||||
# Set the policy based on the selected eplb algorithm type.
|
||||
policy_type = self.parallel_config.eplb_config.policy
|
||||
self.policy = EPLB_POLICIES[policy_type]
|
||||
logger.debug("Selected EPLB policy: %d", policy_type)
|
||||
if global_expert_load is not None:
|
||||
ep_group = get_ep_group().device_group
|
||||
assert global_expert_load.shape == (
|
||||
@ -441,7 +450,7 @@ class EplbState:
|
||||
new_physical_to_logical_map,
|
||||
new_logical_to_physical_map,
|
||||
new_logical_replica_count,
|
||||
) = rebalance_experts(
|
||||
) = self.policy.rebalance_experts(
|
||||
global_expert_load,
|
||||
num_replicas,
|
||||
num_groups,
|
||||
@ -776,6 +785,7 @@ class EplbState:
|
||||
f"{num_gpus=}, {num_nodes=}"
|
||||
)
|
||||
|
||||
# Get new expert mappings
|
||||
for eplb_model_state, global_expert_load_window in zip(
|
||||
self.model_states.values(), global_expert_load_windows
|
||||
):
|
||||
@ -784,7 +794,7 @@ class EplbState:
|
||||
new_physical_to_logical_map,
|
||||
new_logical_to_physical_map,
|
||||
new_logical_replica_count,
|
||||
) = rebalance_experts(
|
||||
) = self.policy.rebalance_experts(
|
||||
global_expert_load_window,
|
||||
num_replicas,
|
||||
num_groups,
|
||||
|
||||
19
vllm/distributed/eplb/policy/__init__.py
Normal file
19
vllm/distributed/eplb/policy/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import get_args
|
||||
|
||||
from vllm.config.parallel import EPLBPolicyOption
|
||||
|
||||
from .abstract import AbstractEplbPolicy
|
||||
from .default import DefaultEplbPolicy
|
||||
|
||||
EPLB_POLICIES = {"default": DefaultEplbPolicy}
|
||||
|
||||
# Ensure that the EPLB_POLICIES keys match the EPLBPolicyOption values
|
||||
assert set(EPLB_POLICIES.keys()) == set(get_args(EPLBPolicyOption))
|
||||
|
||||
__all__ = [
|
||||
"AbstractEplbPolicy",
|
||||
"DefaultEplbPolicy",
|
||||
"EPLB_POLICIES",
|
||||
]
|
||||
40
vllm/distributed/eplb/policy/abstract.py
Normal file
40
vllm/distributed/eplb/policy/abstract.py
Normal file
@ -0,0 +1,40 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class AbstractEplbPolicy(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def rebalance_experts(
|
||||
cls,
|
||||
weight: torch.Tensor,
|
||||
num_replicas: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
num_ranks: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Entry point for expert-parallelism load balancer.
|
||||
|
||||
Parameters:
|
||||
weight: [layers, num_logical_experts], the load statistics
|
||||
for all logical experts
|
||||
num_replicas: number of physical experts, must be a multiple of
|
||||
`num_ranks`
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes
|
||||
num_ranks: number of ranks, must be a multiple of `num_nodes`
|
||||
|
||||
Returns:
|
||||
physical_to_logical_map: [layers, num_replicas], the expert
|
||||
index of each replica
|
||||
logical_to_physical_map: [layers, num_logical_experts, X],
|
||||
the replica indices for each expert
|
||||
expert_count: [layers, num_logical_experts], number of
|
||||
physical replicas for each logical expert
|
||||
"""
|
||||
raise NotImplementedError
|
||||
267
vllm/distributed/eplb/policy/default.py
Normal file
267
vllm/distributed/eplb/policy/default.py
Normal file
@ -0,0 +1,267 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Expert parallelism load balancer (EPLB) for vLLM.
|
||||
|
||||
This module implements the core rearrangement algorithm.
|
||||
|
||||
The rearrangement algorithm is adapted from
|
||||
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
|
||||
|
||||
Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example
|
||||
on how the EPLB algorithm works.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .abstract import AbstractEplbPolicy
|
||||
|
||||
|
||||
class DefaultEplbPolicy(AbstractEplbPolicy):
|
||||
@classmethod
|
||||
def balanced_packing(
|
||||
cls, weight: torch.Tensor, num_packs: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Pack n weighted objects to m packs, such that each bin contains exactly
|
||||
n/m objects and the weights of all packs are as balanced as possible.
|
||||
|
||||
Parameters:
|
||||
weight: [X, n], the weight of each item
|
||||
num_packs: number of packs
|
||||
|
||||
Returns:
|
||||
pack_index: [X, n], the pack index of each item
|
||||
rank_in_pack: [X, n], the rank of the item in the pack
|
||||
"""
|
||||
num_layers, num_groups = weight.shape
|
||||
assert num_groups % num_packs == 0
|
||||
groups_per_pack = num_groups // num_packs
|
||||
|
||||
device = weight.device
|
||||
|
||||
if groups_per_pack == 1:
|
||||
pack_index = torch.arange(
|
||||
weight.size(-1), dtype=torch.int64, device=device
|
||||
).expand(weight.shape)
|
||||
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device)
|
||||
return pack_index, rank_in_pack
|
||||
|
||||
weight_np = weight.cpu().numpy()
|
||||
|
||||
# Sort and get indices in decending order
|
||||
indices_np = np.argsort(-weight_np, axis=-1)
|
||||
|
||||
pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
|
||||
rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
|
||||
|
||||
# Run the packing algorithm
|
||||
for i in range(num_layers):
|
||||
pack_weights = [0.0] * num_packs
|
||||
pack_items = [0] * num_packs
|
||||
|
||||
for group in indices_np[i]:
|
||||
# Find a pack with capacity that has the lowest weight
|
||||
pack = min(
|
||||
(j for j in range(num_packs) if pack_items[j] < groups_per_pack),
|
||||
key=pack_weights.__getitem__,
|
||||
)
|
||||
|
||||
assert pack_items[pack] < groups_per_pack
|
||||
pack_index_np[i, group] = pack
|
||||
rank_in_pack_np[i, group] = pack_items[pack]
|
||||
pack_weights[pack] += weight_np[i, group]
|
||||
pack_items[pack] += 1
|
||||
|
||||
pack_index = torch.from_numpy(pack_index_np).to(device)
|
||||
rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device)
|
||||
|
||||
return pack_index, rank_in_pack
|
||||
|
||||
@classmethod
|
||||
def replicate_experts(
|
||||
cls, weight: torch.Tensor, num_phy: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
|
||||
load of all replicas is minimized.
|
||||
|
||||
Parameters:
|
||||
weight: [X, num_log]
|
||||
num_phy: total number of experts after replication
|
||||
|
||||
Returns:
|
||||
phy2log: [X, num_phy], logical expert id of each physical expert
|
||||
rank: [X, num_phy], the replica rank
|
||||
logcnt: [X, num_log], number of replicas for each logical expert
|
||||
"""
|
||||
n, num_log = weight.shape
|
||||
num_redundant = num_phy - num_log
|
||||
assert num_redundant >= 0
|
||||
device = weight.device
|
||||
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
|
||||
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
|
||||
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
|
||||
arangen = torch.arange(n, dtype=torch.int64, device=device)
|
||||
for i in range(num_log, num_phy):
|
||||
redundant_indices = (weight / logcnt).max(dim=-1).indices
|
||||
phy2log[:, i] = redundant_indices
|
||||
rank[:, i] = logcnt[arangen, redundant_indices]
|
||||
logcnt[arangen, redundant_indices] += 1
|
||||
return phy2log, rank, logcnt
|
||||
|
||||
@classmethod
|
||||
def rebalance_experts_hierarchical(
|
||||
cls,
|
||||
weight: torch.Tensor,
|
||||
num_physical_experts: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
num_gpus: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Parameters:
|
||||
weight: [num_moe_layers, num_logical_experts]
|
||||
num_physical_experts: number of physical experts after replication
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes, where the intra-node network
|
||||
(e.g, NVLink) is faster
|
||||
num_gpus: number of GPUs, must be a multiple of `num_nodes`
|
||||
|
||||
Returns:
|
||||
phy2log: [layers, num_replicas], the expert
|
||||
index of each replica
|
||||
log2phy: [layers, num_logical_experts, X],
|
||||
the replica indices for each expert
|
||||
logcnt: [layers, num_logical_experts], number of
|
||||
physical replicas for each logical expert
|
||||
"""
|
||||
num_layers, num_logical_experts = weight.shape
|
||||
assert num_logical_experts % num_groups == 0
|
||||
group_size = num_logical_experts // num_groups
|
||||
assert num_groups % num_nodes == 0
|
||||
groups_per_node = num_groups // num_nodes
|
||||
assert num_gpus % num_nodes == 0
|
||||
assert num_physical_experts % num_gpus == 0
|
||||
phy_experts_per_gpu = num_physical_experts // num_gpus
|
||||
|
||||
def inverse(perm: torch.Tensor) -> torch.Tensor:
|
||||
inv = torch.empty_like(perm)
|
||||
inv.scatter_(
|
||||
1,
|
||||
perm,
|
||||
torch.arange(
|
||||
perm.size(1), dtype=torch.int64, device=perm.device
|
||||
).expand(perm.shape),
|
||||
)
|
||||
return inv
|
||||
|
||||
# Step 1: pack groups to nodes
|
||||
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
|
||||
group_pack_index, group_rank_in_pack = cls.balanced_packing(
|
||||
tokens_per_group, num_nodes
|
||||
)
|
||||
log2mlog = (
|
||||
(
|
||||
(group_pack_index * groups_per_node + group_rank_in_pack) * group_size
|
||||
).unsqueeze(-1)
|
||||
+ torch.arange(
|
||||
group_size, dtype=torch.int64, device=group_pack_index.device
|
||||
)
|
||||
).flatten(-2)
|
||||
mlog2log = inverse(log2mlog)
|
||||
|
||||
# Step 2: construct redundant experts within nodes
|
||||
# [num_layers * num_nodes, num_logical_experts // num_nodes]
|
||||
tokens_per_mlog = weight.gather(-1, mlog2log).view(
|
||||
-1, num_logical_experts // num_nodes
|
||||
)
|
||||
phy2mlog, phyrank, mlogcnt = cls.replicate_experts(
|
||||
tokens_per_mlog, num_physical_experts // num_nodes
|
||||
)
|
||||
|
||||
# Step 3: pack physical_experts to GPUs
|
||||
# [num_layers * num_nodes, num_physical_experts // num_nodes]
|
||||
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
|
||||
pack_index, rank_in_pack = cls.balanced_packing(
|
||||
tokens_per_phy, num_gpus // num_nodes
|
||||
)
|
||||
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
|
||||
pphy2phy = inverse(phy2pphy)
|
||||
|
||||
pphy2mlog = phy2mlog.gather(
|
||||
-1, pphy2phy
|
||||
) # [num_layers * num_nodes, num_log_per_nodes]
|
||||
pphy2mlog = (
|
||||
pphy2mlog.view(num_layers, num_nodes, -1)
|
||||
+ torch.arange(
|
||||
0,
|
||||
num_logical_experts,
|
||||
num_logical_experts // num_nodes,
|
||||
device=group_pack_index.device,
|
||||
).view(1, -1, 1)
|
||||
).flatten(-2)
|
||||
pphy2log = mlog2log.gather(-1, pphy2mlog)
|
||||
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
|
||||
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
|
||||
return pphy2log, pphyrank, logcnt
|
||||
|
||||
@classmethod
|
||||
def rebalance_experts(
|
||||
cls,
|
||||
weight: torch.Tensor,
|
||||
num_replicas: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
num_ranks: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Entry point for expert-parallelism load balancer.
|
||||
|
||||
Parameters:
|
||||
weight: [layers, num_logical_experts], the load statistics for all
|
||||
logical experts
|
||||
num_replicas: number of physical experts, must be a multiple of
|
||||
`num_gpus`
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes, where the intra-node network
|
||||
(e.g, NVLink) is faster
|
||||
num_ranks: number of ranks, must be a multiple of `num_nodes`
|
||||
|
||||
Returns:
|
||||
phy2log: [layers, num_replicas], the expert
|
||||
index of each replica
|
||||
log2phy: [layers, num_logical_experts, X],
|
||||
the replica indices for each expert
|
||||
logcnt: [layers, num_logical_experts], number of
|
||||
physical replicas for each logical expert
|
||||
"""
|
||||
num_layers, num_logical_experts = weight.shape
|
||||
weight = weight.float()
|
||||
if num_groups % num_nodes == 0:
|
||||
# use hierarchical load-balance policy
|
||||
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical(
|
||||
weight, num_replicas, num_groups, num_nodes, num_ranks
|
||||
)
|
||||
else:
|
||||
# use global load-balance policy
|
||||
phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical(
|
||||
weight, num_replicas, 1, 1, num_ranks
|
||||
)
|
||||
num_redundant_experts = num_replicas - num_logical_experts
|
||||
maxlogcnt = num_redundant_experts + 1
|
||||
log2phy: torch.Tensor = torch.full(
|
||||
(num_layers, num_logical_experts, maxlogcnt),
|
||||
-1,
|
||||
dtype=torch.int64,
|
||||
device=logcnt.device,
|
||||
)
|
||||
log2phy.view(num_layers, -1).scatter_(
|
||||
-1,
|
||||
phy2log * maxlogcnt + phyrank,
|
||||
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
|
||||
num_layers, -1
|
||||
),
|
||||
)
|
||||
return phy2log, log2phy, logcnt
|
||||
@ -1,260 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Expert parallelism load balancer (EPLB) for vLLM.
|
||||
|
||||
This module implements the core rearrangement algorithm.
|
||||
|
||||
The rearrangement algorithm is adapted from
|
||||
[DeepSeek EPLB](https://github.com/deepseek-ai/eplb).
|
||||
|
||||
Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example
|
||||
on how the EPLB algorithm works.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def balanced_packing(
|
||||
weight: torch.Tensor, num_packs: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Pack n weighted objects to m packs, such that each bin contains exactly
|
||||
n/m objects and the weights of all packs are as balanced as possible.
|
||||
|
||||
Parameters:
|
||||
weight: [X, n], the weight of each item
|
||||
num_packs: number of packs
|
||||
|
||||
Returns:
|
||||
pack_index: [X, n], the pack index of each item
|
||||
rank_in_pack: [X, n], the rank of the item in the pack
|
||||
"""
|
||||
num_layers, num_groups = weight.shape
|
||||
assert num_groups % num_packs == 0
|
||||
groups_per_pack = num_groups // num_packs
|
||||
|
||||
device = weight.device
|
||||
|
||||
if groups_per_pack == 1:
|
||||
pack_index = torch.arange(
|
||||
weight.size(-1), dtype=torch.int64, device=device
|
||||
).expand(weight.shape)
|
||||
rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device)
|
||||
return pack_index, rank_in_pack
|
||||
|
||||
weight_np = weight.cpu().numpy()
|
||||
|
||||
# Sort and get indices in decending order
|
||||
indices_np = np.argsort(-weight_np, axis=-1)
|
||||
|
||||
pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
|
||||
rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64)
|
||||
|
||||
# Run the packing algorithm
|
||||
for i in range(num_layers):
|
||||
pack_weights = [0.0] * num_packs
|
||||
pack_items = [0] * num_packs
|
||||
|
||||
for group in indices_np[i]:
|
||||
# Find a pack with capacity that has the lowest weight
|
||||
pack = min(
|
||||
(j for j in range(num_packs) if pack_items[j] < groups_per_pack),
|
||||
key=pack_weights.__getitem__,
|
||||
)
|
||||
|
||||
assert pack_items[pack] < groups_per_pack
|
||||
pack_index_np[i, group] = pack
|
||||
rank_in_pack_np[i, group] = pack_items[pack]
|
||||
pack_weights[pack] += weight_np[i, group]
|
||||
pack_items[pack] += 1
|
||||
|
||||
pack_index = torch.from_numpy(pack_index_np).to(device)
|
||||
rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device)
|
||||
|
||||
return pack_index, rank_in_pack
|
||||
|
||||
|
||||
def replicate_experts(
|
||||
weight: torch.Tensor, num_phy: int
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Replicate `num_log` experts to `num_phy` replicas, such that the maximum
|
||||
load of all replicas is minimized.
|
||||
|
||||
Parameters:
|
||||
weight: [X, num_log]
|
||||
num_phy: total number of experts after replication
|
||||
|
||||
Returns:
|
||||
phy2log: [X, num_phy], logical expert id of each physical expert
|
||||
rank: [X, num_phy], the replica rank
|
||||
logcnt: [X, num_log], number of replicas for each logical expert
|
||||
"""
|
||||
n, num_log = weight.shape
|
||||
num_redundant = num_phy - num_log
|
||||
assert num_redundant >= 0
|
||||
device = weight.device
|
||||
phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1)
|
||||
rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device)
|
||||
logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device)
|
||||
arangen = torch.arange(n, dtype=torch.int64, device=device)
|
||||
for i in range(num_log, num_phy):
|
||||
redundant_indices = (weight / logcnt).max(dim=-1).indices
|
||||
phy2log[:, i] = redundant_indices
|
||||
rank[:, i] = logcnt[arangen, redundant_indices]
|
||||
logcnt[arangen, redundant_indices] += 1
|
||||
return phy2log, rank, logcnt
|
||||
|
||||
|
||||
def rebalance_experts_hierarchical(
|
||||
weight: torch.Tensor,
|
||||
num_physical_experts: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
num_gpus: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Parameters:
|
||||
weight: [num_moe_layers, num_logical_experts]
|
||||
num_physical_experts: number of physical experts after replication
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes, where the intra-node network
|
||||
(e.g., NVLink) is faster
|
||||
num_gpus: number of GPUs, must be a multiple of `num_nodes`
|
||||
|
||||
Returns:
|
||||
physical_to_logical_map (torch.Tensor):
|
||||
[num_moe_layers, num_physical_experts]
|
||||
logical_to_physical_map (torch.Tensor):
|
||||
[num_moe_layers, num_logical_experts, X]
|
||||
logical_count (torch.Tensor):
|
||||
[num_moe_layers, num_logical_experts]
|
||||
"""
|
||||
num_layers, num_logical_experts = weight.shape
|
||||
assert num_logical_experts % num_groups == 0
|
||||
group_size = num_logical_experts // num_groups
|
||||
assert num_groups % num_nodes == 0
|
||||
groups_per_node = num_groups // num_nodes
|
||||
assert num_gpus % num_nodes == 0
|
||||
assert num_physical_experts % num_gpus == 0
|
||||
phy_experts_per_gpu = num_physical_experts // num_gpus
|
||||
|
||||
def inverse(perm: torch.Tensor) -> torch.Tensor:
|
||||
inv = torch.empty_like(perm)
|
||||
inv.scatter_(
|
||||
1,
|
||||
perm,
|
||||
torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand(
|
||||
perm.shape
|
||||
),
|
||||
)
|
||||
return inv
|
||||
|
||||
# Step 1: pack groups to nodes
|
||||
tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1)
|
||||
group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes)
|
||||
log2mlog = (
|
||||
(
|
||||
(group_pack_index * groups_per_node + group_rank_in_pack) * group_size
|
||||
).unsqueeze(-1)
|
||||
+ torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device)
|
||||
).flatten(-2)
|
||||
mlog2log = inverse(log2mlog)
|
||||
|
||||
# Step 2: construct redundant experts within nodes
|
||||
# [num_layers * num_nodes, num_logical_experts // num_nodes]
|
||||
tokens_per_mlog = weight.gather(-1, mlog2log).view(
|
||||
-1, num_logical_experts // num_nodes
|
||||
)
|
||||
phy2mlog, phyrank, mlogcnt = replicate_experts(
|
||||
tokens_per_mlog, num_physical_experts // num_nodes
|
||||
)
|
||||
|
||||
# Step 3: pack physical_experts to GPUs
|
||||
# [num_layers * num_nodes, num_physical_experts // num_nodes]
|
||||
tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog)
|
||||
pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes)
|
||||
phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack
|
||||
pphy2phy = inverse(phy2pphy)
|
||||
|
||||
pphy2mlog = phy2mlog.gather(
|
||||
-1, pphy2phy
|
||||
) # [num_layers * num_nodes, num_log_per_nodes]
|
||||
pphy2mlog = (
|
||||
pphy2mlog.view(num_layers, num_nodes, -1)
|
||||
+ torch.arange(
|
||||
0,
|
||||
num_logical_experts,
|
||||
num_logical_experts // num_nodes,
|
||||
device=group_pack_index.device,
|
||||
).view(1, -1, 1)
|
||||
).flatten(-2)
|
||||
pphy2log = mlog2log.gather(-1, pphy2mlog)
|
||||
pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1)
|
||||
logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog)
|
||||
return pphy2log, pphyrank, logcnt
|
||||
|
||||
|
||||
def rebalance_experts(
|
||||
weight: torch.Tensor,
|
||||
num_replicas: int,
|
||||
num_groups: int,
|
||||
num_nodes: int,
|
||||
num_gpus: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Entry point for expert-parallelism load balancer.
|
||||
|
||||
Parameters:
|
||||
weight: [layers, num_logical_experts], the load statistics for all
|
||||
logical experts
|
||||
num_replicas: number of physical experts, must be a multiple of
|
||||
`num_gpus`
|
||||
num_groups: number of expert groups
|
||||
num_nodes: number of server nodes, where the intra-node network
|
||||
(e.g, NVLink) is faster
|
||||
num_gpus: number of GPUs, must be a multiple of `num_nodes`
|
||||
|
||||
Returns:
|
||||
physical_to_logical_map:
|
||||
[layers, num_replicas], the expert index of each replica
|
||||
logical_to_physical_map:
|
||||
[layers, num_logical_experts, X], the replica indices for each
|
||||
expert
|
||||
expert_count:
|
||||
[layers, num_logical_experts], number of physical
|
||||
replicas for each logical expert
|
||||
"""
|
||||
num_layers, num_logical_experts = weight.shape
|
||||
weight = weight.float()
|
||||
if num_groups % num_nodes == 0:
|
||||
# use hierarchical load-balance policy
|
||||
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
|
||||
weight, num_replicas, num_groups, num_nodes, num_gpus
|
||||
)
|
||||
else:
|
||||
# use global load-balance policy
|
||||
phy2log, phyrank, logcnt = rebalance_experts_hierarchical(
|
||||
weight, num_replicas, 1, 1, num_gpus
|
||||
)
|
||||
num_redundant_experts = num_replicas - num_logical_experts
|
||||
maxlogcnt = num_redundant_experts + 1
|
||||
log2phy: torch.Tensor = torch.full(
|
||||
(num_layers, num_logical_experts, maxlogcnt),
|
||||
-1,
|
||||
dtype=torch.int64,
|
||||
device=logcnt.device,
|
||||
)
|
||||
log2phy.view(num_layers, -1).scatter_(
|
||||
-1,
|
||||
phy2log * maxlogcnt + phyrank,
|
||||
torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand(
|
||||
num_layers, -1
|
||||
),
|
||||
)
|
||||
return phy2log, log2phy, logcnt
|
||||
|
||||
|
||||
__all__ = ["rebalance_experts"]
|
||||
Loading…
x
Reference in New Issue
Block a user