diff --git a/tests/distributed/test_eplb_algo.py b/tests/distributed/test_eplb_algo.py index 79805a7cce53..a53a61840e79 100644 --- a/tests/distributed/test_eplb_algo.py +++ b/tests/distributed/test_eplb_algo.py @@ -4,7 +4,7 @@ import pytest import torch -from vllm.distributed.eplb.rebalance_algo import rebalance_experts +from vllm.distributed.eplb.policy.default import DefaultEplbPolicy def test_basic_rebalance(): @@ -23,7 +23,7 @@ def test_basic_rebalance(): num_nodes = 2 num_gpus = 8 - phy2log, log2phy, logcnt = rebalance_experts( + phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) @@ -77,7 +77,7 @@ def test_single_gpu_case(): num_nodes = 1 num_gpus = 1 - phy2log, log2phy, logcnt = rebalance_experts( + phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) @@ -99,7 +99,7 @@ def test_equal_weights(): num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = rebalance_experts( + phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) @@ -122,7 +122,7 @@ def test_extreme_weight_imbalance(): num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = rebalance_experts( + phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) @@ -150,7 +150,7 @@ def test_multiple_layers(): num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = rebalance_experts( + phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) @@ -175,14 +175,14 @@ def test_parameter_validation(): # Test non-divisible case - this should handle normally without throwing # errors because the function will fall back to global load balancing # strategy - phy2log, log2phy, logcnt = rebalance_experts(weight, 8, 3, 2, 4) + phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts(weight, 8, 3, 2, 4) assert phy2log.shape == (1, 8) assert logcnt.shape == (1, 4) # Test cases that will actually cause errors: # num_physical_experts not divisible by num_gpus with pytest.raises(AssertionError): - rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4 + DefaultEplbPolicy.rebalance_experts(weight, 7, 2, 2, 4) # 7 not divisible by 4 def test_small_scale_hierarchical(): @@ -197,7 +197,7 @@ def test_small_scale_hierarchical(): num_nodes = 2 # 2 nodes num_gpus = 4 # 4 GPUs - phy2log, log2phy, logcnt = rebalance_experts( + phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) @@ -224,7 +224,7 @@ def test_global_load_balance_fallback(): num_nodes = 2 num_gpus = 4 - phy2log, log2phy, logcnt = rebalance_experts( + phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) @@ -246,7 +246,7 @@ def test_device_compatibility(device): num_nodes = 1 num_gpus = 2 - phy2log, log2phy, logcnt = rebalance_experts( + phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) @@ -263,7 +263,9 @@ def test_additional_cases(): weight1 = torch.tensor( [[50, 100, 75, 120, 90, 60, 80, 110, 40, 70, 95, 85, 65, 55, 45, 35]] ) - phy2log1, log2phy1, logcnt1 = rebalance_experts(weight1, 24, 8, 4, 8) + phy2log1, log2phy1, logcnt1 = DefaultEplbPolicy.rebalance_experts( + weight1, 24, 8, 4, 8 + ) assert phy2log1.shape == (1, 24) assert logcnt1.shape == (1, 16) @@ -276,7 +278,9 @@ def test_additional_cases(): [12, 25, 50, 100, 150, 200], # Increasing weights ] ) - phy2log2, log2phy2, logcnt2 = rebalance_experts(weight2, 10, 3, 1, 2) + phy2log2, log2phy2, logcnt2 = DefaultEplbPolicy.rebalance_experts( + weight2, 10, 3, 1, 2 + ) assert phy2log2.shape == (2, 10) assert logcnt2.shape == (2, 6) @@ -300,7 +304,7 @@ if __name__ == "__main__": num_nodes = 2 num_gpus = 8 - phy2log, log2phy, logcnt = rebalance_experts( + phy2log, log2phy, logcnt = DefaultEplbPolicy.rebalance_experts( weight, num_replicas, num_groups, num_nodes, num_gpus ) print(phy2log) diff --git a/vllm/config/parallel.py b/vllm/config/parallel.py index 20de67225710..3a768bcd4f2c 100644 --- a/vllm/config/parallel.py +++ b/vllm/config/parallel.py @@ -35,6 +35,7 @@ logger = init_logger(__name__) ExpertPlacementStrategy = Literal["linear", "round_robin"] DistributedExecutorBackend = Literal["ray", "mp", "uni", "external_launcher"] DataParallelBackend = Literal["ray", "mp"] +EPLBPolicyOption = Literal["default"] @config @@ -65,6 +66,9 @@ class EPLBConfig: Whether to use non-blocking EPLB. """ + policy: EPLBPolicyOption = "default" + """The policy type for expert parallel load balancing (EPLB).""" + @config @dataclass diff --git a/vllm/distributed/eplb/__init__.py b/vllm/distributed/eplb/__init__.py index 4cd51dd384ad..12e6cd417c50 100644 --- a/vllm/distributed/eplb/__init__.py +++ b/vllm/distributed/eplb/__init__.py @@ -1,8 +1,3 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Expert parallelism load balancer (EPLB). -""" - -from .eplb_state import * -from .rebalance_algo import * +"""Expert parallelism load balancer (EPLB).""" diff --git a/vllm/distributed/eplb/eplb_state.py b/vllm/distributed/eplb/eplb_state.py index 9f8798a96a2f..c5654659b79d 100644 --- a/vllm/distributed/eplb/eplb_state.py +++ b/vllm/distributed/eplb/eplb_state.py @@ -45,7 +45,7 @@ from vllm.logger import init_logger from vllm.model_executor.models.interfaces import MixtureOfExperts from .async_worker import start_async_worker -from .rebalance_algo import rebalance_experts +from .policy import EPLB_POLICIES, AbstractEplbPolicy, DefaultEplbPolicy from .rebalance_execute import move_from_buffer, rearrange_expert_weights_inplace logger = init_logger(__name__) @@ -213,18 +213,23 @@ class EplbState: self.parallel_config = parallel_config self.device = device self.model_states: dict[str, EplbModelState] = {} + self.policy: type[AbstractEplbPolicy] = DefaultEplbPolicy + """ + Selected EPLB algorithm class + """ + self.expert_load_window_step: int = 0 """ Current step in the sliding window. Different from `expert_rearrangement_step`, each EP rank may have its own `expert_load_window_step`. """ - self.expert_load_window_step: int = 0 + self.expert_load_window_size: int = 0 """ Size of the expert load sliding window. This is a constant and is taken from the config. """ - self.expert_load_window_size: int = 0 + self.expert_rearrangement_step: int = 0 """ Steps after last rearrangement. Will trigger a rearrangement if it exceeds the threshold. @@ -415,6 +420,10 @@ class EplbState: ) self.expert_rearrangement_step_interval = eplb_step_interval + # Set the policy based on the selected eplb algorithm type. + policy_type = self.parallel_config.eplb_config.policy + self.policy = EPLB_POLICIES[policy_type] + logger.debug("Selected EPLB policy: %d", policy_type) if global_expert_load is not None: ep_group = get_ep_group().device_group assert global_expert_load.shape == ( @@ -441,7 +450,7 @@ class EplbState: new_physical_to_logical_map, new_logical_to_physical_map, new_logical_replica_count, - ) = rebalance_experts( + ) = self.policy.rebalance_experts( global_expert_load, num_replicas, num_groups, @@ -776,6 +785,7 @@ class EplbState: f"{num_gpus=}, {num_nodes=}" ) + # Get new expert mappings for eplb_model_state, global_expert_load_window in zip( self.model_states.values(), global_expert_load_windows ): @@ -784,7 +794,7 @@ class EplbState: new_physical_to_logical_map, new_logical_to_physical_map, new_logical_replica_count, - ) = rebalance_experts( + ) = self.policy.rebalance_experts( global_expert_load_window, num_replicas, num_groups, diff --git a/vllm/distributed/eplb/policy/__init__.py b/vllm/distributed/eplb/policy/__init__.py new file mode 100644 index 000000000000..8e78d7bac0e3 --- /dev/null +++ b/vllm/distributed/eplb/policy/__init__.py @@ -0,0 +1,19 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import get_args + +from vllm.config.parallel import EPLBPolicyOption + +from .abstract import AbstractEplbPolicy +from .default import DefaultEplbPolicy + +EPLB_POLICIES = {"default": DefaultEplbPolicy} + +# Ensure that the EPLB_POLICIES keys match the EPLBPolicyOption values +assert set(EPLB_POLICIES.keys()) == set(get_args(EPLBPolicyOption)) + +__all__ = [ + "AbstractEplbPolicy", + "DefaultEplbPolicy", + "EPLB_POLICIES", +] diff --git a/vllm/distributed/eplb/policy/abstract.py b/vllm/distributed/eplb/policy/abstract.py new file mode 100644 index 000000000000..40ed621c8489 --- /dev/null +++ b/vllm/distributed/eplb/policy/abstract.py @@ -0,0 +1,40 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from abc import ABC, abstractmethod + +import torch + + +class AbstractEplbPolicy(ABC): + @classmethod + @abstractmethod + def rebalance_experts( + cls, + weight: torch.Tensor, + num_replicas: int, + num_groups: int, + num_nodes: int, + num_ranks: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Entry point for expert-parallelism load balancer. + + Parameters: + weight: [layers, num_logical_experts], the load statistics + for all logical experts + num_replicas: number of physical experts, must be a multiple of + `num_ranks` + num_groups: number of expert groups + num_nodes: number of server nodes + num_ranks: number of ranks, must be a multiple of `num_nodes` + + Returns: + physical_to_logical_map: [layers, num_replicas], the expert + index of each replica + logical_to_physical_map: [layers, num_logical_experts, X], + the replica indices for each expert + expert_count: [layers, num_logical_experts], number of + physical replicas for each logical expert + """ + raise NotImplementedError diff --git a/vllm/distributed/eplb/policy/default.py b/vllm/distributed/eplb/policy/default.py new file mode 100644 index 000000000000..6127ec703184 --- /dev/null +++ b/vllm/distributed/eplb/policy/default.py @@ -0,0 +1,267 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Expert parallelism load balancer (EPLB) for vLLM. + +This module implements the core rearrangement algorithm. + +The rearrangement algorithm is adapted from +[DeepSeek EPLB](https://github.com/deepseek-ai/eplb). + +Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example +on how the EPLB algorithm works. +""" + +import numpy as np +import torch + +from .abstract import AbstractEplbPolicy + + +class DefaultEplbPolicy(AbstractEplbPolicy): + @classmethod + def balanced_packing( + cls, weight: torch.Tensor, num_packs: int + ) -> tuple[torch.Tensor, torch.Tensor]: + """ + Pack n weighted objects to m packs, such that each bin contains exactly + n/m objects and the weights of all packs are as balanced as possible. + + Parameters: + weight: [X, n], the weight of each item + num_packs: number of packs + + Returns: + pack_index: [X, n], the pack index of each item + rank_in_pack: [X, n], the rank of the item in the pack + """ + num_layers, num_groups = weight.shape + assert num_groups % num_packs == 0 + groups_per_pack = num_groups // num_packs + + device = weight.device + + if groups_per_pack == 1: + pack_index = torch.arange( + weight.size(-1), dtype=torch.int64, device=device + ).expand(weight.shape) + rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device) + return pack_index, rank_in_pack + + weight_np = weight.cpu().numpy() + + # Sort and get indices in decending order + indices_np = np.argsort(-weight_np, axis=-1) + + pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64) + rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64) + + # Run the packing algorithm + for i in range(num_layers): + pack_weights = [0.0] * num_packs + pack_items = [0] * num_packs + + for group in indices_np[i]: + # Find a pack with capacity that has the lowest weight + pack = min( + (j for j in range(num_packs) if pack_items[j] < groups_per_pack), + key=pack_weights.__getitem__, + ) + + assert pack_items[pack] < groups_per_pack + pack_index_np[i, group] = pack + rank_in_pack_np[i, group] = pack_items[pack] + pack_weights[pack] += weight_np[i, group] + pack_items[pack] += 1 + + pack_index = torch.from_numpy(pack_index_np).to(device) + rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device) + + return pack_index, rank_in_pack + + @classmethod + def replicate_experts( + cls, weight: torch.Tensor, num_phy: int + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Replicate `num_log` experts to `num_phy` replicas, such that the maximum + load of all replicas is minimized. + + Parameters: + weight: [X, num_log] + num_phy: total number of experts after replication + + Returns: + phy2log: [X, num_phy], logical expert id of each physical expert + rank: [X, num_phy], the replica rank + logcnt: [X, num_log], number of replicas for each logical expert + """ + n, num_log = weight.shape + num_redundant = num_phy - num_log + assert num_redundant >= 0 + device = weight.device + phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) + rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) + logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) + arangen = torch.arange(n, dtype=torch.int64, device=device) + for i in range(num_log, num_phy): + redundant_indices = (weight / logcnt).max(dim=-1).indices + phy2log[:, i] = redundant_indices + rank[:, i] = logcnt[arangen, redundant_indices] + logcnt[arangen, redundant_indices] += 1 + return phy2log, rank, logcnt + + @classmethod + def rebalance_experts_hierarchical( + cls, + weight: torch.Tensor, + num_physical_experts: int, + num_groups: int, + num_nodes: int, + num_gpus: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Parameters: + weight: [num_moe_layers, num_logical_experts] + num_physical_experts: number of physical experts after replication + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network + (e.g, NVLink) is faster + num_gpus: number of GPUs, must be a multiple of `num_nodes` + + Returns: + phy2log: [layers, num_replicas], the expert + index of each replica + log2phy: [layers, num_logical_experts, X], + the replica indices for each expert + logcnt: [layers, num_logical_experts], number of + physical replicas for each logical expert + """ + num_layers, num_logical_experts = weight.shape + assert num_logical_experts % num_groups == 0 + group_size = num_logical_experts // num_groups + assert num_groups % num_nodes == 0 + groups_per_node = num_groups // num_nodes + assert num_gpus % num_nodes == 0 + assert num_physical_experts % num_gpus == 0 + phy_experts_per_gpu = num_physical_experts // num_gpus + + def inverse(perm: torch.Tensor) -> torch.Tensor: + inv = torch.empty_like(perm) + inv.scatter_( + 1, + perm, + torch.arange( + perm.size(1), dtype=torch.int64, device=perm.device + ).expand(perm.shape), + ) + return inv + + # Step 1: pack groups to nodes + tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) + group_pack_index, group_rank_in_pack = cls.balanced_packing( + tokens_per_group, num_nodes + ) + log2mlog = ( + ( + (group_pack_index * groups_per_node + group_rank_in_pack) * group_size + ).unsqueeze(-1) + + torch.arange( + group_size, dtype=torch.int64, device=group_pack_index.device + ) + ).flatten(-2) + mlog2log = inverse(log2mlog) + + # Step 2: construct redundant experts within nodes + # [num_layers * num_nodes, num_logical_experts // num_nodes] + tokens_per_mlog = weight.gather(-1, mlog2log).view( + -1, num_logical_experts // num_nodes + ) + phy2mlog, phyrank, mlogcnt = cls.replicate_experts( + tokens_per_mlog, num_physical_experts // num_nodes + ) + + # Step 3: pack physical_experts to GPUs + # [num_layers * num_nodes, num_physical_experts // num_nodes] + tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) + pack_index, rank_in_pack = cls.balanced_packing( + tokens_per_phy, num_gpus // num_nodes + ) + phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack + pphy2phy = inverse(phy2pphy) + + pphy2mlog = phy2mlog.gather( + -1, pphy2phy + ) # [num_layers * num_nodes, num_log_per_nodes] + pphy2mlog = ( + pphy2mlog.view(num_layers, num_nodes, -1) + + torch.arange( + 0, + num_logical_experts, + num_logical_experts // num_nodes, + device=group_pack_index.device, + ).view(1, -1, 1) + ).flatten(-2) + pphy2log = mlog2log.gather(-1, pphy2mlog) + pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) + logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) + return pphy2log, pphyrank, logcnt + + @classmethod + def rebalance_experts( + cls, + weight: torch.Tensor, + num_replicas: int, + num_groups: int, + num_nodes: int, + num_ranks: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Entry point for expert-parallelism load balancer. + + Parameters: + weight: [layers, num_logical_experts], the load statistics for all + logical experts + num_replicas: number of physical experts, must be a multiple of + `num_gpus` + num_groups: number of expert groups + num_nodes: number of server nodes, where the intra-node network + (e.g, NVLink) is faster + num_ranks: number of ranks, must be a multiple of `num_nodes` + + Returns: + phy2log: [layers, num_replicas], the expert + index of each replica + log2phy: [layers, num_logical_experts, X], + the replica indices for each expert + logcnt: [layers, num_logical_experts], number of + physical replicas for each logical expert + """ + num_layers, num_logical_experts = weight.shape + weight = weight.float() + if num_groups % num_nodes == 0: + # use hierarchical load-balance policy + phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical( + weight, num_replicas, num_groups, num_nodes, num_ranks + ) + else: + # use global load-balance policy + phy2log, phyrank, logcnt = cls.rebalance_experts_hierarchical( + weight, num_replicas, 1, 1, num_ranks + ) + num_redundant_experts = num_replicas - num_logical_experts + maxlogcnt = num_redundant_experts + 1 + log2phy: torch.Tensor = torch.full( + (num_layers, num_logical_experts, maxlogcnt), + -1, + dtype=torch.int64, + device=logcnt.device, + ) + log2phy.view(num_layers, -1).scatter_( + -1, + phy2log * maxlogcnt + phyrank, + torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand( + num_layers, -1 + ), + ) + return phy2log, log2phy, logcnt diff --git a/vllm/distributed/eplb/rebalance_algo.py b/vllm/distributed/eplb/rebalance_algo.py deleted file mode 100644 index e6645e524cc3..000000000000 --- a/vllm/distributed/eplb/rebalance_algo.py +++ /dev/null @@ -1,260 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -""" -Expert parallelism load balancer (EPLB) for vLLM. - -This module implements the core rearrangement algorithm. - -The rearrangement algorithm is adapted from -[DeepSeek EPLB](https://github.com/deepseek-ai/eplb). - -Please find at [#12](https://github.com/deepseek-ai/EPLB/issues/12) an example -on how the EPLB algorithm works. -""" - -import numpy as np -import torch - - -def balanced_packing( - weight: torch.Tensor, num_packs: int -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Pack n weighted objects to m packs, such that each bin contains exactly - n/m objects and the weights of all packs are as balanced as possible. - - Parameters: - weight: [X, n], the weight of each item - num_packs: number of packs - - Returns: - pack_index: [X, n], the pack index of each item - rank_in_pack: [X, n], the rank of the item in the pack - """ - num_layers, num_groups = weight.shape - assert num_groups % num_packs == 0 - groups_per_pack = num_groups // num_packs - - device = weight.device - - if groups_per_pack == 1: - pack_index = torch.arange( - weight.size(-1), dtype=torch.int64, device=device - ).expand(weight.shape) - rank_in_pack = torch.zeros_like(weight, dtype=torch.int64, device=device) - return pack_index, rank_in_pack - - weight_np = weight.cpu().numpy() - - # Sort and get indices in decending order - indices_np = np.argsort(-weight_np, axis=-1) - - pack_index_np = np.full((num_layers, num_groups), -1, dtype=np.int64) - rank_in_pack_np = np.full((num_layers, num_groups), -1, dtype=np.int64) - - # Run the packing algorithm - for i in range(num_layers): - pack_weights = [0.0] * num_packs - pack_items = [0] * num_packs - - for group in indices_np[i]: - # Find a pack with capacity that has the lowest weight - pack = min( - (j for j in range(num_packs) if pack_items[j] < groups_per_pack), - key=pack_weights.__getitem__, - ) - - assert pack_items[pack] < groups_per_pack - pack_index_np[i, group] = pack - rank_in_pack_np[i, group] = pack_items[pack] - pack_weights[pack] += weight_np[i, group] - pack_items[pack] += 1 - - pack_index = torch.from_numpy(pack_index_np).to(device) - rank_in_pack = torch.from_numpy(rank_in_pack_np).to(device) - - return pack_index, rank_in_pack - - -def replicate_experts( - weight: torch.Tensor, num_phy: int -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Replicate `num_log` experts to `num_phy` replicas, such that the maximum - load of all replicas is minimized. - - Parameters: - weight: [X, num_log] - num_phy: total number of experts after replication - - Returns: - phy2log: [X, num_phy], logical expert id of each physical expert - rank: [X, num_phy], the replica rank - logcnt: [X, num_log], number of replicas for each logical expert - """ - n, num_log = weight.shape - num_redundant = num_phy - num_log - assert num_redundant >= 0 - device = weight.device - phy2log = torch.arange(num_phy, dtype=torch.int64, device=device).repeat(n, 1) - rank = torch.zeros(n, num_phy, dtype=torch.int64, device=device) - logcnt = torch.ones(n, num_log, dtype=torch.int64, device=device) - arangen = torch.arange(n, dtype=torch.int64, device=device) - for i in range(num_log, num_phy): - redundant_indices = (weight / logcnt).max(dim=-1).indices - phy2log[:, i] = redundant_indices - rank[:, i] = logcnt[arangen, redundant_indices] - logcnt[arangen, redundant_indices] += 1 - return phy2log, rank, logcnt - - -def rebalance_experts_hierarchical( - weight: torch.Tensor, - num_physical_experts: int, - num_groups: int, - num_nodes: int, - num_gpus: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Parameters: - weight: [num_moe_layers, num_logical_experts] - num_physical_experts: number of physical experts after replication - num_groups: number of expert groups - num_nodes: number of server nodes, where the intra-node network - (e.g., NVLink) is faster - num_gpus: number of GPUs, must be a multiple of `num_nodes` - - Returns: - physical_to_logical_map (torch.Tensor): - [num_moe_layers, num_physical_experts] - logical_to_physical_map (torch.Tensor): - [num_moe_layers, num_logical_experts, X] - logical_count (torch.Tensor): - [num_moe_layers, num_logical_experts] - """ - num_layers, num_logical_experts = weight.shape - assert num_logical_experts % num_groups == 0 - group_size = num_logical_experts // num_groups - assert num_groups % num_nodes == 0 - groups_per_node = num_groups // num_nodes - assert num_gpus % num_nodes == 0 - assert num_physical_experts % num_gpus == 0 - phy_experts_per_gpu = num_physical_experts // num_gpus - - def inverse(perm: torch.Tensor) -> torch.Tensor: - inv = torch.empty_like(perm) - inv.scatter_( - 1, - perm, - torch.arange(perm.size(1), dtype=torch.int64, device=perm.device).expand( - perm.shape - ), - ) - return inv - - # Step 1: pack groups to nodes - tokens_per_group = weight.unflatten(-1, (num_groups, group_size)).sum(-1) - group_pack_index, group_rank_in_pack = balanced_packing(tokens_per_group, num_nodes) - log2mlog = ( - ( - (group_pack_index * groups_per_node + group_rank_in_pack) * group_size - ).unsqueeze(-1) - + torch.arange(group_size, dtype=torch.int64, device=group_pack_index.device) - ).flatten(-2) - mlog2log = inverse(log2mlog) - - # Step 2: construct redundant experts within nodes - # [num_layers * num_nodes, num_logical_experts // num_nodes] - tokens_per_mlog = weight.gather(-1, mlog2log).view( - -1, num_logical_experts // num_nodes - ) - phy2mlog, phyrank, mlogcnt = replicate_experts( - tokens_per_mlog, num_physical_experts // num_nodes - ) - - # Step 3: pack physical_experts to GPUs - # [num_layers * num_nodes, num_physical_experts // num_nodes] - tokens_per_phy = (tokens_per_mlog / mlogcnt).gather(-1, phy2mlog) - pack_index, rank_in_pack = balanced_packing(tokens_per_phy, num_gpus // num_nodes) - phy2pphy = pack_index * phy_experts_per_gpu + rank_in_pack - pphy2phy = inverse(phy2pphy) - - pphy2mlog = phy2mlog.gather( - -1, pphy2phy - ) # [num_layers * num_nodes, num_log_per_nodes] - pphy2mlog = ( - pphy2mlog.view(num_layers, num_nodes, -1) - + torch.arange( - 0, - num_logical_experts, - num_logical_experts // num_nodes, - device=group_pack_index.device, - ).view(1, -1, 1) - ).flatten(-2) - pphy2log = mlog2log.gather(-1, pphy2mlog) - pphyrank = phyrank.gather(-1, pphy2phy).view(num_layers, -1) - logcnt = mlogcnt.view(num_layers, -1).gather(-1, log2mlog) - return pphy2log, pphyrank, logcnt - - -def rebalance_experts( - weight: torch.Tensor, - num_replicas: int, - num_groups: int, - num_nodes: int, - num_gpus: int, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """ - Entry point for expert-parallelism load balancer. - - Parameters: - weight: [layers, num_logical_experts], the load statistics for all - logical experts - num_replicas: number of physical experts, must be a multiple of - `num_gpus` - num_groups: number of expert groups - num_nodes: number of server nodes, where the intra-node network - (e.g, NVLink) is faster - num_gpus: number of GPUs, must be a multiple of `num_nodes` - - Returns: - physical_to_logical_map: - [layers, num_replicas], the expert index of each replica - logical_to_physical_map: - [layers, num_logical_experts, X], the replica indices for each - expert - expert_count: - [layers, num_logical_experts], number of physical - replicas for each logical expert - """ - num_layers, num_logical_experts = weight.shape - weight = weight.float() - if num_groups % num_nodes == 0: - # use hierarchical load-balance policy - phy2log, phyrank, logcnt = rebalance_experts_hierarchical( - weight, num_replicas, num_groups, num_nodes, num_gpus - ) - else: - # use global load-balance policy - phy2log, phyrank, logcnt = rebalance_experts_hierarchical( - weight, num_replicas, 1, 1, num_gpus - ) - num_redundant_experts = num_replicas - num_logical_experts - maxlogcnt = num_redundant_experts + 1 - log2phy: torch.Tensor = torch.full( - (num_layers, num_logical_experts, maxlogcnt), - -1, - dtype=torch.int64, - device=logcnt.device, - ) - log2phy.view(num_layers, -1).scatter_( - -1, - phy2log * maxlogcnt + phyrank, - torch.arange(num_replicas, dtype=torch.int64, device=log2phy.device).expand( - num_layers, -1 - ), - ) - return phy2log, log2phy, logcnt - - -__all__ = ["rebalance_experts"]