mirror of
https://git.datalinker.icu/vllm-project/vllm.git
synced 2025-12-24 13:25:01 +08:00
Signed-off-by: Che Ruan <cr623@ic.ac.uk> Signed-off-by: mengxingkongzhouhan <117415539+mengxingkongzhouhan@users.noreply.github.com> Signed-off-by: Mercykid-bash <ruanche0218@gmail.com> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Che Ruan <cr623@ic.ac.uk> Co-authored-by: mengxingkongzhouhan <117415539+mengxingkongzhouhan@users.noreply.github.com> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
import torch
|
|
|
|
|
|
class AbstractEplbPolicy(ABC):
|
|
@classmethod
|
|
@abstractmethod
|
|
def rebalance_experts(
|
|
cls,
|
|
weight: torch.Tensor,
|
|
num_replicas: int,
|
|
num_groups: int,
|
|
num_nodes: int,
|
|
num_ranks: int,
|
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Entry point for expert-parallelism load balancer.
|
|
|
|
Parameters:
|
|
weight: [layers, num_logical_experts], the load statistics
|
|
for all logical experts
|
|
num_replicas: number of physical experts, must be a multiple of
|
|
`num_ranks`
|
|
num_groups: number of expert groups
|
|
num_nodes: number of server nodes
|
|
num_ranks: number of ranks, must be a multiple of `num_nodes`
|
|
|
|
Returns:
|
|
physical_to_logical_map: [layers, num_replicas], the expert
|
|
index of each replica
|
|
logical_to_physical_map: [layers, num_logical_experts, X],
|
|
the replica indices for each expert
|
|
expert_count: [layers, num_logical_experts], number of
|
|
physical replicas for each logical expert
|
|
"""
|
|
raise NotImplementedError
|